fix some resize related bugs
This commit is contained in:
		
							parent
							
								
									7eee65dc9c
								
							
						
					
					
						commit
						df2873dee8
					
				
							
								
								
									
										7
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								main.py
									
									
									
									
									
								
							@ -179,6 +179,8 @@ for epoch in range(opt.niter):
 | 
				
			|||||||
            netD.zero_grad()
 | 
					            netD.zero_grad()
 | 
				
			||||||
            batch_size = real_cpu.size(0)
 | 
					            batch_size = real_cpu.size(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if opt.cuda:
 | 
				
			||||||
 | 
					                real_cpu = real_cpu.cuda()
 | 
				
			||||||
            input.resize_as_(real_cpu).copy_(real_cpu)
 | 
					            input.resize_as_(real_cpu).copy_(real_cpu)
 | 
				
			||||||
            inputv = Variable(input)
 | 
					            inputv = Variable(input)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -189,7 +191,8 @@ for epoch in range(opt.niter):
 | 
				
			|||||||
            noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
 | 
					            noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
 | 
				
			||||||
            noisev = Variable(noise)
 | 
					            noisev = Variable(noise)
 | 
				
			||||||
            fake = netG(noisev)
 | 
					            fake = netG(noisev)
 | 
				
			||||||
            inputv.data.copy_(fake.data)
 | 
					            inputv = fake
 | 
				
			||||||
 | 
					            inputv.detach()
 | 
				
			||||||
            errD_fake = netD(inputv)
 | 
					            errD_fake = netD(inputv)
 | 
				
			||||||
            errD_fake.backward(mone)
 | 
					            errD_fake.backward(mone)
 | 
				
			||||||
            errD = errD_real - errD_fake
 | 
					            errD = errD_real - errD_fake
 | 
				
			||||||
@ -216,7 +219,7 @@ for epoch in range(opt.niter):
 | 
				
			|||||||
            errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))
 | 
					            errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))
 | 
				
			||||||
        if gen_iterations % 500 == 0:
 | 
					        if gen_iterations % 500 == 0:
 | 
				
			||||||
            vutils.save_image(real_cpu, '{0}/real_samples.png'.format(opt.experiment))
 | 
					            vutils.save_image(real_cpu, '{0}/real_samples.png'.format(opt.experiment))
 | 
				
			||||||
            fake = netG(fixed_noise)
 | 
					            fake = netG(Variable(fixed_noise, volatile=True))
 | 
				
			||||||
            vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations))
 | 
					            vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # do checkpointing
 | 
					    # do checkpointing
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user