From 3bb2685bd8ed3210d2c1b317f7a2d07c5a3002d8 Mon Sep 17 00:00:00 2001 From: kopytjuk Date: Tue, 25 Dec 2018 23:58:23 +0100 Subject: [PATCH] Export generator configuration for future data generation. --- main.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/main.py b/main.py index a9cf68f..36fc2d3 100644 --- a/main.py +++ b/main.py @@ -12,6 +12,7 @@ import torchvision.transforms as transforms import torchvision.utils as vutils from torch.autograd import Variable import os +import json import models.dcgan as dcgan import models.mlp as mlp @@ -98,6 +99,11 @@ if __name__=="__main__": nc = int(opt.nc) n_extra_layers = int(opt.n_extra_layers) + # write out generator config to generate images together wth training checkpoints (.pth) + generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G} + with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg: + gcfg.write(json.dumps(generator_config)+"\n") + # custom weights initialization called on netG and netD def weights_init(m): classname = m.__class__.__name__ @@ -114,6 +120,11 @@ if __name__=="__main__": else: netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers) + # write out generator config to generate images together wth training checkpoints (.pth) + generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G} + with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg: + gcfg.write(json.dumps(generator_config)+"\n") + netG.apply(weights_init) if opt.netG != '': # load checkpoint if needed netG.load_state_dict(torch.load(opt.netG))