Add generate script.
This commit is contained in:
parent
3bb2685bd8
commit
ebc9e181ef
|
@ -0,0 +1,62 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.parallel
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.utils.data
|
||||||
|
import torchvision.datasets as dset
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import torchvision.utils as vutils
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
import models.dcgan as dcgan
|
||||||
|
import models.mlp as mlp
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-c', '--config', required=True, type=str, help='path to generator config .json file')
|
||||||
|
parser.add_argument('-w', '--weights', required=True, type=str, help='path to generator weights .pth file')
|
||||||
|
parser.add_argument('-o', '--output_dir', required=True, type=str, help="path to to output directory")
|
||||||
|
parser.add_argument('-n', '--nimages', required=True, type=int, help="number of images to generate", default=1)
|
||||||
|
parser.add_argument('--cuda', action='store_true', help='enables cuda')
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
with open(opt.config, 'r') as gencfg:
|
||||||
|
generator_config = json.loads(gencfg.read())
|
||||||
|
|
||||||
|
imageSize = generator_config["imageSize"]
|
||||||
|
nz = generator_config["nz"]
|
||||||
|
nc = generator_config["nc"]
|
||||||
|
ngf = generator_config["ngf"]
|
||||||
|
noBN = generator_config["noBN"]
|
||||||
|
ngpu = generator_config["ngpu"]
|
||||||
|
mlp_G = generator_config["mlp_G"]
|
||||||
|
n_extra_layers = generator_config["n_extra_layers"]
|
||||||
|
|
||||||
|
if noBN:
|
||||||
|
netG = dcgan.DCGAN_G_nobn(imageSize, nz, nc, ngf, ngpu, n_extra_layers)
|
||||||
|
elif mlp_G:
|
||||||
|
netG = mlp.MLP_G(imageSize, nz, nc, ngf, ngpu)
|
||||||
|
else:
|
||||||
|
netG = dcgan.DCGAN_G(imageSize, nz, nc, ngf, ngpu, n_extra_layers)
|
||||||
|
|
||||||
|
# load weights
|
||||||
|
netG.load_state_dict(torch.load(opt.weights))
|
||||||
|
|
||||||
|
# initialize noise
|
||||||
|
fixed_noise = torch.FloatTensor(opt.nimages, nz, 1, 1).normal_(0, 1)
|
||||||
|
|
||||||
|
if opt.cuda:
|
||||||
|
netG.cuda()
|
||||||
|
fixed_noise = fixed_noise.cuda()
|
||||||
|
|
||||||
|
fake = netG(Variable(fixed_noise, volatile=True))
|
||||||
|
fake.data = fake.data.mul(0.5).add(0.5)
|
||||||
|
|
||||||
|
vutils.save_image(fake.data, os.path.join(opt.output_dir, "generated.png"))
|
Loading…
Reference in New Issue