Merge pull request #37 from eriche2016/patch-1
compitible with latest pytorch such that nn.parallel.data_parallel w…
This commit is contained in:
		
						commit
						ebe4b4ca60
					
				@ -44,10 +44,11 @@ class DCGAN_D(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, input):
 | 
					    def forward(self, input):
 | 
				
			||||||
        gpu_ids = None
 | 
					 | 
				
			||||||
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
 | 
					        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
 | 
				
			||||||
            gpu_ids = range(self.ngpu)
 | 
					            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
 | 
				
			||||||
        output = nn.parallel.data_parallel(self.main, input, gpu_ids)
 | 
					        else: 
 | 
				
			||||||
 | 
					            output = self.main(input)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
        output = output.mean(0)
 | 
					        output = output.mean(0)
 | 
				
			||||||
        return output.view(1)
 | 
					        return output.view(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -98,11 +99,11 @@ class DCGAN_G(nn.Module):
 | 
				
			|||||||
        self.main = main
 | 
					        self.main = main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, input):
 | 
					    def forward(self, input):
 | 
				
			||||||
        gpu_ids = None
 | 
					 | 
				
			||||||
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
 | 
					        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
 | 
				
			||||||
            gpu_ids = range(self.ngpu)
 | 
					            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
 | 
				
			||||||
        return nn.parallel.data_parallel(self.main, input, gpu_ids)
 | 
					        else: 
 | 
				
			||||||
 | 
					            output = self.main(input)
 | 
				
			||||||
 | 
					        return output 
 | 
				
			||||||
###############################################################################
 | 
					###############################################################################
 | 
				
			||||||
class DCGAN_D_nobn(nn.Module):
 | 
					class DCGAN_D_nobn(nn.Module):
 | 
				
			||||||
    def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0):
 | 
					    def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0):
 | 
				
			||||||
@ -143,10 +144,11 @@ class DCGAN_D_nobn(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, input):
 | 
					    def forward(self, input):
 | 
				
			||||||
        gpu_ids = None
 | 
					 | 
				
			||||||
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
 | 
					        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
 | 
				
			||||||
            gpu_ids = range(self.ngpu)
 | 
					            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
 | 
				
			||||||
        output = nn.parallel.data_parallel(self.main, input, gpu_ids)
 | 
					        else: 
 | 
				
			||||||
 | 
					            output = self.main(input)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
        output = output.mean(0)
 | 
					        output = output.mean(0)
 | 
				
			||||||
        return output.view(1)
 | 
					        return output.view(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -190,7 +192,8 @@ class DCGAN_G_nobn(nn.Module):
 | 
				
			|||||||
        self.main = main
 | 
					        self.main = main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, input):
 | 
					    def forward(self, input):
 | 
				
			||||||
        gpu_ids = None
 | 
					 | 
				
			||||||
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
 | 
					        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
 | 
				
			||||||
            gpu_ids = range(self.ngpu)
 | 
					            output = nn.parallel.data_parallel(self.main, input,  range(self.ngpu))
 | 
				
			||||||
        return nn.parallel.data_parallel(self.main, input, gpu_ids)
 | 
					        else: 
 | 
				
			||||||
 | 
					            output = self.main(input)
 | 
				
			||||||
 | 
					        return output 
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user