compatible with latest pytorch on nn.parallel.data_parallel

This commit is contained in:
Xinwei He 2017-04-06 13:51:31 +08:00 committed by GitHub
parent 27e72795ad
commit 44caab4855
1 changed files with 7 additions and 7 deletions

View File

@ -27,11 +27,11 @@ class MLP_G(nn.Module):
def forward(self, input): def forward(self, input):
input = input.view(input.size(0), input.size(1)) input = input.view(input.size(0), input.size(1))
gpu_ids = None
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
gpu_ids = range(self.ngpu) output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
out = nn.parallel.data_parallel(self.main, input, gpu_ids) else:
return out.view(out.size(0), self.nc, self.isize, self.isize) output = self.main(input)
return output.view(out.size(0), self.nc, self.isize, self.isize)
class MLP_D(nn.Module): class MLP_D(nn.Module):
@ -57,9 +57,9 @@ class MLP_D(nn.Module):
def forward(self, input): def forward(self, input):
input = input.view(input.size(0), input = input.view(input.size(0),
input.size(1) * input.size(2) * input.size(3)) input.size(1) * input.size(2) * input.size(3))
gpu_ids = None
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
gpu_ids = range(self.ngpu) output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
output = nn.parallel.data_parallel(self.main, input, gpu_ids) else:
output = self.main(input)
output = output.mean(0) output = output.mean(0)
return output.view(1) return output.view(1)