Merge pull request #37 from eriche2016/patch-1
compitible with latest pytorch such that nn.parallel.data_parallel w…
This commit is contained in:
commit
ebe4b4ca60
|
@ -44,10 +44,11 @@ class DCGAN_D(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
gpu_ids = None
|
|
||||||
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
||||||
gpu_ids = range(self.ngpu)
|
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
|
||||||
output = nn.parallel.data_parallel(self.main, input, gpu_ids)
|
else:
|
||||||
|
output = self.main(input)
|
||||||
|
|
||||||
output = output.mean(0)
|
output = output.mean(0)
|
||||||
return output.view(1)
|
return output.view(1)
|
||||||
|
|
||||||
|
@ -98,11 +99,11 @@ class DCGAN_G(nn.Module):
|
||||||
self.main = main
|
self.main = main
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
gpu_ids = None
|
|
||||||
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
||||||
gpu_ids = range(self.ngpu)
|
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
|
||||||
return nn.parallel.data_parallel(self.main, input, gpu_ids)
|
else:
|
||||||
|
output = self.main(input)
|
||||||
|
return output
|
||||||
###############################################################################
|
###############################################################################
|
||||||
class DCGAN_D_nobn(nn.Module):
|
class DCGAN_D_nobn(nn.Module):
|
||||||
def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0):
|
def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0):
|
||||||
|
@ -143,10 +144,11 @@ class DCGAN_D_nobn(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
gpu_ids = None
|
|
||||||
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
||||||
gpu_ids = range(self.ngpu)
|
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
|
||||||
output = nn.parallel.data_parallel(self.main, input, gpu_ids)
|
else:
|
||||||
|
output = self.main(input)
|
||||||
|
|
||||||
output = output.mean(0)
|
output = output.mean(0)
|
||||||
return output.view(1)
|
return output.view(1)
|
||||||
|
|
||||||
|
@ -190,7 +192,8 @@ class DCGAN_G_nobn(nn.Module):
|
||||||
self.main = main
|
self.main = main
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
gpu_ids = None
|
|
||||||
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
||||||
gpu_ids = range(self.ngpu)
|
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
|
||||||
return nn.parallel.data_parallel(self.main, input, gpu_ids)
|
else:
|
||||||
|
output = self.main(input)
|
||||||
|
return output
|
||||||
|
|
Loading…
Reference in New Issue