variational-autoencoder/utils/util.py

22 lines
427 B
Python
Raw Permalink Normal View History

2023-08-21 15:13:45 +00:00
import numpy as np
import torch
from matplotlib import pyplot as plt
def prod(x):
if type(x)==list:
return np.prod(np.array(x))
elif type(x)==float or type(x)==int:
return x
else:
return np.prod(np.array(list(x.shape)))
def format_labels_prob(y, nbr_dif_labels=10):
n=y.shape[0]
list_y=list(y.numpy())
y_bis = torch.zeros((n, nbr_dif_labels))
for i in range(n):
j = list_y[i]
y_bis[i][j] = 1.
return y_bis