Add generate script.

This commit is contained in:
kopytjuk 2018-12-25 23:58:42 +01:00
parent 3bb2685bd8
commit ebc9e181ef
1 changed files with 62 additions and 0 deletions

62
generate.py Normal file
View File

@ -0,0 +1,62 @@
from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import os
import json
import models.dcgan as dcgan
import models.mlp as mlp
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', required=True, type=str, help='path to generator config .json file')
parser.add_argument('-w', '--weights', required=True, type=str, help='path to generator weights .pth file')
parser.add_argument('-o', '--output_dir', required=True, type=str, help="path to to output directory")
parser.add_argument('-n', '--nimages', required=True, type=int, help="number of images to generate", default=1)
parser.add_argument('--cuda', action='store_true', help='enables cuda')
opt = parser.parse_args()
with open(opt.config, 'r') as gencfg:
generator_config = json.loads(gencfg.read())
imageSize = generator_config["imageSize"]
nz = generator_config["nz"]
nc = generator_config["nc"]
ngf = generator_config["ngf"]
noBN = generator_config["noBN"]
ngpu = generator_config["ngpu"]
mlp_G = generator_config["mlp_G"]
n_extra_layers = generator_config["n_extra_layers"]
if noBN:
netG = dcgan.DCGAN_G_nobn(imageSize, nz, nc, ngf, ngpu, n_extra_layers)
elif mlp_G:
netG = mlp.MLP_G(imageSize, nz, nc, ngf, ngpu)
else:
netG = dcgan.DCGAN_G(imageSize, nz, nc, ngf, ngpu, n_extra_layers)
# load weights
netG.load_state_dict(torch.load(opt.weights))
# initialize noise
fixed_noise = torch.FloatTensor(opt.nimages, nz, 1, 1).normal_(0, 1)
if opt.cuda:
netG.cuda()
fixed_noise = fixed_noise.cuda()
fake = netG(Variable(fixed_noise, volatile=True))
fake.data = fake.data.mul(0.5).add(0.5)
vutils.save_image(fake.data, os.path.join(opt.output_dir, "generated.png"))