2017-01-30 14:19:57 +00:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
2017-02-27 20:39:45 +00:00
|
|
|
class MLP_G(nn.Module):
|
2017-01-30 14:19:57 +00:00
|
|
|
def __init__(self, isize, nz, nc, ngf, ngpu):
|
|
|
|
super(MLP_G, self).__init__()
|
|
|
|
self.ngpu = ngpu
|
|
|
|
|
|
|
|
main = nn.Sequential(
|
|
|
|
# Z goes into a linear of size: ngf
|
|
|
|
nn.Linear(nz, ngf),
|
|
|
|
nn.ReLU(True),
|
|
|
|
nn.Linear(ngf, ngf),
|
|
|
|
nn.ReLU(True),
|
|
|
|
nn.Linear(ngf, ngf),
|
|
|
|
nn.ReLU(True),
|
|
|
|
nn.Linear(ngf, nc * isize * isize),
|
|
|
|
)
|
|
|
|
self.main = main
|
|
|
|
self.nc = nc
|
|
|
|
self.isize = isize
|
|
|
|
self.nz = nz
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
input = input.view(input.size(0), input.size(1))
|
|
|
|
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
2017-04-06 05:51:31 +00:00
|
|
|
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
|
|
|
|
else:
|
|
|
|
output = self.main(input)
|
2017-04-13 20:34:29 +00:00
|
|
|
return output.view(output.size(0), self.nc, self.isize, self.isize)
|
2017-01-30 14:19:57 +00:00
|
|
|
|
|
|
|
|
2017-02-27 20:39:45 +00:00
|
|
|
class MLP_D(nn.Module):
|
2017-01-30 14:19:57 +00:00
|
|
|
def __init__(self, isize, nz, nc, ndf, ngpu):
|
|
|
|
super(MLP_D, self).__init__()
|
|
|
|
self.ngpu = ngpu
|
|
|
|
|
|
|
|
main = nn.Sequential(
|
|
|
|
# Z goes into a linear of size: ndf
|
|
|
|
nn.Linear(nc * isize * isize, ndf),
|
|
|
|
nn.ReLU(True),
|
|
|
|
nn.Linear(ndf, ndf),
|
|
|
|
nn.ReLU(True),
|
|
|
|
nn.Linear(ndf, ndf),
|
|
|
|
nn.ReLU(True),
|
|
|
|
nn.Linear(ndf, 1),
|
|
|
|
)
|
|
|
|
self.main = main
|
|
|
|
self.nc = nc
|
|
|
|
self.isize = isize
|
|
|
|
self.nz = nz
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
input = input.view(input.size(0),
|
|
|
|
input.size(1) * input.size(2) * input.size(3))
|
|
|
|
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
|
2017-04-06 05:51:31 +00:00
|
|
|
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
|
|
|
|
else:
|
|
|
|
output = self.main(input)
|
2017-01-30 14:19:57 +00:00
|
|
|
output = output.mean(0)
|
|
|
|
return output.view(1)
|