Tabular-GAN-Project-5Y-INSA/create_fake_data.ipynb

134 lines
3.9 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import pandas as pd\n",
"from torch import nn, optim\n",
"from torch.autograd.variable import Variable\n",
"from torchvision import transforms, datasets\n",
"from data_treatment import DataSet, DataAtts\n",
"from discriminator import *\n",
"from generator import *\n",
"import ipywidgets as widgets\n",
"from IPython.display import display\n",
"import matplotlib.pyplot as plt\n",
"import glob\n",
"from format_data import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"folder_name = \"models/diabetes\"+\"/generator*.pt\"\n",
"model_widget = widgets.Dropdown(\n",
" options=glob.glob(folder_name),\n",
" description='Generator:',\n",
" disabled=False,\n",
")\n",
"display(model_widget)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"original_db_name = \"models/diabetes\"[7:]\n",
"original_db_path = \"original_data/\" + original_db_name + \".csv\"\n",
"original_db = pd.read_csv(original_db_path)\n",
"original_db_size=original_db.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" checkpoint= torch.load(model_widget.value, map_location='cuda')\n",
"except:\n",
" checkpoint= torch.load(model_widget.value, map_location='cpu')\n",
"checkpoint['model_attributes']['out_features'] = len(original_db.columns)\n",
"generator = GeneratorNet(**checkpoint['model_attributes'])\n",
"generator.load_state_dict(checkpoint['model_state_dict'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"size = original_db_size\n",
"new_data = generator.create_data(size)\n",
"df = pd.DataFrame(new_data, columns=original_db.columns)\n",
"#Changes the name to be easier to read\n",
"name = model_widget.value.split(\"/\")[-1][9:-4] + \"_size-\" + str(size)\n",
"df.to_csv( \"fake_data/\" + original_db_name + \"/\" + name + \".csv\", index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Do the same thing as the cells above but for all the files in the directory\n",
"import glob\n",
"for file in glob.glob(folder_name):\n",
" name = file.split(\"/\")[-1][9:-4]\n",
" print(name)\n",
" try:\n",
" checkpoint= torch.load(file, map_location='cuda')\n",
" except:\n",
" checkpoint= torch.load(file, map_location='cpu')\n",
" generator = GeneratorNet(**checkpoint['model_attributes'])\n",
" generator.load_state_dict(checkpoint['model_state_dict'])\n",
" size = original_db_size\n",
" new_data = generator.create_data(size)\n",
" new_data = format_output(new_data)\n",
" df = pd.DataFrame(new_data, columns=original_db.columns)\n",
" df = format_output_db(df)\n",
" name = name + \"_size-\" + str(size)\n",
" df.to_csv( \"fake_data/\" + original_db_name + \"/\" + name + \".csv\", index=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tabular_gan",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8 | packaged by conda-forge | (main, Nov 24 2022, 14:07:00) [MSC v.1916 64 bit (AMD64)]"
},
"vscode": {
"interpreter": {
"hash": "2f1136a7f15cd1225735fd9261403f7c342baa42a12d30e4630e4cfef11f2512"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}