variational-autoencoder/utils/plot_lib.py

30 lines
890 B
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from matplotlib import pyplot as plt
import numpy as np
import torch
from IPython.display import HTML, display
def set_default(figsize=(10, 10), dpi=100):
plt.style.use(['dark_background', 'bmh'])
plt.rc('axes', facecolor='k')
plt.rc('figure', facecolor='k')
plt.rc('figure', figsize=figsize, dpi=dpi)
def display_images(in_, out, n=1, label=None, count=False):
for N in range(n):
if in_ is not None:
in_pic = in_.data.cpu().view(-1, 28, 28)
plt.figure(figsize=(18, 4))
plt.suptitle(label + ' real test data / reconstructions', color='w', fontsize=16)
for i in range(4):
plt.subplot(1,4,i+1)
plt.imshow(in_pic[i+4*N])
plt.axis('off')
out_pic = out.data.cpu().view(-1, 28, 28)
plt.figure(figsize=(18, 6))
for i in range(4):
plt.subplot(1,4,i+1)
plt.imshow(out_pic[i+4*N])
plt.axis('off')
if count: plt.title(str(4 * N + i), color='w')