{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import pandas as pd\n", "from torch import nn, optim\n", "from torch.autograd.variable import Variable\n", "from torchvision import transforms, datasets\n", "from data_treatment import DataSet, DataAtts\n", "from discriminator import *\n", "from generator import *\n", "import ipywidgets as widgets\n", "from IPython.display import display\n", "import matplotlib.pyplot as plt\n", "import glob\n", "from format_data import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "folder_name = \"models/diabetes\"+\"/generator*.pt\"\n", "model_widget = widgets.Dropdown(\n", " options=glob.glob(folder_name),\n", " description='Generator:',\n", " disabled=False,\n", ")\n", "display(model_widget)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "original_db_name = \"models/diabetes\"[7:]\n", "original_db_path = \"original_data/\" + original_db_name + \".csv\"\n", "original_db = pd.read_csv(original_db_path)\n", "original_db_size=original_db.shape[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " checkpoint= torch.load(model_widget.value, map_location='cuda')\n", "except:\n", " checkpoint= torch.load(model_widget.value, map_location='cpu')\n", "checkpoint['model_attributes']['out_features'] = len(original_db.columns)\n", "generator = GeneratorNet(**checkpoint['model_attributes'])\n", "generator.load_state_dict(checkpoint['model_state_dict'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "size = original_db_size\n", "new_data = generator.create_data(size)\n", "df = pd.DataFrame(new_data, columns=original_db.columns)\n", "#Changes the name to be easier to read\n", "name = model_widget.value.split(\"/\")[-1][9:-4] + \"_size-\" + str(size)\n", "df.to_csv( \"fake_data/\" + original_db_name + \"/\" + name + \".csv\", index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Do the same thing as the cells above but for all the files in the directory\n", "import glob\n", "for file in glob.glob(folder_name):\n", " name = file.split(\"/\")[-1][9:-4]\n", " print(name)\n", " try:\n", " checkpoint= torch.load(file, map_location='cuda')\n", " except:\n", " checkpoint= torch.load(file, map_location='cpu')\n", " generator = GeneratorNet(**checkpoint['model_attributes'])\n", " generator.load_state_dict(checkpoint['model_state_dict'])\n", " size = original_db_size\n", " new_data = generator.create_data(size)\n", " new_data = format_output(new_data)\n", " df = pd.DataFrame(new_data, columns=original_db.columns)\n", " df = format_output_db(df)\n", " name = name + \"_size-\" + str(size)\n", " df.to_csv( \"fake_data/\" + original_db_name + \"/\" + name + \".csv\", index=False)" ] } ], "metadata": { "kernelspec": { "display_name": "tabular_gan", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8 | packaged by conda-forge | (main, Nov 24 2022, 14:07:00) [MSC v.1916 64 bit (AMD64)]" }, "vscode": { "interpreter": { "hash": "2f1136a7f15cd1225735fd9261403f7c342baa42a12d30e4630e4cfef11f2512" } } }, "nbformat": 4, "nbformat_minor": 2 }