Tabular-GAN-Project-5Y-INSA/utils.py

24 lines
600 B
Python
Raw Permalink Normal View History

2023-01-07 06:30:24 +00:00
from torch.autograd.variable import Variable
import torch
def random_noise(size):
n = Variable(torch.randn(size, 100))
if torch.cuda.is_available():
return n.cuda()
return n
def real_data_target(size):
'''
Tensor containing ones, with shape = size
'''
data = Variable(torch.ones(size, 1))
if torch.cuda.is_available(): return data.cuda()
return data
def fake_data_target(size):
'''
Tensor containing zeros, with shape = size
'''
data = Variable(torch.zeros(size, 1))
if torch.cuda.is_available(): return data.cuda()
return data