diff --git a/main.py b/main.py index b943b62..948a6f3 100644 --- a/main.py +++ b/main.py @@ -179,6 +179,8 @@ for epoch in range(opt.niter): netD.zero_grad() batch_size = real_cpu.size(0) + if opt.cuda: + real_cpu = real_cpu.cuda() input.resize_as_(real_cpu).copy_(real_cpu) inputv = Variable(input) @@ -189,7 +191,8 @@ for epoch in range(opt.niter): noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1) noisev = Variable(noise) fake = netG(noisev) - inputv.data.copy_(fake.data) + inputv = fake + inputv.detach() errD_fake = netD(inputv) errD_fake.backward(mone) errD = errD_real - errD_fake @@ -216,7 +219,7 @@ for epoch in range(opt.niter): errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0])) if gen_iterations % 500 == 0: vutils.save_image(real_cpu, '{0}/real_samples.png'.format(opt.experiment)) - fake = netG(fixed_noise) + fake = netG(Variable(fixed_noise, volatile=True)) vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations)) # do checkpointing