feat(upgrade): Solve issues from previous PyTorch version

This commit is contained in:
Paul Corbalan 2023-12-16 12:35:13 +01:00
parent 3d765bfa64
commit 51e93db5e0
3 changed files with 7 additions and 6 deletions

View File

@ -31,7 +31,7 @@ Two main empirical claims:
**With DCGAN:**
```bash
python main.py --dataset lsun --dataroot [lsun-train-folder] --cuda
python main.py --dataset folder --dataroot data/maps --cuda
```
**With MLP:**

10
main.py
View File

@ -67,15 +67,15 @@ if __name__=="__main__":
# folder dataset
dataset = dset.ImageFolder(root=opt.dataroot,
transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.Resize(opt.imageSize),
transforms.CenterCrop(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
elif opt.dataset == 'lsun':
dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
dataset = dset.LSUN(root=opt.dataroot, classes=['bedroom_train'],
transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.Resize(opt.imageSize),
transforms.CenterCrop(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
@ -83,7 +83,7 @@ if __name__=="__main__":
elif opt.dataset == 'cifar10':
dataset = dset.CIFAR10(root=opt.dataroot, download=True,
transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.Resize(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
@ -185,7 +185,7 @@ if __name__=="__main__":
for p in netD.parameters():
p.data.clamp_(opt.clamp_lower, opt.clamp_upper)
data = data_iter.next()
data = next(data_iter)
i += 1
# train with real

View File

@ -1,2 +1,3 @@
torch
torchvision
lmdb