variational-autoencoder/trainer/VAE_with_dense_decoder.py

105 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torchvision
from torch import nn
import utils
class Trainer:
def __init__(self, model, optimizer, loss_function, device, train_loader, test_loader, epochs = 10, print_frequency = 10):
self.model = model
self.optimizer = optimizer
self.loss_function = loss_function
self.device = device
self.train_loader = train_loader
self.test_loader = test_loader
self.epochs = epochs
self.print_frequency = print_frequency
def _train_epoch(self, epoch):
# Training
if epoch > 0: # test untrained net first
codes = dict.fromkeys(["μ", "logσ2", "y", "loss", "AR"])
self.model.train()
means, logvars, labels, losses, accuracy_rate = list(), list(), list(), list(), list()
train_loss = 0
good_pred = 0
for batch_idx, (x, y) in enumerate(self.train_loader):
size_batch = list(x.shape)[0] if batch_idx==0 else size_batch
x = x.to(self.device)
y = utils.format_labels_prob(y)
y = y.to(self.device)
# ===================forward=====================
y_hat, mu, logvar = self.model(x)
loss = self.loss_function(y_hat, y, mu, logvar)
train_loss += loss.item()
good_pred += (torch.argmax(y_hat, dim=1) == torch.argmax(y, dim=1)).sum().item()
# ===================backward====================
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# =====================log=======================
means.append(mu.detach())
logvars.append(logvar.detach())
labels.append(y.detach())
losses.append(train_loss / ((batch_idx + 1) * size_batch))
accuracy_rate.append(good_pred / ((batch_idx + 1) * size_batch))
if batch_idx % self.print_frequency == 0:
print(f'- Epoch: {epoch} [{batch_idx}/{len(self.train_loader.dataset)} ({100. * batch_idx / len(self.train_loader.dataset) :.0f}%)] Average loss: {train_loss / ((batch_idx + 1) * size_batch):.4f}\t Accuracy rate: {good_pred / ((batch_idx + 1) * size_batch) :.0f}%', end="\r")
# ===================log========================
codes['μ'] = torch.cat(means).cpu()
codes['logσ2'] = torch.cat(logvars).cpu()
codes['y'] = torch.cat(labels).cpu()
codes['loss'] = losses
codes['AR'] = accuracy_rate
print(f'====> Epoch: {epoch} Average loss: {train_loss / len(self.train_loader.dataset):.4f}\t Accuracy rate: {100 * good_pred / len(self.train_loader.dataset) :.0f}%')
return codes
def _test_epoch(self, epoch):
# Testing
codes = dict.fromkeys(["μ", "logσ2", "y", "loss", "AR"])
with torch.no_grad():
self.model.eval()
means, logvars, labels, losses, accuracy_rate = list(), list(), list(), list(), list()
test_loss = 0
good_pred = 0
for batch_idx, (x, y) in enumerate(self.test_loader):
size_batch = list(x.shape)[0] if batch_idx==0 else size_batch
x = x.to(self.device)
y = utils.format_labels_prob(y)
y = y.to(self.device)
# ===================forward=====================
y_hat, mu, logvar = self.model(x)
test_loss += self.loss_function(y_hat, y, mu, logvar).item()
good_pred += (torch.argmax(y_hat, dim=1) == torch.argmax(y, dim=1)).sum().item()
# =====================log=======================
means.append(mu.detach())
logvars.append(logvar.detach())
labels.append(y.detach())
losses.append(test_loss / ((batch_idx + 1) * size_batch))
accuracy_rate.append(good_pred / ((batch_idx + 1) * size_batch))
# ===================log========================
codes['μ'] = torch.cat(means).cpu()
codes['logσ2'] = torch.cat(logvars).cpu()
codes['y'] = torch.cat(labels).cpu()
codes['loss'] = losses
codes['AR'] = accuracy_rate
test_loss /= len(self.test_loader.dataset)
print(f'====> Test set loss: {test_loss:.4f}\t Accuracy rate: {100. * good_pred / len(self.train_loader.dataset) :.0f}%')
return codes
def train(self):
codes_train = dict(μ=list(), logσ2=list(), y=list(), loss=list(), AR=list())
codes_test = dict(μ=list(), logσ2=list(), y=list(), loss=list(), AR=list())
for epoch in range(self.epochs + 1):
codes = self._train_epoch(epoch) if epoch>0 else dict.fromkeys(["μ", "logσ2", "y", "loss", "AR"])
for key in codes_train:
codes_train[key].append(codes[key])
codes = self._test_epoch(epoch)
for key in codes_test:
codes_test[key].append(codes[key])
if epoch != self.epochs:
print("---")
return codes_train, codes_test