Export generator configuration for future data generation.
This commit is contained in:
parent
72853533a0
commit
23176f06f1
11
main.py
11
main.py
|
@ -12,6 +12,7 @@ import torchvision.transforms as transforms
|
||||||
import torchvision.utils as vutils
|
import torchvision.utils as vutils
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
import models.dcgan as dcgan
|
import models.dcgan as dcgan
|
||||||
import models.mlp as mlp
|
import models.mlp as mlp
|
||||||
|
@ -98,6 +99,11 @@ if __name__=="__main__":
|
||||||
nc = int(opt.nc)
|
nc = int(opt.nc)
|
||||||
n_extra_layers = int(opt.n_extra_layers)
|
n_extra_layers = int(opt.n_extra_layers)
|
||||||
|
|
||||||
|
# write out generator config to generate images together wth training checkpoints (.pth)
|
||||||
|
generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G}
|
||||||
|
with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg:
|
||||||
|
gcfg.write(json.dumps(generator_config)+"\n")
|
||||||
|
|
||||||
# custom weights initialization called on netG and netD
|
# custom weights initialization called on netG and netD
|
||||||
def weights_init(m):
|
def weights_init(m):
|
||||||
classname = m.__class__.__name__
|
classname = m.__class__.__name__
|
||||||
|
@ -114,6 +120,11 @@ if __name__=="__main__":
|
||||||
else:
|
else:
|
||||||
netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
|
netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
|
||||||
|
|
||||||
|
# write out generator config to generate images together wth training checkpoints (.pth)
|
||||||
|
generator_config = {"imageSize": opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf, "ngpu": ngpu, "n_extra_layers": n_extra_layers, "noBN": opt.noBN, "mlp_G": opt.mlp_G}
|
||||||
|
with open(os.path.join(opt.experiment, "generator_config.json"), 'w') as gcfg:
|
||||||
|
gcfg.write(json.dumps(generator_config)+"\n")
|
||||||
|
|
||||||
netG.apply(weights_init)
|
netG.apply(weights_init)
|
||||||
if opt.netG != '': # load checkpoint if needed
|
if opt.netG != '': # load checkpoint if needed
|
||||||
netG.load_state_dict(torch.load(opt.netG))
|
netG.load_state_dict(torch.load(opt.netG))
|
||||||
|
|
Loading…
Reference in New Issue