131 lines
4.3 KiB
Plaintext
131 lines
4.3 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Results Tables\n",
|
|
"Proportion: outcome=0/outcome=1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"from data_treatment import DataAtts\n",
|
|
"from matplotlib import pyplot as plt\n",
|
|
"%matplotlib inline\n",
|
|
"\n",
|
|
"from sklearn.tree import DecisionTreeClassifier as DT\n",
|
|
"from sklearn.tree import export_graphviz # Decision tree from sklearn\n",
|
|
"\n",
|
|
"import pydotplus # Decision tree plotting\n",
|
|
"from IPython.display import Image\n",
|
|
"\n",
|
|
"import ipywidgets as widgets\n",
|
|
"import glob"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"file_name = 'original_data/diabetes.csv'\n",
|
|
"dataAtts = DataAtts(file_name) \n",
|
|
"data = pd.read_csv(file_name)\n",
|
|
"folder_name = file_name[14:-4]\n",
|
|
"\n",
|
|
"# Creates the training set\n",
|
|
"training_data = [[\"original\", data.head(int(data.shape[0]*0.7))]]\n",
|
|
"test = data.tail(int(data.shape[0]*0.3))\n",
|
|
"for file in glob.glob(\"fake_data\\\\\" + folder_name + \"\\\\*.csv\"):\n",
|
|
" name = \"fake\" + str(file).split(\"\\\\\")[2][0]\n",
|
|
" fake_data = pd.read_csv(file)\n",
|
|
" fake_data.loc[getattr(fake_data, dataAtts.class_name) >= 0.5, dataAtts.class_name] = 1\n",
|
|
" fake_data.loc[getattr(fake_data, dataAtts.class_name) < 0.5, dataAtts.class_name] = 0\n",
|
|
" fake_training=fake_data.head(int(fake_data.shape[0]*0.7))\n",
|
|
" training_data.append([name, fake_training])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(\"| Database \\t| Proportion \\t| Test Error \\t|\")\n",
|
|
"print(\"| ---------\\t| ---------: \\t| :--------- \\t|\")\n",
|
|
"\n",
|
|
"for episode in training_data:\n",
|
|
" name = episode[0]\n",
|
|
" train = episode[1]\n",
|
|
" try:\n",
|
|
" positive=str(round(train[dataAtts.class_name].value_counts()[0]/len(train) * 100,2))\n",
|
|
" except:\n",
|
|
" positive=\"0\"\n",
|
|
" try:\n",
|
|
" negative=str(round(train[dataAtts.class_name].value_counts()[1]/len(train) * 100,2))\n",
|
|
" except:\n",
|
|
" negative=\"0\"\n",
|
|
" \n",
|
|
" \n",
|
|
" trainX = train.drop(dataAtts.class_name, 1)\n",
|
|
" testX = test.drop(dataAtts.class_name, 1)\n",
|
|
" y_train = train[dataAtts.class_name]\n",
|
|
" y_test = test[dataAtts.class_name]\n",
|
|
" #trainX = pd.get_dummies(trainX)\n",
|
|
"\n",
|
|
" clf1 = DT(max_depth = 3, min_samples_leaf = 1)\n",
|
|
" clf1 = clf1.fit(trainX,y_train)\n",
|
|
" export_graphviz(clf1, out_file=\"models/tree.dot\", feature_names=trainX.columns, class_names=[\"0\",\"1\"], filled=True, rounded=True)\n",
|
|
" g = pydotplus.graph_from_dot_file(path=\"models/tree.dot\")\n",
|
|
"\n",
|
|
" pred = clf1.predict_proba(testX)\n",
|
|
" if pred.shape[1] > 1:\n",
|
|
" pred = np.argmax(pred, axis=1)\n",
|
|
" else:\n",
|
|
" pred = pred.reshape((pred.shape[0]))\n",
|
|
" if negative==\"0\":\n",
|
|
" pred = pred-1\n",
|
|
" \n",
|
|
" mse = round(((pred - y_test.values)**2).mean(axis=0), 4)\n",
|
|
" \n",
|
|
" string=\"| \" + name + \" \\t| \" + positive + \"/\" + negative + \" \\t| \" + str(mse) + \" \\t|\"\n",
|
|
" print(string)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "tabular_gan",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.8 | packaged by conda-forge | (main, Nov 24 2022, 14:07:00) [MSC v.1916 64 bit (AMD64)]"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "2f1136a7f15cd1225735fd9261403f7c342baa42a12d30e4630e4cfef11f2512"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|