74 lines
1.9 KiB
Python
74 lines
1.9 KiB
Python
import torch
|
|
import torchvision
|
|
from torch import nn
|
|
|
|
import utils
|
|
from models import VAE as VAEm
|
|
|
|
|
|
# Defining the model
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, d=20, size_input=[28,28], size_output=10, model=None, output_network=None):
|
|
super().__init__()
|
|
|
|
self.d = d
|
|
self.size_input = size_input
|
|
self.flatten_size_input = utils.prod(self.size_input)
|
|
self.size_output = size_output
|
|
self.flatten_size_output = utils.prod(self.size_output)
|
|
global flatten_size_input, flatten_size_output
|
|
flatten_size_input = self.flatten_size_input
|
|
flatten_size_output = self.flatten_size_output
|
|
if model==None:
|
|
self.encoder = nn.Sequential(
|
|
nn.Linear(self.flatten_size_input, d ** 2),
|
|
nn.ReLU(),
|
|
nn.Linear(d ** 2, d * 2)
|
|
)
|
|
else:
|
|
self.encoder = model.encoder
|
|
if output_network==None:
|
|
self.output_network = nn.Sequential(
|
|
nn.Linear(d, d ** 2),
|
|
nn.ReLU(),
|
|
nn.Linear(d ** 2, d ** 2),
|
|
nn.ReLU(),
|
|
nn.Linear(d ** 2, self.flatten_size_output),
|
|
nn.Softmax(dim=1),
|
|
)
|
|
else:
|
|
self.output_network = output_network
|
|
|
|
def reparameterise(self, mu, logvar):
|
|
if self.training:
|
|
std = logvar.mul(0.5).exp_()
|
|
eps = std.data.new(std.size()).normal_()
|
|
return eps.mul(std).add_(mu)
|
|
else:
|
|
return mu
|
|
|
|
def encode(self,x):
|
|
return self.encoder(x.view(-1, self.flatten_size_input)).view(-1, 2, self.d)
|
|
|
|
def dense_decoder(self,z):
|
|
return self.output_network(z)
|
|
|
|
def forward(self, x):
|
|
mu_logvar = self.encode(x)
|
|
mu = mu_logvar[:, 0, :]
|
|
logvar = mu_logvar[:, 1, :]
|
|
z = self.reparameterise(mu, logvar)
|
|
return self.dense_decoder(z), mu, logvar
|
|
|
|
|
|
def optimizer(model, optim=torch.optim.Adam, learning_rate=1e-3):
|
|
return optim(model.output_network.parameters(),lr=learning_rate,)
|
|
|
|
def loss_function(f=nn.functional.binary_cross_entropy, β=1):
|
|
def loss(y_hat, y, mu, logvar):
|
|
Data_Error = f(y_hat, y, reduction='sum')
|
|
KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
|
|
return Data_Error + β * KLD
|
|
return loss |