Merge pull request #71 from kopytjuk/generate-images-script

Generate images script
This commit is contained in:
Soumith Chintala 2018-12-26 11:45:59 -05:00 committed by GitHub
commit f7a01e8200
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 0 deletions

63
generate.py Normal file
View File

@ -0,0 +1,63 @@
from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import os
import json
import models.dcgan as dcgan
import models.mlp as mlp
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', required=True, type=str, help='path to generator config .json file')
parser.add_argument('-w', '--weights', required=True, type=str, help='path to generator weights .pth file')
parser.add_argument('-o', '--output_dir', required=True, type=str, help="path to to output directory")
parser.add_argument('-n', '--nimages', required=True, type=int, help="number of images to generate", default=1)
parser.add_argument('--cuda', action='store_true', help='enables cuda')
opt = parser.parse_args()
with open(opt.config, 'r') as gencfg:
generator_config = json.loads(gencfg.read())
imageSize = generator_config["imageSize"]
nz = generator_config["nz"]
nc = generator_config["nc"]
ngf = generator_config["ngf"]
noBN = generator_config["noBN"]
ngpu = generator_config["ngpu"]
mlp_G = generator_config["mlp_G"]
n_extra_layers = generator_config["n_extra_layers"]
if noBN:
netG = dcgan.DCGAN_G_nobn(imageSize, nz, nc, ngf, ngpu, n_extra_layers)
elif mlp_G:
netG = mlp.MLP_G(imageSize, nz, nc, ngf, ngpu)
else:
netG = dcgan.DCGAN_G(imageSize, nz, nc, ngf, ngpu, n_extra_layers)
# load weights
netG.load_state_dict(torch.load(opt.weights))
# initialize noise
fixed_noise = torch.FloatTensor(opt.nimages, nz, 1, 1).normal_(0, 1)
if opt.cuda:
netG.cuda()
fixed_noise = fixed_noise.cuda()
fake = netG(fixed_noise)
fake.data = fake.data.mul(0.5).add(0.5)
for i in range(opt.nimages):
vutils.save_image(fake.data[i, ...].reshape((1, nc, imageSize, imageSize)), os.path.join(opt.output_dir, "generated_%02d.png"%i))

11
main.py
View File

@ -12,6 +12,7 @@ import torchvision.transforms as transforms
import torchvision.utils as vutils import torchvision.utils as vutils
from torch.autograd import Variable from torch.autograd import Variable
import os import os
import json
import models.dcgan as dcgan import models.dcgan as dcgan
import models.mlp as mlp import models.mlp as mlp
@ -98,6 +99,11 @@ if __name__=="__main__":
nc = int(opt.nc) nc = int(opt.nc)
n_extra_layers = int(opt.n_extra_layers) n_extra_layers = int(opt.n_extra_layers)
# write out generator config to generate images together wth training checkpoints (.pth)
generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G}
with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg:
gcfg.write(json.dumps(generator_config)+"\n")
# custom weights initialization called on netG and netD # custom weights initialization called on netG and netD
def weights_init(m): def weights_init(m):
classname = m.__class__.__name__ classname = m.__class__.__name__
@ -114,6 +120,11 @@ if __name__=="__main__":
else: else:
netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers) netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
# write out generator config to generate images together wth training checkpoints (.pth)
generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G}
with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg:
gcfg.write(json.dumps(generator_config)+"\n")
netG.apply(weights_init) netG.apply(weights_init)
if opt.netG != '': # load checkpoint if needed if opt.netG != '': # load checkpoint if needed
netG.load_state_dict(torch.load(opt.netG)) netG.load_state_dict(torch.load(opt.netG))