Bugfix to run the code on windows machines.
This commit is contained in:
		
							parent
							
								
									f81eafd2aa
								
							
						
					
					
						commit
						afd82ac354
					
				
							
								
								
									
										382
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										382
									
								
								main.py
									
									
									
									
									
								
							@ -16,214 +16,216 @@ import os
 | 
				
			|||||||
import models.dcgan as dcgan
 | 
					import models.dcgan as dcgan
 | 
				
			||||||
import models.mlp as mlp
 | 
					import models.mlp as mlp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
parser = argparse.ArgumentParser()
 | 
					if __name__=="__main__":
 | 
				
			||||||
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ')
 | 
					 | 
				
			||||||
parser.add_argument('--dataroot', required=True, help='path to dataset')
 | 
					 | 
				
			||||||
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
 | 
					 | 
				
			||||||
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
 | 
					 | 
				
			||||||
parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
 | 
					 | 
				
			||||||
parser.add_argument('--nc', type=int, default=3, help='input image channels')
 | 
					 | 
				
			||||||
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
 | 
					 | 
				
			||||||
parser.add_argument('--ngf', type=int, default=64)
 | 
					 | 
				
			||||||
parser.add_argument('--ndf', type=int, default=64)
 | 
					 | 
				
			||||||
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
 | 
					 | 
				
			||||||
parser.add_argument('--lrD', type=float, default=0.00005, help='learning rate for Critic, default=0.00005')
 | 
					 | 
				
			||||||
parser.add_argument('--lrG', type=float, default=0.00005, help='learning rate for Generator, default=0.00005')
 | 
					 | 
				
			||||||
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
 | 
					 | 
				
			||||||
parser.add_argument('--cuda'  , action='store_true', help='enables cuda')
 | 
					 | 
				
			||||||
parser.add_argument('--ngpu'  , type=int, default=1, help='number of GPUs to use')
 | 
					 | 
				
			||||||
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
 | 
					 | 
				
			||||||
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
 | 
					 | 
				
			||||||
parser.add_argument('--clamp_lower', type=float, default=-0.01)
 | 
					 | 
				
			||||||
parser.add_argument('--clamp_upper', type=float, default=0.01)
 | 
					 | 
				
			||||||
parser.add_argument('--Diters', type=int, default=5, help='number of D iters per each G iter')
 | 
					 | 
				
			||||||
parser.add_argument('--noBN', action='store_true', help='use batchnorm or not (only for DCGAN)')
 | 
					 | 
				
			||||||
parser.add_argument('--mlp_G', action='store_true', help='use MLP for G')
 | 
					 | 
				
			||||||
parser.add_argument('--mlp_D', action='store_true', help='use MLP for D')
 | 
					 | 
				
			||||||
parser.add_argument('--n_extra_layers', type=int, default=0, help='Number of extra layers on gen and disc')
 | 
					 | 
				
			||||||
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
 | 
					 | 
				
			||||||
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
 | 
					 | 
				
			||||||
opt = parser.parse_args()
 | 
					 | 
				
			||||||
print(opt)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
if opt.experiment is None:
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    opt.experiment = 'samples'
 | 
					    parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ')
 | 
				
			||||||
os.system('mkdir {0}'.format(opt.experiment))
 | 
					    parser.add_argument('--dataroot', required=True, help='path to dataset')
 | 
				
			||||||
 | 
					    parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
 | 
				
			||||||
 | 
					    parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
 | 
				
			||||||
 | 
					    parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
 | 
				
			||||||
 | 
					    parser.add_argument('--nc', type=int, default=3, help='input image channels')
 | 
				
			||||||
 | 
					    parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
 | 
				
			||||||
 | 
					    parser.add_argument('--ngf', type=int, default=64)
 | 
				
			||||||
 | 
					    parser.add_argument('--ndf', type=int, default=64)
 | 
				
			||||||
 | 
					    parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
 | 
				
			||||||
 | 
					    parser.add_argument('--lrD', type=float, default=0.00005, help='learning rate for Critic, default=0.00005')
 | 
				
			||||||
 | 
					    parser.add_argument('--lrG', type=float, default=0.00005, help='learning rate for Generator, default=0.00005')
 | 
				
			||||||
 | 
					    parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
 | 
				
			||||||
 | 
					    parser.add_argument('--cuda'  , action='store_true', help='enables cuda')
 | 
				
			||||||
 | 
					    parser.add_argument('--ngpu'  , type=int, default=1, help='number of GPUs to use')
 | 
				
			||||||
 | 
					    parser.add_argument('--netG', default='', help="path to netG (to continue training)")
 | 
				
			||||||
 | 
					    parser.add_argument('--netD', default='', help="path to netD (to continue training)")
 | 
				
			||||||
 | 
					    parser.add_argument('--clamp_lower', type=float, default=-0.01)
 | 
				
			||||||
 | 
					    parser.add_argument('--clamp_upper', type=float, default=0.01)
 | 
				
			||||||
 | 
					    parser.add_argument('--Diters', type=int, default=5, help='number of D iters per each G iter')
 | 
				
			||||||
 | 
					    parser.add_argument('--noBN', action='store_true', help='use batchnorm or not (only for DCGAN)')
 | 
				
			||||||
 | 
					    parser.add_argument('--mlp_G', action='store_true', help='use MLP for G')
 | 
				
			||||||
 | 
					    parser.add_argument('--mlp_D', action='store_true', help='use MLP for D')
 | 
				
			||||||
 | 
					    parser.add_argument('--n_extra_layers', type=int, default=0, help='Number of extra layers on gen and disc')
 | 
				
			||||||
 | 
					    parser.add_argument('--experiment', default=None, help='Where to store samples and models')
 | 
				
			||||||
 | 
					    parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
 | 
				
			||||||
 | 
					    opt = parser.parse_args()
 | 
				
			||||||
 | 
					    print(opt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
opt.manualSeed = random.randint(1, 10000) # fix seed
 | 
					    if opt.experiment is None:
 | 
				
			||||||
print("Random Seed: ", opt.manualSeed)
 | 
					        opt.experiment = 'samples'
 | 
				
			||||||
random.seed(opt.manualSeed)
 | 
					    os.system('mkdir {0}'.format(opt.experiment))
 | 
				
			||||||
torch.manual_seed(opt.manualSeed)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
cudnn.benchmark = True
 | 
					    opt.manualSeed = random.randint(1, 10000) # fix seed
 | 
				
			||||||
 | 
					    print("Random Seed: ", opt.manualSeed)
 | 
				
			||||||
 | 
					    random.seed(opt.manualSeed)
 | 
				
			||||||
 | 
					    torch.manual_seed(opt.manualSeed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if torch.cuda.is_available() and not opt.cuda:
 | 
					    cudnn.benchmark = True
 | 
				
			||||||
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
if opt.dataset in ['imagenet', 'folder', 'lfw']:
 | 
					    if torch.cuda.is_available() and not opt.cuda:
 | 
				
			||||||
    # folder dataset
 | 
					        print("WARNING: You have a CUDA device, so you should probably run with --cuda")
 | 
				
			||||||
    dataset = dset.ImageFolder(root=opt.dataroot,
 | 
					 | 
				
			||||||
                               transform=transforms.Compose([
 | 
					 | 
				
			||||||
                                   transforms.Scale(opt.imageSize),
 | 
					 | 
				
			||||||
                                   transforms.CenterCrop(opt.imageSize),
 | 
					 | 
				
			||||||
                                   transforms.ToTensor(),
 | 
					 | 
				
			||||||
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
 | 
					 | 
				
			||||||
                               ]))
 | 
					 | 
				
			||||||
elif opt.dataset == 'lsun':
 | 
					 | 
				
			||||||
    dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
 | 
					 | 
				
			||||||
                        transform=transforms.Compose([
 | 
					 | 
				
			||||||
                            transforms.Scale(opt.imageSize),
 | 
					 | 
				
			||||||
                            transforms.CenterCrop(opt.imageSize),
 | 
					 | 
				
			||||||
                            transforms.ToTensor(),
 | 
					 | 
				
			||||||
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
 | 
					 | 
				
			||||||
                        ]))
 | 
					 | 
				
			||||||
elif opt.dataset == 'cifar10':
 | 
					 | 
				
			||||||
    dataset = dset.CIFAR10(root=opt.dataroot, download=True,
 | 
					 | 
				
			||||||
                           transform=transforms.Compose([
 | 
					 | 
				
			||||||
                               transforms.Scale(opt.imageSize),
 | 
					 | 
				
			||||||
                               transforms.ToTensor(),
 | 
					 | 
				
			||||||
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
 | 
					 | 
				
			||||||
                           ])
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
assert dataset
 | 
					 | 
				
			||||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
 | 
					 | 
				
			||||||
                                         shuffle=True, num_workers=int(opt.workers))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
ngpu = int(opt.ngpu)
 | 
					    if opt.dataset in ['imagenet', 'folder', 'lfw']:
 | 
				
			||||||
nz = int(opt.nz)
 | 
					        # folder dataset
 | 
				
			||||||
ngf = int(opt.ngf)
 | 
					        dataset = dset.ImageFolder(root=opt.dataroot,
 | 
				
			||||||
ndf = int(opt.ndf)
 | 
					                                transform=transforms.Compose([
 | 
				
			||||||
nc = int(opt.nc)
 | 
					                                    transforms.Scale(opt.imageSize),
 | 
				
			||||||
n_extra_layers = int(opt.n_extra_layers)
 | 
					                                    transforms.CenterCrop(opt.imageSize),
 | 
				
			||||||
 | 
					                                    transforms.ToTensor(),
 | 
				
			||||||
 | 
					                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
 | 
				
			||||||
 | 
					                                ]))
 | 
				
			||||||
 | 
					    elif opt.dataset == 'lsun':
 | 
				
			||||||
 | 
					        dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
 | 
				
			||||||
 | 
					                            transform=transforms.Compose([
 | 
				
			||||||
 | 
					                                transforms.Scale(opt.imageSize),
 | 
				
			||||||
 | 
					                                transforms.CenterCrop(opt.imageSize),
 | 
				
			||||||
 | 
					                                transforms.ToTensor(),
 | 
				
			||||||
 | 
					                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
 | 
				
			||||||
 | 
					                            ]))
 | 
				
			||||||
 | 
					    elif opt.dataset == 'cifar10':
 | 
				
			||||||
 | 
					        dataset = dset.CIFAR10(root=opt.dataroot, download=True,
 | 
				
			||||||
 | 
					                            transform=transforms.Compose([
 | 
				
			||||||
 | 
					                                transforms.Scale(opt.imageSize),
 | 
				
			||||||
 | 
					                                transforms.ToTensor(),
 | 
				
			||||||
 | 
					                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
 | 
				
			||||||
 | 
					                            ])
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    assert dataset
 | 
				
			||||||
 | 
					    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
 | 
				
			||||||
 | 
					                                            shuffle=True, num_workers=int(opt.workers))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# custom weights initialization called on netG and netD
 | 
					    ngpu = int(opt.ngpu)
 | 
				
			||||||
def weights_init(m):
 | 
					    nz = int(opt.nz)
 | 
				
			||||||
    classname = m.__class__.__name__
 | 
					    ngf = int(opt.ngf)
 | 
				
			||||||
    if classname.find('Conv') != -1:
 | 
					    ndf = int(opt.ndf)
 | 
				
			||||||
        m.weight.data.normal_(0.0, 0.02)
 | 
					    nc = int(opt.nc)
 | 
				
			||||||
    elif classname.find('BatchNorm') != -1:
 | 
					    n_extra_layers = int(opt.n_extra_layers)
 | 
				
			||||||
        m.weight.data.normal_(1.0, 0.02)
 | 
					 | 
				
			||||||
        m.bias.data.fill_(0)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
if opt.noBN:
 | 
					    # custom weights initialization called on netG and netD
 | 
				
			||||||
    netG = dcgan.DCGAN_G_nobn(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
 | 
					    def weights_init(m):
 | 
				
			||||||
elif opt.mlp_G:
 | 
					        classname = m.__class__.__name__
 | 
				
			||||||
    netG = mlp.MLP_G(opt.imageSize, nz, nc, ngf, ngpu)
 | 
					        if classname.find('Conv') != -1:
 | 
				
			||||||
else:
 | 
					            m.weight.data.normal_(0.0, 0.02)
 | 
				
			||||||
    netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
 | 
					        elif classname.find('BatchNorm') != -1:
 | 
				
			||||||
 | 
					            m.weight.data.normal_(1.0, 0.02)
 | 
				
			||||||
 | 
					            m.bias.data.fill_(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
netG.apply(weights_init)
 | 
					    if opt.noBN:
 | 
				
			||||||
if opt.netG != '': # load checkpoint if needed
 | 
					        netG = dcgan.DCGAN_G_nobn(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
 | 
				
			||||||
    netG.load_state_dict(torch.load(opt.netG))
 | 
					    elif opt.mlp_G:
 | 
				
			||||||
print(netG)
 | 
					        netG = mlp.MLP_G(opt.imageSize, nz, nc, ngf, ngpu)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if opt.mlp_D:
 | 
					    netG.apply(weights_init)
 | 
				
			||||||
    netD = mlp.MLP_D(opt.imageSize, nz, nc, ndf, ngpu)
 | 
					    if opt.netG != '': # load checkpoint if needed
 | 
				
			||||||
else:
 | 
					        netG.load_state_dict(torch.load(opt.netG))
 | 
				
			||||||
    netD = dcgan.DCGAN_D(opt.imageSize, nz, nc, ndf, ngpu, n_extra_layers)
 | 
					    print(netG)
 | 
				
			||||||
    netD.apply(weights_init)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
if opt.netD != '':
 | 
					    if opt.mlp_D:
 | 
				
			||||||
    netD.load_state_dict(torch.load(opt.netD))
 | 
					        netD = mlp.MLP_D(opt.imageSize, nz, nc, ndf, ngpu)
 | 
				
			||||||
print(netD)
 | 
					    else:
 | 
				
			||||||
 | 
					        netD = dcgan.DCGAN_D(opt.imageSize, nz, nc, ndf, ngpu, n_extra_layers)
 | 
				
			||||||
 | 
					        netD.apply(weights_init)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
 | 
					    if opt.netD != '':
 | 
				
			||||||
noise = torch.FloatTensor(opt.batchSize, nz, 1, 1)
 | 
					        netD.load_state_dict(torch.load(opt.netD))
 | 
				
			||||||
fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1)
 | 
					    print(netD)
 | 
				
			||||||
one = torch.FloatTensor([1])
 | 
					 | 
				
			||||||
mone = one * -1
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
if opt.cuda:
 | 
					    input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
 | 
				
			||||||
    netD.cuda()
 | 
					    noise = torch.FloatTensor(opt.batchSize, nz, 1, 1)
 | 
				
			||||||
    netG.cuda()
 | 
					    fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1)
 | 
				
			||||||
    input = input.cuda()
 | 
					    one = torch.FloatTensor([1])
 | 
				
			||||||
    one, mone = one.cuda(), mone.cuda()
 | 
					    mone = one * -1
 | 
				
			||||||
    noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# setup optimizer
 | 
					    if opt.cuda:
 | 
				
			||||||
if opt.adam:
 | 
					        netD.cuda()
 | 
				
			||||||
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lrD, betas=(opt.beta1, 0.999))
 | 
					        netG.cuda()
 | 
				
			||||||
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lrG, betas=(opt.beta1, 0.999))
 | 
					        input = input.cuda()
 | 
				
			||||||
else:
 | 
					        one, mone = one.cuda(), mone.cuda()
 | 
				
			||||||
    optimizerD = optim.RMSprop(netD.parameters(), lr = opt.lrD)
 | 
					        noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
 | 
				
			||||||
    optimizerG = optim.RMSprop(netG.parameters(), lr = opt.lrG)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
gen_iterations = 0
 | 
					    # setup optimizer
 | 
				
			||||||
for epoch in range(opt.niter):
 | 
					    if opt.adam:
 | 
				
			||||||
    data_iter = iter(dataloader)
 | 
					        optimizerD = optim.Adam(netD.parameters(), lr=opt.lrD, betas=(opt.beta1, 0.999))
 | 
				
			||||||
    i = 0
 | 
					        optimizerG = optim.Adam(netG.parameters(), lr=opt.lrG, betas=(opt.beta1, 0.999))
 | 
				
			||||||
    while i < len(dataloader):
 | 
					    else:
 | 
				
			||||||
        ############################
 | 
					        optimizerD = optim.RMSprop(netD.parameters(), lr = opt.lrD)
 | 
				
			||||||
        # (1) Update D network
 | 
					        optimizerG = optim.RMSprop(netG.parameters(), lr = opt.lrG)
 | 
				
			||||||
        ###########################
 | 
					 | 
				
			||||||
        for p in netD.parameters(): # reset requires_grad
 | 
					 | 
				
			||||||
            p.requires_grad = True # they are set to False below in netG update
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # train the discriminator Diters times
 | 
					    gen_iterations = 0
 | 
				
			||||||
        if gen_iterations < 25 or gen_iterations % 500 == 0:
 | 
					    for epoch in range(opt.niter):
 | 
				
			||||||
            Diters = 100
 | 
					        data_iter = iter(dataloader)
 | 
				
			||||||
        else:
 | 
					        i = 0
 | 
				
			||||||
            Diters = opt.Diters
 | 
					        while i < len(dataloader):
 | 
				
			||||||
        j = 0
 | 
					            ############################
 | 
				
			||||||
        while j < Diters and i < len(dataloader):
 | 
					            # (1) Update D network
 | 
				
			||||||
            j += 1
 | 
					            ###########################
 | 
				
			||||||
 | 
					            for p in netD.parameters(): # reset requires_grad
 | 
				
			||||||
 | 
					                p.requires_grad = True # they are set to False below in netG update
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # clamp parameters to a cube
 | 
					            # train the discriminator Diters times
 | 
				
			||||||
 | 
					            if gen_iterations < 25 or gen_iterations % 500 == 0:
 | 
				
			||||||
 | 
					                Diters = 100
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                Diters = opt.Diters
 | 
				
			||||||
 | 
					            j = 0
 | 
				
			||||||
 | 
					            while j < Diters and i < len(dataloader):
 | 
				
			||||||
 | 
					                j += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # clamp parameters to a cube
 | 
				
			||||||
 | 
					                for p in netD.parameters():
 | 
				
			||||||
 | 
					                    p.data.clamp_(opt.clamp_lower, opt.clamp_upper)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                data = data_iter.next()
 | 
				
			||||||
 | 
					                i += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # train with real
 | 
				
			||||||
 | 
					                real_cpu, _ = data
 | 
				
			||||||
 | 
					                netD.zero_grad()
 | 
				
			||||||
 | 
					                batch_size = real_cpu.size(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if opt.cuda:
 | 
				
			||||||
 | 
					                    real_cpu = real_cpu.cuda()
 | 
				
			||||||
 | 
					                input.resize_as_(real_cpu).copy_(real_cpu)
 | 
				
			||||||
 | 
					                inputv = Variable(input)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                errD_real = netD(inputv)
 | 
				
			||||||
 | 
					                errD_real.backward(one)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # train with fake
 | 
				
			||||||
 | 
					                noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
 | 
				
			||||||
 | 
					                noisev = Variable(noise, volatile = True) # totally freeze netG
 | 
				
			||||||
 | 
					                fake = Variable(netG(noisev).data)
 | 
				
			||||||
 | 
					                inputv = fake
 | 
				
			||||||
 | 
					                errD_fake = netD(inputv)
 | 
				
			||||||
 | 
					                errD_fake.backward(mone)
 | 
				
			||||||
 | 
					                errD = errD_real - errD_fake
 | 
				
			||||||
 | 
					                optimizerD.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            ############################
 | 
				
			||||||
 | 
					            # (2) Update G network
 | 
				
			||||||
 | 
					            ###########################
 | 
				
			||||||
            for p in netD.parameters():
 | 
					            for p in netD.parameters():
 | 
				
			||||||
                p.data.clamp_(opt.clamp_lower, opt.clamp_upper)
 | 
					                p.requires_grad = False # to avoid computation
 | 
				
			||||||
 | 
					            netG.zero_grad()
 | 
				
			||||||
            data = data_iter.next()
 | 
					            # in case our last batch was the tail batch of the dataloader,
 | 
				
			||||||
            i += 1
 | 
					            # make sure we feed a full batch of noise
 | 
				
			||||||
 | 
					 | 
				
			||||||
            # train with real
 | 
					 | 
				
			||||||
            real_cpu, _ = data
 | 
					 | 
				
			||||||
            netD.zero_grad()
 | 
					 | 
				
			||||||
            batch_size = real_cpu.size(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if opt.cuda:
 | 
					 | 
				
			||||||
                real_cpu = real_cpu.cuda()
 | 
					 | 
				
			||||||
            input.resize_as_(real_cpu).copy_(real_cpu)
 | 
					 | 
				
			||||||
            inputv = Variable(input)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            errD_real = netD(inputv)
 | 
					 | 
				
			||||||
            errD_real.backward(one)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # train with fake
 | 
					 | 
				
			||||||
            noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
 | 
					            noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
 | 
				
			||||||
            noisev = Variable(noise, volatile = True) # totally freeze netG
 | 
					            noisev = Variable(noise)
 | 
				
			||||||
            fake = Variable(netG(noisev).data)
 | 
					            fake = netG(noisev)
 | 
				
			||||||
            inputv = fake
 | 
					            errG = netD(fake)
 | 
				
			||||||
            errD_fake = netD(inputv)
 | 
					            errG.backward(one)
 | 
				
			||||||
            errD_fake.backward(mone)
 | 
					            optimizerG.step()
 | 
				
			||||||
            errD = errD_real - errD_fake
 | 
					            gen_iterations += 1
 | 
				
			||||||
            optimizerD.step()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ############################
 | 
					            print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
 | 
				
			||||||
        # (2) Update G network
 | 
					                % (epoch, opt.niter, i, len(dataloader), gen_iterations,
 | 
				
			||||||
        ###########################
 | 
					                errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))
 | 
				
			||||||
        for p in netD.parameters():
 | 
					            if gen_iterations % 500 == 0:
 | 
				
			||||||
            p.requires_grad = False # to avoid computation
 | 
					                real_cpu = real_cpu.mul(0.5).add(0.5)
 | 
				
			||||||
        netG.zero_grad()
 | 
					                vutils.save_image(real_cpu, '{0}/real_samples.png'.format(opt.experiment))
 | 
				
			||||||
        # in case our last batch was the tail batch of the dataloader,
 | 
					                fake = netG(Variable(fixed_noise, volatile=True))
 | 
				
			||||||
        # make sure we feed a full batch of noise
 | 
					                fake.data = fake.data.mul(0.5).add(0.5)
 | 
				
			||||||
        noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
 | 
					                vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations))
 | 
				
			||||||
        noisev = Variable(noise)
 | 
					 | 
				
			||||||
        fake = netG(noisev)
 | 
					 | 
				
			||||||
        errG = netD(fake)
 | 
					 | 
				
			||||||
        errG.backward(one)
 | 
					 | 
				
			||||||
        optimizerG.step()
 | 
					 | 
				
			||||||
        gen_iterations += 1
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
 | 
					        # do checkpointing
 | 
				
			||||||
            % (epoch, opt.niter, i, len(dataloader), gen_iterations,
 | 
					        torch.save(netG.state_dict(), '{0}/netG_epoch_{1}.pth'.format(opt.experiment, epoch))
 | 
				
			||||||
            errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))
 | 
					        torch.save(netD.state_dict(), '{0}/netD_epoch_{1}.pth'.format(opt.experiment, epoch))
 | 
				
			||||||
        if gen_iterations % 500 == 0:
 | 
					 | 
				
			||||||
            real_cpu = real_cpu.mul(0.5).add(0.5)
 | 
					 | 
				
			||||||
            vutils.save_image(real_cpu, '{0}/real_samples.png'.format(opt.experiment))
 | 
					 | 
				
			||||||
            fake = netG(Variable(fixed_noise, volatile=True))
 | 
					 | 
				
			||||||
            fake.data = fake.data.mul(0.5).add(0.5)
 | 
					 | 
				
			||||||
            vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # do checkpointing
 | 
					 | 
				
			||||||
    torch.save(netG.state_dict(), '{0}/netG_epoch_{1}.pth'.format(opt.experiment, epoch))
 | 
					 | 
				
			||||||
    torch.save(netD.state_dict(), '{0}/netD_epoch_{1}.pth'.format(opt.experiment, epoch))
 | 
					 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user