134 lines
3.9 KiB
Plaintext
134 lines
3.9 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import pandas as pd\n",
|
|
"from torch import nn, optim\n",
|
|
"from torch.autograd.variable import Variable\n",
|
|
"from torchvision import transforms, datasets\n",
|
|
"from data_treatment import DataSet, DataAtts\n",
|
|
"from discriminator import *\n",
|
|
"from generator import *\n",
|
|
"import ipywidgets as widgets\n",
|
|
"from IPython.display import display\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import glob\n",
|
|
"from format_data import *"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"folder_name = \"models/diabetes\"+\"/generator*.pt\"\n",
|
|
"model_widget = widgets.Dropdown(\n",
|
|
" options=glob.glob(folder_name),\n",
|
|
" description='Generator:',\n",
|
|
" disabled=False,\n",
|
|
")\n",
|
|
"display(model_widget)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"original_db_name = \"models/diabetes\"[7:]\n",
|
|
"original_db_path = \"original_data/\" + original_db_name + \".csv\"\n",
|
|
"original_db = pd.read_csv(original_db_path)\n",
|
|
"original_db_size=original_db.shape[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"try:\n",
|
|
" checkpoint= torch.load(model_widget.value, map_location='cuda')\n",
|
|
"except:\n",
|
|
" checkpoint= torch.load(model_widget.value, map_location='cpu')\n",
|
|
"checkpoint['model_attributes']['out_features'] = len(original_db.columns)\n",
|
|
"generator = GeneratorNet(**checkpoint['model_attributes'])\n",
|
|
"generator.load_state_dict(checkpoint['model_state_dict'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"size = original_db_size\n",
|
|
"new_data = generator.create_data(size)\n",
|
|
"df = pd.DataFrame(new_data, columns=original_db.columns)\n",
|
|
"#Changes the name to be easier to read\n",
|
|
"name = model_widget.value.split(\"/\")[-1][9:-4] + \"_size-\" + str(size)\n",
|
|
"df.to_csv( \"fake_data/\" + original_db_name + \"/\" + name + \".csv\", index=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#Do the same thing as the cells above but for all the files in the directory\n",
|
|
"import glob\n",
|
|
"for file in glob.glob(folder_name):\n",
|
|
" name = file.split(\"/\")[-1][9:-4]\n",
|
|
" print(name)\n",
|
|
" try:\n",
|
|
" checkpoint= torch.load(file, map_location='cuda')\n",
|
|
" except:\n",
|
|
" checkpoint= torch.load(file, map_location='cpu')\n",
|
|
" generator = GeneratorNet(**checkpoint['model_attributes'])\n",
|
|
" generator.load_state_dict(checkpoint['model_state_dict'])\n",
|
|
" size = original_db_size\n",
|
|
" new_data = generator.create_data(size)\n",
|
|
" new_data = format_output(new_data)\n",
|
|
" df = pd.DataFrame(new_data, columns=original_db.columns)\n",
|
|
" df = format_output_db(df)\n",
|
|
" name = name + \"_size-\" + str(size)\n",
|
|
" df.to_csv( \"fake_data/\" + original_db_name + \"/\" + name + \".csv\", index=False)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "tabular_gan",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.8 | packaged by conda-forge | (main, Nov 24 2022, 14:07:00) [MSC v.1916 64 bit (AMD64)]"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "2f1136a7f15cd1225735fd9261403f7c342baa42a12d30e4630e4cfef11f2512"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|