Merge pull request #19 from elPistolero/resize_fix

fix some resize related bugs
This commit is contained in:
Soumith Chintala 2017-03-02 17:10:23 -05:00 committed by GitHub
commit 492c38def2
1 changed files with 5 additions and 2 deletions

View File

@ -179,6 +179,8 @@ for epoch in range(opt.niter):
netD.zero_grad() netD.zero_grad()
batch_size = real_cpu.size(0) batch_size = real_cpu.size(0)
if opt.cuda:
real_cpu = real_cpu.cuda()
input.resize_as_(real_cpu).copy_(real_cpu) input.resize_as_(real_cpu).copy_(real_cpu)
inputv = Variable(input) inputv = Variable(input)
@ -189,7 +191,8 @@ for epoch in range(opt.niter):
noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1) noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
noisev = Variable(noise) noisev = Variable(noise)
fake = netG(noisev) fake = netG(noisev)
inputv.data.copy_(fake.data) inputv = fake
inputv.detach()
errD_fake = netD(inputv) errD_fake = netD(inputv)
errD_fake.backward(mone) errD_fake.backward(mone)
errD = errD_real - errD_fake errD = errD_real - errD_fake
@ -216,7 +219,7 @@ for epoch in range(opt.niter):
errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0])) errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))
if gen_iterations % 500 == 0: if gen_iterations % 500 == 0:
vutils.save_image(real_cpu, '{0}/real_samples.png'.format(opt.experiment)) vutils.save_image(real_cpu, '{0}/real_samples.png'.format(opt.experiment))
fake = netG(fixed_noise) fake = netG(Variable(fixed_noise, volatile=True))
vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations)) vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations))
# do checkpointing # do checkpointing