Merge pull request #37 from eriche2016/patch-1
compitible with latest pytorch such that nn.parallel.data_parallel w…
This commit is contained in:
		
						commit
						ebe4b4ca60
					
				| @ -44,10 +44,11 @@ class DCGAN_D(nn.Module): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     def forward(self, input): |     def forward(self, input): | ||||||
|         gpu_ids = None |  | ||||||
|         if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: |         if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: | ||||||
|             gpu_ids = range(self.ngpu) |             output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) | ||||||
|         output = nn.parallel.data_parallel(self.main, input, gpu_ids) |         else:  | ||||||
|  |             output = self.main(input) | ||||||
|  |              | ||||||
|         output = output.mean(0) |         output = output.mean(0) | ||||||
|         return output.view(1) |         return output.view(1) | ||||||
| 
 | 
 | ||||||
| @ -98,11 +99,11 @@ class DCGAN_G(nn.Module): | |||||||
|         self.main = main |         self.main = main | ||||||
| 
 | 
 | ||||||
|     def forward(self, input): |     def forward(self, input): | ||||||
|         gpu_ids = None |  | ||||||
|         if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: |         if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: | ||||||
|             gpu_ids = range(self.ngpu) |             output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) | ||||||
|         return nn.parallel.data_parallel(self.main, input, gpu_ids) |         else:  | ||||||
| 
 |             output = self.main(input) | ||||||
|  |         return output  | ||||||
| ############################################################################### | ############################################################################### | ||||||
| class DCGAN_D_nobn(nn.Module): | class DCGAN_D_nobn(nn.Module): | ||||||
|     def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0): |     def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0): | ||||||
| @ -143,10 +144,11 @@ class DCGAN_D_nobn(nn.Module): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     def forward(self, input): |     def forward(self, input): | ||||||
|         gpu_ids = None |  | ||||||
|         if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: |         if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: | ||||||
|             gpu_ids = range(self.ngpu) |             output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) | ||||||
|         output = nn.parallel.data_parallel(self.main, input, gpu_ids) |         else:  | ||||||
|  |             output = self.main(input) | ||||||
|  |              | ||||||
|         output = output.mean(0) |         output = output.mean(0) | ||||||
|         return output.view(1) |         return output.view(1) | ||||||
| 
 | 
 | ||||||
| @ -190,7 +192,8 @@ class DCGAN_G_nobn(nn.Module): | |||||||
|         self.main = main |         self.main = main | ||||||
| 
 | 
 | ||||||
|     def forward(self, input): |     def forward(self, input): | ||||||
|         gpu_ids = None |  | ||||||
|         if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: |         if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: | ||||||
|             gpu_ids = range(self.ngpu) |             output = nn.parallel.data_parallel(self.main, input,  range(self.ngpu)) | ||||||
|         return nn.parallel.data_parallel(self.main, input, gpu_ids) |         else:  | ||||||
|  |             output = self.main(input) | ||||||
|  |         return output  | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Soumith Chintala
						Soumith Chintala