Export generator configuration for future data generation.

This commit is contained in:
kopytjuk 2018-12-25 23:58:23 +01:00
parent 72853533a0
commit 23176f06f1
1 changed files with 11 additions and 0 deletions

11
main.py
View File

@ -12,6 +12,7 @@ import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import os
import json
import models.dcgan as dcgan
import models.mlp as mlp
@ -98,6 +99,11 @@ if __name__=="__main__":
nc = int(opt.nc)
n_extra_layers = int(opt.n_extra_layers)
# write out generator config to generate images together wth training checkpoints (.pth)
generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G}
with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg:
gcfg.write(json.dumps(generator_config)+"\n")
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
@ -114,6 +120,11 @@ if __name__=="__main__":
else:
netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
# write out generator config to generate images together wth training checkpoints (.pth)
generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G}
with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg:
gcfg.write(json.dumps(generator_config)+"\n")
netG.apply(weights_init)
if opt.netG != '': # load checkpoint if needed
netG.load_state_dict(torch.load(opt.netG))