18 lines
623 B
Python
18 lines
623 B
Python
|
import torch
|
||
|
|
||
|
def prepare_device(n_gpu_use):
|
||
|
"""
|
||
|
setup GPU device if available. get gpu device indices which are used for DataParallel
|
||
|
"""
|
||
|
n_gpu = torch.cuda.device_count()
|
||
|
if n_gpu_use > 0 and n_gpu == 0:
|
||
|
print("Warning: There\'s no GPU available on this machine,"
|
||
|
"training will be performed on CPU.")
|
||
|
n_gpu_use = 0
|
||
|
if n_gpu_use > n_gpu:
|
||
|
print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
|
||
|
"available on this machine.")
|
||
|
n_gpu_use = n_gpu
|
||
|
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
|
||
|
list_ids = list(range(n_gpu_use))
|
||
|
return device, list_ids
|