Export generator configuration for future data generation.

This commit is contained in:
kopytjuk 2018-12-25 23:58:23 +01:00
parent 72853533a0
commit 23176f06f1
1 changed files with 11 additions and 0 deletions

11
main.py
View File

@ -12,6 +12,7 @@ import torchvision.transforms as transforms
import torchvision.utils as vutils import torchvision.utils as vutils
from torch.autograd import Variable from torch.autograd import Variable
import os import os
import json
import models.dcgan as dcgan import models.dcgan as dcgan
import models.mlp as mlp import models.mlp as mlp
@ -98,6 +99,11 @@ if __name__=="__main__":
nc = int(opt.nc) nc = int(opt.nc)
n_extra_layers = int(opt.n_extra_layers) n_extra_layers = int(opt.n_extra_layers)
# write out generator config to generate images together wth training checkpoints (.pth)
generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G}
with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg:
gcfg.write(json.dumps(generator_config)+"\n")
# custom weights initialization called on netG and netD # custom weights initialization called on netG and netD
def weights_init(m): def weights_init(m):
classname = m.__class__.__name__ classname = m.__class__.__name__
@ -114,6 +120,11 @@ if __name__=="__main__":
else: else:
netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers) netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
# write out generator config to generate images together wth training checkpoints (.pth)
generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G}
with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg:
gcfg.write(json.dumps(generator_config)+"\n")
netG.apply(weights_init) netG.apply(weights_init)
if opt.netG != '': # load checkpoint if needed if opt.netG != '': # load checkpoint if needed
netG.load_state_dict(torch.load(opt.netG)) netG.load_state_dict(torch.load(opt.netG))