1. 指明要是用的GPU
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,3"
上述代码说明,给本程序分配了编号“0,1,3”的三块GPU可供使用
PS: 本机必须是有上述声明的显卡,否则在使用时会出错
RuntimeError: cuda runtime error (38) : no CUDA-capable device is detected at ..\aten\src\THC\THCGeneral.cpp:50
2. GPU并行
PyTorch提供相应的函数,可实现简单高效的并行GPU计算。
class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
通过device_ids参数可以指定在哪些GPU上进行优化,返回一个新的module
import torch.nn as nn
model=Net() # 定义模型
model=nn. DataParallel(model,device_ids=[0,1])
model.cuda()
DataParallel并行的方式,是将输入一个batch的数据均分成多份,分别送到对应的GPU进行计算,各个GPU得到的梯度累加。