diff --git a/models/mlp.py b/models/mlp.py index beafa3a..4701026 100644 --- a/models/mlp.py +++ b/models/mlp.py @@ -27,11 +27,11 @@ class MLP_G(nn.Module): def forward(self, input): input = input.view(input.size(0), input.size(1)) - gpu_ids = None if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: - gpu_ids = range(self.ngpu) - out = nn.parallel.data_parallel(self.main, input, gpu_ids) - return out.view(out.size(0), self.nc, self.isize, self.isize) + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) + return output.view(out.size(0), self.nc, self.isize, self.isize) class MLP_D(nn.Module): @@ -57,9 +57,9 @@ class MLP_D(nn.Module): def forward(self, input): input = input.view(input.size(0), input.size(1) * input.size(2) * input.size(3)) - gpu_ids = None if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: - gpu_ids = range(self.ngpu) - output = nn.parallel.data_parallel(self.main, input, gpu_ids) + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) + else: + output = self.main(input) output = output.mean(0) return output.view(1)