compitible with latest pytorch such that nn.parallel.data_parallel wont complains

This commit is contained in:
Xinwei He 2017-04-06 13:49:14 +08:00 committed by GitHub
parent 27e72795ad
commit 5ef486cec5
1 changed files with 16 additions and 13 deletions

View File

@ -44,10 +44,11 @@ class DCGAN_D(nn.Module):
def forward(self, input):
gpu_ids = None
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
gpu_ids = range(self.ngpu)
output = nn.parallel.data_parallel(self.main, input, gpu_ids)
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
output = output.mean(0)
return output.view(1)
@ -98,11 +99,11 @@ class DCGAN_G(nn.Module):
self.main = main
def forward(self, input):
gpu_ids = None
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
gpu_ids = range(self.ngpu)
return nn.parallel.data_parallel(self.main, input, gpu_ids)
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
###############################################################################
class DCGAN_D_nobn(nn.Module):
def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0):
@ -143,10 +144,11 @@ class DCGAN_D_nobn(nn.Module):
def forward(self, input):
gpu_ids = None
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
gpu_ids = range(self.ngpu)
output = nn.parallel.data_parallel(self.main, input, gpu_ids)
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
output = output.mean(0)
return output.view(1)
@ -190,7 +192,8 @@ class DCGAN_G_nobn(nn.Module):
self.main = main
def forward(self, input):
gpu_ids = None
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
gpu_ids = range(self.ngpu)
return nn.parallel.data_parallel(self.main, input, gpu_ids)
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output