Tabular-GAN-Project-5Y-INSA/data_treatment.py

101 lines
3.9 KiB
Python

import torch
import pandas as pd
import os
from sklearn import preprocessing
from sklearn.utils import shuffle
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets, utils
from torch.utils.data import Dataset, DataLoader
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
# print (sample.values[2])
# print (torch.from_numpy(sample.values)[2].item())
return torch.from_numpy(sample.values)
class DataSet(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=transforms.Compose([ToTensor()]), training_porcentage=0.7, shuffle_db=False):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
# self.data = pd.read_csv(csv_file).head(100000)
self.file = pd.read_csv(csv_file)
if (shuffle):
self.file = shuffle(self.file)
self.data = self.file.head(int(self.file.shape[0]*training_porcentage))
self.test_data = self.file.tail(int(self.file.shape[0]*(1-training_porcentage)))
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data.iloc[idx]
if self.transform:
item = self.transform(item)
return item
def get_columns(self):
return self.data.columns
class DataAtts():
def __init__(self, file_name):
if file_name == "original_data/data.csv":
self.message = "Breast Cancer Wisconsin (Diagnostic) Data Set"
self.class_name = "diagnosis"
self.values_names = {0: "Benign", 1: "Malignant"}
self.class_len = 32
self.fname="data"
elif file_name == "original_data/creditcard.csv":
self.message = "Credit Card Fraud Detection"
self.class_name = "Class"
self.values_names = {0: "No Frauds", 1: "Frauds"}
self.class_len = 31
self.fname="creditcard"
elif file_name == "original_data/diabetes.csv":
self.message="Pima Indians Diabetes Database"
self.class_name = "Outcome"
self.values_names = {0: "Normal", 1: "Diabets"}
self.class_len = 9
self.fname="diabetes"
elif file_name == "original_data/data_escalonated.csv":
self.message = "Breast Cancer Wisconsin (Diagnostic) Data Set eSCALONATED"
self.class_name = "diagnosis"
self.values_names = {0: "Benign", 1: "Malignant"}
self.class_len = 32
self.fname="data_escalonated"
elif file_name == "original_data/creditcard_escalonated.csv":
self.message = "Credit Card Fraud Detection eSCALONATED"
self.class_name = "Class"
self.values_names = {0: "No Frauds", 1: "Frauds"}
self.class_len = 31
self.fname="creditcard_escalonated"
elif file_name == "original_data/diabetes_escalonated.csv":
self.message="Pima Indians Diabetes Database eSCALONATED"
self.class_name = "Outcome"
self.values_names = {0: "Normal", 1: "Diabets"}
self.class_len = 9
self.fname="diabetes_escalonated"
elif file_name == "original_data/creditcard_1s_escalonated.csv":
self.message = "Credit Card Fraud Detection eSCALONATED"
self.class_name = "Class"
self.values_names = {0: "No Frauds", 1: "Frauds"}
self.class_len = 31
self.fname="creditcard_1s_escalonated"
else:
print("File not found, exiting")
exit(1)