Merge pull request #71 from kopytjuk/generate-images-script
Generate images script
This commit is contained in:
commit
f7a01e8200
|
@ -0,0 +1,63 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.parallel
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.utils.data
|
||||||
|
import torchvision.datasets as dset
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import torchvision.utils as vutils
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
import models.dcgan as dcgan
|
||||||
|
import models.mlp as mlp
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-c', '--config', required=True, type=str, help='path to generator config .json file')
|
||||||
|
parser.add_argument('-w', '--weights', required=True, type=str, help='path to generator weights .pth file')
|
||||||
|
parser.add_argument('-o', '--output_dir', required=True, type=str, help="path to to output directory")
|
||||||
|
parser.add_argument('-n', '--nimages', required=True, type=int, help="number of images to generate", default=1)
|
||||||
|
parser.add_argument('--cuda', action='store_true', help='enables cuda')
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
with open(opt.config, 'r') as gencfg:
|
||||||
|
generator_config = json.loads(gencfg.read())
|
||||||
|
|
||||||
|
imageSize = generator_config["imageSize"]
|
||||||
|
nz = generator_config["nz"]
|
||||||
|
nc = generator_config["nc"]
|
||||||
|
ngf = generator_config["ngf"]
|
||||||
|
noBN = generator_config["noBN"]
|
||||||
|
ngpu = generator_config["ngpu"]
|
||||||
|
mlp_G = generator_config["mlp_G"]
|
||||||
|
n_extra_layers = generator_config["n_extra_layers"]
|
||||||
|
|
||||||
|
if noBN:
|
||||||
|
netG = dcgan.DCGAN_G_nobn(imageSize, nz, nc, ngf, ngpu, n_extra_layers)
|
||||||
|
elif mlp_G:
|
||||||
|
netG = mlp.MLP_G(imageSize, nz, nc, ngf, ngpu)
|
||||||
|
else:
|
||||||
|
netG = dcgan.DCGAN_G(imageSize, nz, nc, ngf, ngpu, n_extra_layers)
|
||||||
|
|
||||||
|
# load weights
|
||||||
|
netG.load_state_dict(torch.load(opt.weights))
|
||||||
|
|
||||||
|
# initialize noise
|
||||||
|
fixed_noise = torch.FloatTensor(opt.nimages, nz, 1, 1).normal_(0, 1)
|
||||||
|
|
||||||
|
if opt.cuda:
|
||||||
|
netG.cuda()
|
||||||
|
fixed_noise = fixed_noise.cuda()
|
||||||
|
|
||||||
|
fake = netG(fixed_noise)
|
||||||
|
fake.data = fake.data.mul(0.5).add(0.5)
|
||||||
|
|
||||||
|
for i in range(opt.nimages):
|
||||||
|
vutils.save_image(fake.data[i, ...].reshape((1, nc, imageSize, imageSize)), os.path.join(opt.output_dir, "generated_%02d.png"%i))
|
11
main.py
11
main.py
|
@ -12,6 +12,7 @@ import torchvision.transforms as transforms
|
||||||
import torchvision.utils as vutils
|
import torchvision.utils as vutils
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
import models.dcgan as dcgan
|
import models.dcgan as dcgan
|
||||||
import models.mlp as mlp
|
import models.mlp as mlp
|
||||||
|
@ -98,6 +99,11 @@ if __name__=="__main__":
|
||||||
nc = int(opt.nc)
|
nc = int(opt.nc)
|
||||||
n_extra_layers = int(opt.n_extra_layers)
|
n_extra_layers = int(opt.n_extra_layers)
|
||||||
|
|
||||||
|
# write out generator config to generate images together wth training checkpoints (.pth)
|
||||||
|
generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G}
|
||||||
|
with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg:
|
||||||
|
gcfg.write(json.dumps(generator_config)+"\n")
|
||||||
|
|
||||||
# custom weights initialization called on netG and netD
|
# custom weights initialization called on netG and netD
|
||||||
def weights_init(m):
|
def weights_init(m):
|
||||||
classname = m.__class__.__name__
|
classname = m.__class__.__name__
|
||||||
|
@ -114,6 +120,11 @@ if __name__=="__main__":
|
||||||
else:
|
else:
|
||||||
netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
|
netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
|
||||||
|
|
||||||
|
# write out generator config to generate images together wth training checkpoints (.pth)
|
||||||
|
generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G}
|
||||||
|
with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg:
|
||||||
|
gcfg.write(json.dumps(generator_config)+"\n")
|
||||||
|
|
||||||
netG.apply(weights_init)
|
netG.apply(weights_init)
|
||||||
if opt.netG != '': # load checkpoint if needed
|
if opt.netG != '': # load checkpoint if needed
|
||||||
netG.load_state_dict(torch.load(opt.netG))
|
netG.load_state_dict(torch.load(opt.netG))
|
||||||
|
|
Loading…
Reference in New Issue