Container -> Module (remove depreceated code)

This commit is contained in:
soumith 2017-02-27 12:39:45 -08:00
parent 93f26f9e4c
commit d9ad2bb847
2 changed files with 6 additions and 6 deletions

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
class DCGAN_D(nn.Container): class DCGAN_D(nn.Module):
def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0): def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0):
super(DCGAN_D, self).__init__() super(DCGAN_D, self).__init__()
self.ngpu = ngpu self.ngpu = ngpu
@ -51,7 +51,7 @@ class DCGAN_D(nn.Container):
output = output.mean(0) output = output.mean(0)
return output.view(1) return output.view(1)
class DCGAN_G(nn.Container): class DCGAN_G(nn.Module):
def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0): def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0):
super(DCGAN_G, self).__init__() super(DCGAN_G, self).__init__()
self.ngpu = ngpu self.ngpu = ngpu
@ -104,7 +104,7 @@ class DCGAN_G(nn.Container):
return nn.parallel.data_parallel(self.main, input, gpu_ids) return nn.parallel.data_parallel(self.main, input, gpu_ids)
############################################################################### ###############################################################################
class DCGAN_D_nobn(nn.Container): class DCGAN_D_nobn(nn.Module):
def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0): def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0):
super(DCGAN_D_nobn, self).__init__() super(DCGAN_D_nobn, self).__init__()
self.ngpu = ngpu self.ngpu = ngpu
@ -150,7 +150,7 @@ class DCGAN_D_nobn(nn.Container):
output = output.mean(0) output = output.mean(0)
return output.view(1) return output.view(1)
class DCGAN_G_nobn(nn.Container): class DCGAN_G_nobn(nn.Module):
def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0): def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0):
super(DCGAN_G_nobn, self).__init__() super(DCGAN_G_nobn, self).__init__()
self.ngpu = ngpu self.ngpu = ngpu

View File

@ -5,7 +5,7 @@ from __future__ import unicode_literals
import torch import torch
import torch.nn as nn import torch.nn as nn
class MLP_G(nn.Container): class MLP_G(nn.Module):
def __init__(self, isize, nz, nc, ngf, ngpu): def __init__(self, isize, nz, nc, ngf, ngpu):
super(MLP_G, self).__init__() super(MLP_G, self).__init__()
self.ngpu = ngpu self.ngpu = ngpu
@ -34,7 +34,7 @@ class MLP_G(nn.Container):
return out.view(out.size(0), self.nc, self.isize, self.isize) return out.view(out.size(0), self.nc, self.isize, self.isize)
class MLP_D(nn.Container): class MLP_D(nn.Module):
def __init__(self, isize, nz, nc, ndf, ngpu): def __init__(self, isize, nz, nc, ndf, ngpu):
super(MLP_D, self).__init__() super(MLP_D, self).__init__()
self.ngpu = ngpu self.ngpu = ngpu