Import
This commit is contained in:
commit
304d56a1a6
|
@ -0,0 +1,10 @@
|
||||||
|
# Compile file
|
||||||
|
*/__pycache__/*
|
||||||
|
|
||||||
|
# Data
|
||||||
|
data/*
|
||||||
|
*.png
|
||||||
|
*.wav
|
||||||
|
|
||||||
|
# Editor configuration
|
||||||
|
.vscode/*
|
|
@ -0,0 +1,2 @@
|
||||||
|
# Variational Autoencoder
|
||||||
|
This code was developed as part of a one-month research assistant internship in June 2021.
|
File diff suppressed because one or more lines are too long
Binary file not shown.
|
@ -0,0 +1,66 @@
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
|
|
||||||
|
# Defining the model
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, d=20, size_input=[28,28], encoder=None, decoder=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.d = d
|
||||||
|
self.size_input = size_input
|
||||||
|
self.flatten_size_input = utils.prod(self.size_input)
|
||||||
|
global flatten_size_input
|
||||||
|
flatten_size_input = self.flatten_size_input
|
||||||
|
if encoder==None or decoder==None:
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(self.flatten_size_input, d ** 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(d ** 2, d * 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
nn.Linear(d, d ** 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(d ** 2, self.flatten_size_input),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder, self.decoder = encoder, decoder
|
||||||
|
|
||||||
|
def reparameterise(self, mu, logvar):
|
||||||
|
if self.training:
|
||||||
|
std = logvar.mul(0.5).exp_()
|
||||||
|
eps = std.data.new(std.size()).normal_()
|
||||||
|
return eps.mul(std).add_(mu)
|
||||||
|
else:
|
||||||
|
return mu
|
||||||
|
|
||||||
|
def encode(self,x):
|
||||||
|
return self.encoder(x.view(-1, self.flatten_size_input)).view(-1, 2, self.d)
|
||||||
|
|
||||||
|
def decode(self,z):
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
mu_logvar = self.encode(x)
|
||||||
|
mu = mu_logvar[:, 0, :]
|
||||||
|
logvar = mu_logvar[:, 1, :]
|
||||||
|
z = self.reparameterise(mu, logvar)
|
||||||
|
return self.decode(z), mu, logvar
|
||||||
|
|
||||||
|
|
||||||
|
def optimizer(model, optim=torch.optim.Adam, learning_rate=1e-3):
|
||||||
|
return optim(model.parameters(),lr=learning_rate,)
|
||||||
|
|
||||||
|
def loss_function(f=nn.functional.binary_cross_entropy, β=1):
|
||||||
|
def loss(x_hat, x, mu, logvar):
|
||||||
|
Data_Error = f(x_hat, x.view(-1, flatten_size_input), reduction='sum')
|
||||||
|
KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
|
||||||
|
return Data_Error + β * KLD
|
||||||
|
return loss
|
|
@ -0,0 +1,74 @@
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import utils
|
||||||
|
from models import VAE as VAEm
|
||||||
|
|
||||||
|
|
||||||
|
# Defining the model
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, d=20, size_input=[28,28], size_output=10, model=None, output_network=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.d = d
|
||||||
|
self.size_input = size_input
|
||||||
|
self.flatten_size_input = utils.prod(self.size_input)
|
||||||
|
self.size_output = size_output
|
||||||
|
self.flatten_size_output = utils.prod(self.size_output)
|
||||||
|
global flatten_size_input, flatten_size_output
|
||||||
|
flatten_size_input = self.flatten_size_input
|
||||||
|
flatten_size_output = self.flatten_size_output
|
||||||
|
if model==None:
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(self.flatten_size_input, d ** 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(d ** 2, d * 2)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder = model.encoder
|
||||||
|
if output_network==None:
|
||||||
|
self.output_network = nn.Sequential(
|
||||||
|
nn.Linear(d, d ** 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(d ** 2, d ** 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(d ** 2, self.flatten_size_output),
|
||||||
|
nn.Softmax(dim=1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.output_network = output_network
|
||||||
|
|
||||||
|
def reparameterise(self, mu, logvar):
|
||||||
|
if self.training:
|
||||||
|
std = logvar.mul(0.5).exp_()
|
||||||
|
eps = std.data.new(std.size()).normal_()
|
||||||
|
return eps.mul(std).add_(mu)
|
||||||
|
else:
|
||||||
|
return mu
|
||||||
|
|
||||||
|
def encode(self,x):
|
||||||
|
return self.encoder(x.view(-1, self.flatten_size_input)).view(-1, 2, self.d)
|
||||||
|
|
||||||
|
def dense_decoder(self,z):
|
||||||
|
return self.output_network(z)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
mu_logvar = self.encode(x)
|
||||||
|
mu = mu_logvar[:, 0, :]
|
||||||
|
logvar = mu_logvar[:, 1, :]
|
||||||
|
z = self.reparameterise(mu, logvar)
|
||||||
|
return self.dense_decoder(z), mu, logvar
|
||||||
|
|
||||||
|
|
||||||
|
def optimizer(model, optim=torch.optim.Adam, learning_rate=1e-3):
|
||||||
|
return optim(model.output_network.parameters(),lr=learning_rate,)
|
||||||
|
|
||||||
|
def loss_function(f=nn.functional.binary_cross_entropy, β=1):
|
||||||
|
def loss(y_hat, y, mu, logvar):
|
||||||
|
Data_Error = f(y_hat, y, reduction='sum')
|
||||||
|
KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
|
||||||
|
return Data_Error + β * KLD
|
||||||
|
return loss
|
|
@ -0,0 +1,99 @@
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(self, model, optimizer, loss_function, device, train_loader, test_loader, epochs = 10, print_frequency = 10):
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.loss_function = loss_function
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.test_loader = test_loader
|
||||||
|
|
||||||
|
self.epochs = epochs
|
||||||
|
self.print_frequency = print_frequency
|
||||||
|
|
||||||
|
def _train_epoch(self, epoch):
|
||||||
|
# Training
|
||||||
|
if epoch > 0: # test untrained net first
|
||||||
|
codes = dict.fromkeys(["μ", "logσ2", "y", "loss"])
|
||||||
|
self.model.train()
|
||||||
|
means, logvars, labels, losses = list(), list(), list(), list()
|
||||||
|
train_loss = 0
|
||||||
|
for batch_idx, (x, y) in enumerate(self.train_loader):
|
||||||
|
size_batch = list(x.shape)[0] if batch_idx==0 else size_batch
|
||||||
|
x = x.to(self.device)
|
||||||
|
# ===================forward=====================
|
||||||
|
x_hat, mu, logvar = self.model(x)
|
||||||
|
loss = self.loss_function(x_hat, x, mu, logvar)
|
||||||
|
train_loss += loss.item()
|
||||||
|
# ===================backward====================
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
# =====================log=======================
|
||||||
|
means.append(mu.detach())
|
||||||
|
logvars.append(logvar.detach())
|
||||||
|
labels.append(y.detach())
|
||||||
|
losses.append(train_loss / ((batch_idx + 1) * size_batch))
|
||||||
|
if batch_idx % self.print_frequency == 0:
|
||||||
|
print(f'- Epoch: {epoch} [{batch_idx}/{len(self.train_loader.dataset)} ({100. * batch_idx / len(self.train_loader.dataset) :.0f}%)] Average loss: {train_loss / ((batch_idx + 1) * size_batch):.4f}', end="\r")
|
||||||
|
# ===================log========================
|
||||||
|
codes['μ'] = torch.cat(means).cpu()
|
||||||
|
codes['logσ2'] = torch.cat(logvars).cpu()
|
||||||
|
codes['y'] = torch.cat(labels).cpu()
|
||||||
|
codes['loss'] = losses
|
||||||
|
print(f'====> Epoch: {epoch} Average loss: {train_loss / len(self.train_loader.dataset):.4f}')
|
||||||
|
return codes
|
||||||
|
|
||||||
|
def _test_epoch(self, epoch):
|
||||||
|
# Testing
|
||||||
|
codes = dict.fromkeys(["μ", "logσ2", "y", "loss"])
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model.eval()
|
||||||
|
means, logvars, labels, losses = list(), list(), list(), list()
|
||||||
|
test_loss = 0
|
||||||
|
for batch_idx, (x, y) in enumerate(self.test_loader):
|
||||||
|
size_batch = list(x.shape)[0] if batch_idx==0 else size_batch
|
||||||
|
x = x.to(self.device)
|
||||||
|
# ===================forward=====================
|
||||||
|
x_hat, mu, logvar = self.model(x)
|
||||||
|
test_loss += self.loss_function(x_hat, x, mu, logvar).item()
|
||||||
|
# =====================log=======================
|
||||||
|
means.append(mu.detach())
|
||||||
|
logvars.append(logvar.detach())
|
||||||
|
labels.append(y.detach())
|
||||||
|
losses.append(test_loss / ((batch_idx + 1) * size_batch))
|
||||||
|
# ===================log========================
|
||||||
|
codes['μ'] = torch.cat(means).cpu()
|
||||||
|
codes['logσ2'] = torch.cat(logvars).cpu()
|
||||||
|
codes['y'] = torch.cat(labels).cpu()
|
||||||
|
codes['loss'] = losses
|
||||||
|
test_loss /= len(self.test_loader.dataset)
|
||||||
|
print(f'====> Test set loss: {test_loss:.4f}')
|
||||||
|
return codes
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
codes_train = dict(μ=list(), logσ2=list(), y=list(), loss=list())
|
||||||
|
codes_test = dict(μ=list(), logσ2=list(), y=list(), loss=list())
|
||||||
|
|
||||||
|
for epoch in range(self.epochs + 1):
|
||||||
|
codes = self._train_epoch(epoch) if epoch>0 else dict.fromkeys(["μ", "logσ2", "y", "loss"])
|
||||||
|
for key in codes_train:
|
||||||
|
codes_train[key].append(codes[key])
|
||||||
|
codes = self._test_epoch(epoch)
|
||||||
|
for key in codes_test:
|
||||||
|
codes_test[key].append(codes[key])
|
||||||
|
|
||||||
|
if epoch != self.epochs:
|
||||||
|
print("---")
|
||||||
|
for x, y in self.test_loader:
|
||||||
|
x = x.to(self.device)
|
||||||
|
x_hat, _, _ = self.model(x)
|
||||||
|
pass
|
||||||
|
utils.display_images(x, x_hat, 1, f'Epoch {epoch}')
|
||||||
|
return codes_train, codes_test
|
|
@ -0,0 +1,104 @@
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(self, model, optimizer, loss_function, device, train_loader, test_loader, epochs = 10, print_frequency = 10):
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.loss_function = loss_function
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.test_loader = test_loader
|
||||||
|
|
||||||
|
self.epochs = epochs
|
||||||
|
self.print_frequency = print_frequency
|
||||||
|
|
||||||
|
def _train_epoch(self, epoch):
|
||||||
|
# Training
|
||||||
|
if epoch > 0: # test untrained net first
|
||||||
|
codes = dict.fromkeys(["μ", "logσ2", "y", "loss", "AR"])
|
||||||
|
self.model.train()
|
||||||
|
means, logvars, labels, losses, accuracy_rate = list(), list(), list(), list(), list()
|
||||||
|
train_loss = 0
|
||||||
|
good_pred = 0
|
||||||
|
for batch_idx, (x, y) in enumerate(self.train_loader):
|
||||||
|
size_batch = list(x.shape)[0] if batch_idx==0 else size_batch
|
||||||
|
x = x.to(self.device)
|
||||||
|
y = utils.format_labels_prob(y)
|
||||||
|
y = y.to(self.device)
|
||||||
|
# ===================forward=====================
|
||||||
|
y_hat, mu, logvar = self.model(x)
|
||||||
|
loss = self.loss_function(y_hat, y, mu, logvar)
|
||||||
|
train_loss += loss.item()
|
||||||
|
good_pred += (torch.argmax(y_hat, dim=1) == torch.argmax(y, dim=1)).sum().item()
|
||||||
|
# ===================backward====================
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
# =====================log=======================
|
||||||
|
means.append(mu.detach())
|
||||||
|
logvars.append(logvar.detach())
|
||||||
|
labels.append(y.detach())
|
||||||
|
losses.append(train_loss / ((batch_idx + 1) * size_batch))
|
||||||
|
accuracy_rate.append(good_pred / ((batch_idx + 1) * size_batch))
|
||||||
|
if batch_idx % self.print_frequency == 0:
|
||||||
|
print(f'- Epoch: {epoch} [{batch_idx}/{len(self.train_loader.dataset)} ({100. * batch_idx / len(self.train_loader.dataset) :.0f}%)] Average loss: {train_loss / ((batch_idx + 1) * size_batch):.4f}\t Accuracy rate: {good_pred / ((batch_idx + 1) * size_batch) :.0f}%', end="\r")
|
||||||
|
# ===================log========================
|
||||||
|
codes['μ'] = torch.cat(means).cpu()
|
||||||
|
codes['logσ2'] = torch.cat(logvars).cpu()
|
||||||
|
codes['y'] = torch.cat(labels).cpu()
|
||||||
|
codes['loss'] = losses
|
||||||
|
codes['AR'] = accuracy_rate
|
||||||
|
print(f'====> Epoch: {epoch} Average loss: {train_loss / len(self.train_loader.dataset):.4f}\t Accuracy rate: {100 * good_pred / len(self.train_loader.dataset) :.0f}%')
|
||||||
|
return codes
|
||||||
|
|
||||||
|
def _test_epoch(self, epoch):
|
||||||
|
# Testing
|
||||||
|
codes = dict.fromkeys(["μ", "logσ2", "y", "loss", "AR"])
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model.eval()
|
||||||
|
means, logvars, labels, losses, accuracy_rate = list(), list(), list(), list(), list()
|
||||||
|
test_loss = 0
|
||||||
|
good_pred = 0
|
||||||
|
for batch_idx, (x, y) in enumerate(self.test_loader):
|
||||||
|
size_batch = list(x.shape)[0] if batch_idx==0 else size_batch
|
||||||
|
x = x.to(self.device)
|
||||||
|
y = utils.format_labels_prob(y)
|
||||||
|
y = y.to(self.device)
|
||||||
|
# ===================forward=====================
|
||||||
|
y_hat, mu, logvar = self.model(x)
|
||||||
|
test_loss += self.loss_function(y_hat, y, mu, logvar).item()
|
||||||
|
good_pred += (torch.argmax(y_hat, dim=1) == torch.argmax(y, dim=1)).sum().item()
|
||||||
|
# =====================log=======================
|
||||||
|
means.append(mu.detach())
|
||||||
|
logvars.append(logvar.detach())
|
||||||
|
labels.append(y.detach())
|
||||||
|
losses.append(test_loss / ((batch_idx + 1) * size_batch))
|
||||||
|
accuracy_rate.append(good_pred / ((batch_idx + 1) * size_batch))
|
||||||
|
# ===================log========================
|
||||||
|
codes['μ'] = torch.cat(means).cpu()
|
||||||
|
codes['logσ2'] = torch.cat(logvars).cpu()
|
||||||
|
codes['y'] = torch.cat(labels).cpu()
|
||||||
|
codes['loss'] = losses
|
||||||
|
codes['AR'] = accuracy_rate
|
||||||
|
test_loss /= len(self.test_loader.dataset)
|
||||||
|
print(f'====> Test set loss: {test_loss:.4f}\t Accuracy rate: {100. * good_pred / len(self.train_loader.dataset) :.0f}%')
|
||||||
|
return codes
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
codes_train = dict(μ=list(), logσ2=list(), y=list(), loss=list(), AR=list())
|
||||||
|
codes_test = dict(μ=list(), logσ2=list(), y=list(), loss=list(), AR=list())
|
||||||
|
for epoch in range(self.epochs + 1):
|
||||||
|
codes = self._train_epoch(epoch) if epoch>0 else dict.fromkeys(["μ", "logσ2", "y", "loss", "AR"])
|
||||||
|
for key in codes_train:
|
||||||
|
codes_train[key].append(codes[key])
|
||||||
|
codes = self._test_epoch(epoch)
|
||||||
|
for key in codes_test:
|
||||||
|
codes_test[key].append(codes[key])
|
||||||
|
if epoch != self.epochs:
|
||||||
|
print("---")
|
||||||
|
return codes_train, codes_test
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .util import *
|
||||||
|
from .run import *
|
||||||
|
from .plot_lib import *
|
|
@ -0,0 +1,30 @@
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from IPython.display import HTML, display
|
||||||
|
|
||||||
|
|
||||||
|
def set_default(figsize=(10, 10), dpi=100):
|
||||||
|
plt.style.use(['dark_background', 'bmh'])
|
||||||
|
plt.rc('axes', facecolor='k')
|
||||||
|
plt.rc('figure', facecolor='k')
|
||||||
|
plt.rc('figure', figsize=figsize, dpi=dpi)
|
||||||
|
|
||||||
|
|
||||||
|
def display_images(in_, out, n=1, label=None, count=False):
|
||||||
|
for N in range(n):
|
||||||
|
if in_ is not None:
|
||||||
|
in_pic = in_.data.cpu().view(-1, 28, 28)
|
||||||
|
plt.figure(figsize=(18, 4))
|
||||||
|
plt.suptitle(label + ' – real test data / reconstructions', color='w', fontsize=16)
|
||||||
|
for i in range(4):
|
||||||
|
plt.subplot(1,4,i+1)
|
||||||
|
plt.imshow(in_pic[i+4*N])
|
||||||
|
plt.axis('off')
|
||||||
|
out_pic = out.data.cpu().view(-1, 28, 28)
|
||||||
|
plt.figure(figsize=(18, 6))
|
||||||
|
for i in range(4):
|
||||||
|
plt.subplot(1,4,i+1)
|
||||||
|
plt.imshow(out_pic[i+4*N])
|
||||||
|
plt.axis('off')
|
||||||
|
if count: plt.title(str(4 * N + i), color='w')
|
|
@ -0,0 +1,18 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def prepare_device(n_gpu_use):
|
||||||
|
"""
|
||||||
|
setup GPU device if available. get gpu device indices which are used for DataParallel
|
||||||
|
"""
|
||||||
|
n_gpu = torch.cuda.device_count()
|
||||||
|
if n_gpu_use > 0 and n_gpu == 0:
|
||||||
|
print("Warning: There\'s no GPU available on this machine,"
|
||||||
|
"training will be performed on CPU.")
|
||||||
|
n_gpu_use = 0
|
||||||
|
if n_gpu_use > n_gpu:
|
||||||
|
print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
|
||||||
|
"available on this machine.")
|
||||||
|
n_gpu_use = n_gpu
|
||||||
|
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
|
||||||
|
list_ids = list(range(n_gpu_use))
|
||||||
|
return device, list_ids
|
|
@ -0,0 +1,22 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def prod(x):
|
||||||
|
if type(x)==list:
|
||||||
|
return np.prod(np.array(x))
|
||||||
|
elif type(x)==float or type(x)==int:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return np.prod(np.array(list(x.shape)))
|
||||||
|
|
||||||
|
def format_labels_prob(y, nbr_dif_labels=10):
|
||||||
|
n=y.shape[0]
|
||||||
|
list_y=list(y.numpy())
|
||||||
|
y_bis = torch.zeros((n, nbr_dif_labels))
|
||||||
|
for i in range(n):
|
||||||
|
j = list_y[i]
|
||||||
|
y_bis[i][j] = 1.
|
||||||
|
return y_bis
|
Loading…
Reference in New Issue