105 lines
4.3 KiB
Python
105 lines
4.3 KiB
Python
|
import torch
|
|||
|
import torchvision
|
|||
|
from torch import nn
|
|||
|
|
|||
|
import utils
|
|||
|
|
|||
|
class Trainer:
|
|||
|
def __init__(self, model, optimizer, loss_function, device, train_loader, test_loader, epochs = 10, print_frequency = 10):
|
|||
|
self.model = model
|
|||
|
self.optimizer = optimizer
|
|||
|
self.loss_function = loss_function
|
|||
|
|
|||
|
self.device = device
|
|||
|
self.train_loader = train_loader
|
|||
|
self.test_loader = test_loader
|
|||
|
|
|||
|
self.epochs = epochs
|
|||
|
self.print_frequency = print_frequency
|
|||
|
|
|||
|
def _train_epoch(self, epoch):
|
|||
|
# Training
|
|||
|
if epoch > 0: # test untrained net first
|
|||
|
codes = dict.fromkeys(["μ", "logσ2", "y", "loss", "AR"])
|
|||
|
self.model.train()
|
|||
|
means, logvars, labels, losses, accuracy_rate = list(), list(), list(), list(), list()
|
|||
|
train_loss = 0
|
|||
|
good_pred = 0
|
|||
|
for batch_idx, (x, y) in enumerate(self.train_loader):
|
|||
|
size_batch = list(x.shape)[0] if batch_idx==0 else size_batch
|
|||
|
x = x.to(self.device)
|
|||
|
y = utils.format_labels_prob(y)
|
|||
|
y = y.to(self.device)
|
|||
|
# ===================forward=====================
|
|||
|
y_hat, mu, logvar = self.model(x)
|
|||
|
loss = self.loss_function(y_hat, y, mu, logvar)
|
|||
|
train_loss += loss.item()
|
|||
|
good_pred += (torch.argmax(y_hat, dim=1) == torch.argmax(y, dim=1)).sum().item()
|
|||
|
# ===================backward====================
|
|||
|
self.optimizer.zero_grad()
|
|||
|
loss.backward()
|
|||
|
self.optimizer.step()
|
|||
|
# =====================log=======================
|
|||
|
means.append(mu.detach())
|
|||
|
logvars.append(logvar.detach())
|
|||
|
labels.append(y.detach())
|
|||
|
losses.append(train_loss / ((batch_idx + 1) * size_batch))
|
|||
|
accuracy_rate.append(good_pred / ((batch_idx + 1) * size_batch))
|
|||
|
if batch_idx % self.print_frequency == 0:
|
|||
|
print(f'- Epoch: {epoch} [{batch_idx}/{len(self.train_loader.dataset)} ({100. * batch_idx / len(self.train_loader.dataset) :.0f}%)] Average loss: {train_loss / ((batch_idx + 1) * size_batch):.4f}\t Accuracy rate: {good_pred / ((batch_idx + 1) * size_batch) :.0f}%', end="\r")
|
|||
|
# ===================log========================
|
|||
|
codes['μ'] = torch.cat(means).cpu()
|
|||
|
codes['logσ2'] = torch.cat(logvars).cpu()
|
|||
|
codes['y'] = torch.cat(labels).cpu()
|
|||
|
codes['loss'] = losses
|
|||
|
codes['AR'] = accuracy_rate
|
|||
|
print(f'====> Epoch: {epoch} Average loss: {train_loss / len(self.train_loader.dataset):.4f}\t Accuracy rate: {100 * good_pred / len(self.train_loader.dataset) :.0f}%')
|
|||
|
return codes
|
|||
|
|
|||
|
def _test_epoch(self, epoch):
|
|||
|
# Testing
|
|||
|
codes = dict.fromkeys(["μ", "logσ2", "y", "loss", "AR"])
|
|||
|
with torch.no_grad():
|
|||
|
self.model.eval()
|
|||
|
means, logvars, labels, losses, accuracy_rate = list(), list(), list(), list(), list()
|
|||
|
test_loss = 0
|
|||
|
good_pred = 0
|
|||
|
for batch_idx, (x, y) in enumerate(self.test_loader):
|
|||
|
size_batch = list(x.shape)[0] if batch_idx==0 else size_batch
|
|||
|
x = x.to(self.device)
|
|||
|
y = utils.format_labels_prob(y)
|
|||
|
y = y.to(self.device)
|
|||
|
# ===================forward=====================
|
|||
|
y_hat, mu, logvar = self.model(x)
|
|||
|
test_loss += self.loss_function(y_hat, y, mu, logvar).item()
|
|||
|
good_pred += (torch.argmax(y_hat, dim=1) == torch.argmax(y, dim=1)).sum().item()
|
|||
|
# =====================log=======================
|
|||
|
means.append(mu.detach())
|
|||
|
logvars.append(logvar.detach())
|
|||
|
labels.append(y.detach())
|
|||
|
losses.append(test_loss / ((batch_idx + 1) * size_batch))
|
|||
|
accuracy_rate.append(good_pred / ((batch_idx + 1) * size_batch))
|
|||
|
# ===================log========================
|
|||
|
codes['μ'] = torch.cat(means).cpu()
|
|||
|
codes['logσ2'] = torch.cat(logvars).cpu()
|
|||
|
codes['y'] = torch.cat(labels).cpu()
|
|||
|
codes['loss'] = losses
|
|||
|
codes['AR'] = accuracy_rate
|
|||
|
test_loss /= len(self.test_loader.dataset)
|
|||
|
print(f'====> Test set loss: {test_loss:.4f}\t Accuracy rate: {100. * good_pred / len(self.train_loader.dataset) :.0f}%')
|
|||
|
return codes
|
|||
|
|
|||
|
def train(self):
|
|||
|
codes_train = dict(μ=list(), logσ2=list(), y=list(), loss=list(), AR=list())
|
|||
|
codes_test = dict(μ=list(), logσ2=list(), y=list(), loss=list(), AR=list())
|
|||
|
for epoch in range(self.epochs + 1):
|
|||
|
codes = self._train_epoch(epoch) if epoch>0 else dict.fromkeys(["μ", "logσ2", "y", "loss", "AR"])
|
|||
|
for key in codes_train:
|
|||
|
codes_train[key].append(codes[key])
|
|||
|
codes = self._test_epoch(epoch)
|
|||
|
for key in codes_test:
|
|||
|
codes_test[key].append(codes[key])
|
|||
|
if epoch != self.epochs:
|
|||
|
print("---")
|
|||
|
return codes_train, codes_test
|