30 lines
890 B
Python
30 lines
890 B
Python
from matplotlib import pyplot as plt
|
||
import numpy as np
|
||
import torch
|
||
from IPython.display import HTML, display
|
||
|
||
|
||
def set_default(figsize=(10, 10), dpi=100):
|
||
plt.style.use(['dark_background', 'bmh'])
|
||
plt.rc('axes', facecolor='k')
|
||
plt.rc('figure', facecolor='k')
|
||
plt.rc('figure', figsize=figsize, dpi=dpi)
|
||
|
||
|
||
def display_images(in_, out, n=1, label=None, count=False):
|
||
for N in range(n):
|
||
if in_ is not None:
|
||
in_pic = in_.data.cpu().view(-1, 28, 28)
|
||
plt.figure(figsize=(18, 4))
|
||
plt.suptitle(label + ' – real test data / reconstructions', color='w', fontsize=16)
|
||
for i in range(4):
|
||
plt.subplot(1,4,i+1)
|
||
plt.imshow(in_pic[i+4*N])
|
||
plt.axis('off')
|
||
out_pic = out.data.cpu().view(-1, 28, 28)
|
||
plt.figure(figsize=(18, 6))
|
||
for i in range(4):
|
||
plt.subplot(1,4,i+1)
|
||
plt.imshow(out_pic[i+4*N])
|
||
plt.axis('off')
|
||
if count: plt.title(str(4 * N + i), color='w') |