Add generate script.
This commit is contained in:
		
							parent
							
								
									3bb2685bd8
								
							
						
					
					
						commit
						ebc9e181ef
					
				
							
								
								
									
										62
									
								
								generate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								generate.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,62 @@
 | 
				
			|||||||
 | 
					from __future__ import print_function
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.nn.parallel
 | 
				
			||||||
 | 
					import torch.backends.cudnn as cudnn
 | 
				
			||||||
 | 
					import torch.optim as optim
 | 
				
			||||||
 | 
					import torch.utils.data
 | 
				
			||||||
 | 
					import torchvision.datasets as dset
 | 
				
			||||||
 | 
					import torchvision.transforms as transforms
 | 
				
			||||||
 | 
					import torchvision.utils as vutils
 | 
				
			||||||
 | 
					from torch.autograd import Variable
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import models.dcgan as dcgan
 | 
				
			||||||
 | 
					import models.mlp as mlp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__=="__main__":
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser.add_argument('-c', '--config', required=True, type=str, help='path to generator config .json file')
 | 
				
			||||||
 | 
					    parser.add_argument('-w', '--weights', required=True, type=str, help='path to generator weights .pth file')
 | 
				
			||||||
 | 
					    parser.add_argument('-o', '--output_dir', required=True, type=str, help="path to to output directory")
 | 
				
			||||||
 | 
					    parser.add_argument('-n', '--nimages', required=True, type=int, help="number of images to generate", default=1)
 | 
				
			||||||
 | 
					    parser.add_argument('--cuda', action='store_true', help='enables cuda')
 | 
				
			||||||
 | 
					    opt = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with open(opt.config, 'r') as gencfg:
 | 
				
			||||||
 | 
					        generator_config = json.loads(gencfg.read())
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    imageSize = generator_config["imageSize"]
 | 
				
			||||||
 | 
					    nz = generator_config["nz"]
 | 
				
			||||||
 | 
					    nc = generator_config["nc"]
 | 
				
			||||||
 | 
					    ngf = generator_config["ngf"]
 | 
				
			||||||
 | 
					    noBN = generator_config["noBN"]
 | 
				
			||||||
 | 
					    ngpu = generator_config["ngpu"]
 | 
				
			||||||
 | 
					    mlp_G = generator_config["mlp_G"]
 | 
				
			||||||
 | 
					    n_extra_layers = generator_config["n_extra_layers"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if noBN:
 | 
				
			||||||
 | 
					        netG = dcgan.DCGAN_G_nobn(imageSize, nz, nc, ngf, ngpu, n_extra_layers)
 | 
				
			||||||
 | 
					    elif mlp_G:
 | 
				
			||||||
 | 
					        netG = mlp.MLP_G(imageSize, nz, nc, ngf, ngpu)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        netG = dcgan.DCGAN_G(imageSize, nz, nc, ngf, ngpu, n_extra_layers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # load weights
 | 
				
			||||||
 | 
					    netG.load_state_dict(torch.load(opt.weights))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # initialize noise
 | 
				
			||||||
 | 
					    fixed_noise = torch.FloatTensor(opt.nimages, nz, 1, 1).normal_(0, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if opt.cuda:
 | 
				
			||||||
 | 
					        netG.cuda()
 | 
				
			||||||
 | 
					        fixed_noise = fixed_noise.cuda()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fake = netG(Variable(fixed_noise, volatile=True))
 | 
				
			||||||
 | 
					    fake.data = fake.data.mul(0.5).add(0.5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vutils.save_image(fake.data, os.path.join(opt.output_dir, "generated.png"))
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user