from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import torch import torch.nn as nn class MLP_G(nn.Module): def __init__(self, isize, nz, nc, ngf, ngpu): super(MLP_G, self).__init__() self.ngpu = ngpu main = nn.Sequential( # Z goes into a linear of size: ngf nn.Linear(nz, ngf), nn.ReLU(True), nn.Linear(ngf, ngf), nn.ReLU(True), nn.Linear(ngf, ngf), nn.ReLU(True), nn.Linear(ngf, nc * isize * isize), ) self.main = main self.nc = nc self.isize = isize self.nz = nz def forward(self, input): input = input.view(input.size(0), input.size(1)) if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) else: output = self.main(input) return output.view(output.size(0), self.nc, self.isize, self.isize) class MLP_D(nn.Module): def __init__(self, isize, nz, nc, ndf, ngpu): super(MLP_D, self).__init__() self.ngpu = ngpu main = nn.Sequential( # Z goes into a linear of size: ndf nn.Linear(nc * isize * isize, ndf), nn.ReLU(True), nn.Linear(ndf, ndf), nn.ReLU(True), nn.Linear(ndf, ndf), nn.ReLU(True), nn.Linear(ndf, 1), ) self.main = main self.nc = nc self.isize = isize self.nz = nz def forward(self, input): input = input.view(input.size(0), input.size(1) * input.size(2) * input.size(3)) if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) else: output = self.main(input) output = output.mean(0) return output.view(1)