feat(upgrade): Solve issues from previous PyTorch version
This commit is contained in:
parent
3d765bfa64
commit
51e93db5e0
|
@ -31,7 +31,7 @@ Two main empirical claims:
|
||||||
**With DCGAN:**
|
**With DCGAN:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python main.py --dataset lsun --dataroot [lsun-train-folder] --cuda
|
python main.py --dataset folder --dataroot data/maps --cuda
|
||||||
```
|
```
|
||||||
|
|
||||||
**With MLP:**
|
**With MLP:**
|
||||||
|
|
10
main.py
10
main.py
|
@ -67,15 +67,15 @@ if __name__=="__main__":
|
||||||
# folder dataset
|
# folder dataset
|
||||||
dataset = dset.ImageFolder(root=opt.dataroot,
|
dataset = dset.ImageFolder(root=opt.dataroot,
|
||||||
transform=transforms.Compose([
|
transform=transforms.Compose([
|
||||||
transforms.Scale(opt.imageSize),
|
transforms.Resize(opt.imageSize),
|
||||||
transforms.CenterCrop(opt.imageSize),
|
transforms.CenterCrop(opt.imageSize),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
]))
|
]))
|
||||||
elif opt.dataset == 'lsun':
|
elif opt.dataset == 'lsun':
|
||||||
dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
|
dataset = dset.LSUN(root=opt.dataroot, classes=['bedroom_train'],
|
||||||
transform=transforms.Compose([
|
transform=transforms.Compose([
|
||||||
transforms.Scale(opt.imageSize),
|
transforms.Resize(opt.imageSize),
|
||||||
transforms.CenterCrop(opt.imageSize),
|
transforms.CenterCrop(opt.imageSize),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
|
@ -83,7 +83,7 @@ if __name__=="__main__":
|
||||||
elif opt.dataset == 'cifar10':
|
elif opt.dataset == 'cifar10':
|
||||||
dataset = dset.CIFAR10(root=opt.dataroot, download=True,
|
dataset = dset.CIFAR10(root=opt.dataroot, download=True,
|
||||||
transform=transforms.Compose([
|
transform=transforms.Compose([
|
||||||
transforms.Scale(opt.imageSize),
|
transforms.Resize(opt.imageSize),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
])
|
])
|
||||||
|
@ -185,7 +185,7 @@ if __name__=="__main__":
|
||||||
for p in netD.parameters():
|
for p in netD.parameters():
|
||||||
p.data.clamp_(opt.clamp_lower, opt.clamp_upper)
|
p.data.clamp_(opt.clamp_lower, opt.clamp_upper)
|
||||||
|
|
||||||
data = data_iter.next()
|
data = next(data_iter)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
# train with real
|
# train with real
|
||||||
|
|
|
@ -1,2 +1,3 @@
|
||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
|
lmdb
|
||||||
|
|
Loading…
Reference in New Issue