dont reuse variables

This commit is contained in:
Soumith Chintala 2017-02-27 13:54:22 -05:00
parent a6c76da533
commit e553093d3b
1 changed files with 11 additions and 10 deletions

21
main.py
View File

@ -182,17 +182,18 @@ for epoch in range(opt.niter):
real_cpu, _ = data
netD.zero_grad()
batch_size = real_cpu.size(0)
input.data.resize_(real_cpu.size()).copy_(real_cpu)
inputv = Variable(input)
errD_real = netD(input)
errD_real = netD(inputv)
errD_real.backward(one)
# train with fake
noise.data.resize_(batch_size, nz, 1, 1)
noise.data.normal_(0, 1)
fake = netG(noise)
input.data.copy_(fake.data)
errD_fake = netD(input)
noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
noisev = Variable(noise)
fake = netG(noisev)
inputv.data.copy_(fake.data)
errD_fake = netD(inputv)
errD_fake.backward(mone)
errD = errD_real - errD_fake
optimizerD.step()
@ -205,9 +206,9 @@ for epoch in range(opt.niter):
netG.zero_grad()
# in case our last batch was the tail batch of the dataloader,
# make sure we feed a full batch of noise
noise.data.resize_(opt.batchSize, nz, 1, 1)
noise.data.normal_(0, 1)
fake = netG(noise)
noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
noisev = Variable(noise)
fake = netG(noisev)
errG = netD(fake)
errG.backward(one)
optimizerG.step()