diff --git a/models/dcgan.py b/models/dcgan.py index 9a9ae4c..1dd8dbf 100644 --- a/models/dcgan.py +++ b/models/dcgan.py @@ -44,10 +44,11 @@ class DCGAN_D(nn.Module): def forward(self, input): - gpu_ids = None if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: - gpu_ids = range(self.ngpu) - output = nn.parallel.data_parallel(self.main, input, gpu_ids) + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + output = output.mean(0) return output.view(1) @@ -98,11 +99,11 @@ class DCGAN_G(nn.Module): self.main = main def forward(self, input): - gpu_ids = None if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: - gpu_ids = range(self.ngpu) - return nn.parallel.data_parallel(self.main, input, gpu_ids) - + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + return output ############################################################################### class DCGAN_D_nobn(nn.Module): def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0): @@ -143,10 +144,11 @@ class DCGAN_D_nobn(nn.Module): def forward(self, input): - gpu_ids = None if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: - gpu_ids = range(self.ngpu) - output = nn.parallel.data_parallel(self.main, input, gpu_ids) + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + output = output.mean(0) return output.view(1) @@ -190,7 +192,8 @@ class DCGAN_G_nobn(nn.Module): self.main = main def forward(self, input): - gpu_ids = None if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: - gpu_ids = range(self.ngpu) - return nn.parallel.data_parallel(self.main, input, gpu_ids) + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + return output