compatible with latest pytorch on nn.parallel.data_parallel
This commit is contained in:
parent
27e72795ad
commit
44caab4855
|
@ -27,11 +27,11 @@ class MLP_G(nn.Module):
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
input = input.view(input.size(0), input.size(1))
|
input = input.view(input.size(0), input.size(1))
|
||||||
gpu_ids = None
|
|
||||||
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
||||||
gpu_ids = range(self.ngpu)
|
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
|
||||||
out = nn.parallel.data_parallel(self.main, input, gpu_ids)
|
else:
|
||||||
return out.view(out.size(0), self.nc, self.isize, self.isize)
|
output = self.main(input)
|
||||||
|
return output.view(out.size(0), self.nc, self.isize, self.isize)
|
||||||
|
|
||||||
|
|
||||||
class MLP_D(nn.Module):
|
class MLP_D(nn.Module):
|
||||||
|
@ -57,9 +57,9 @@ class MLP_D(nn.Module):
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
input = input.view(input.size(0),
|
input = input.view(input.size(0),
|
||||||
input.size(1) * input.size(2) * input.size(3))
|
input.size(1) * input.size(2) * input.size(3))
|
||||||
gpu_ids = None
|
|
||||||
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
||||||
gpu_ids = range(self.ngpu)
|
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
|
||||||
output = nn.parallel.data_parallel(self.main, input, gpu_ids)
|
else:
|
||||||
|
output = self.main(input)
|
||||||
output = output.mean(0)
|
output = output.mean(0)
|
||||||
return output.view(1)
|
return output.view(1)
|
||||||
|
|
Loading…
Reference in New Issue