Skip to content

Commit df4beb4

Browse files
Add train/test split function
1 parent 7ca35de commit df4beb4

4 files changed

+199
-50
lines changed

Core Basics 2 - Train a Classifier on a Star Multi-Table Dataset.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,9 @@
344344
{
345345
"cell_type": "code",
346346
"execution_count": null,
347-
"metadata": {},
347+
"metadata": {
348+
"is_khiops_tutorial_solution": true
349+
},
348350
"outputs": [],
349351
"source": [
350352
"# To visualize uncomment the line below\n",

Sklearn Basics 2 - Train a Classifier on a Star Multi-Table Dataset.ipynb

Lines changed: 120 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
"import pandas as pd\n",
2323
"from khiops import core as kh\n",
2424
"from khiops.sklearn import KhiopsClassifier\n",
25+
"from khiops.utils.helpers import train_test_split_dataset\n",
26+
"from sklearn import metrics\n",
2527
"\n",
2628
"# If there are any issues you may Khiops status with the following command\n",
2729
"# kh.get_runner().print_status()"
@@ -106,8 +108,8 @@
106108
"metadata": {},
107109
"outputs": [],
108110
"source": [
109-
"headlines_train_df = headlines_df.drop(\"IsSarcasm\", axis=1)\n",
110-
"y_sarcasm_train = headlines_df[\"IsSarcasm\"]"
111+
"headlines_main_df = headlines_df.drop(\"IsSarcasm\", axis=1)\n",
112+
"y_sarcasm = headlines_df[\"IsSarcasm\"]"
111113
]
112114
},
113115
{
@@ -138,15 +140,36 @@
138140
"metadata": {},
139141
"outputs": [],
140142
"source": [
141-
"X_sarcasm_train = {\n",
143+
"X_sarcasm = {\n",
142144
" \"main_table\": \"headlines\",\n",
143145
" \"tables\": {\n",
144-
" \"headlines\": (headlines_train_df, \"HeadlineId\"),\n",
146+
" \"headlines\": (headlines_main_df, \"HeadlineId\"),\n",
145147
" \"headline_words\": (headlines_words_df, \"HeadlineId\"),\n",
146148
" },\n",
147149
"}"
148150
]
149151
},
152+
{
153+
"cell_type": "markdown",
154+
"metadata": {},
155+
"source": [
156+
"To separate this dataset into train and test, we user the `khiops-python` helper function `train_test_split_dataset`. This function allows to separate ``dict`` dataset specifications:"
157+
]
158+
},
159+
{
160+
"cell_type": "code",
161+
"execution_count": null,
162+
"metadata": {},
163+
"outputs": [],
164+
"source": [
165+
"(\n",
166+
" X_sarcasm_train,\n",
167+
" X_sarcasm_test,\n",
168+
" y_sarcasm_train,\n",
169+
" y_sarcasm_test,\n",
170+
") = train_test_split_dataset(X_sarcasm, y_sarcasm)"
171+
]
172+
},
150173
{
151174
"cell_type": "markdown",
152175
"metadata": {},
@@ -196,7 +219,7 @@
196219
"cell_type": "markdown",
197220
"metadata": {},
198221
"source": [
199-
"Now, we use our sarcasm classifier to obtain predictions on the training data. We normally do that on new test data, and again a multi-table dataset specification would have been needed."
222+
"Now, we use our sarcasm classifier to obtain predictions and probabilities on the test data:"
200223
]
201224
},
202225
{
@@ -205,9 +228,33 @@
205228
"metadata": {},
206229
"outputs": [],
207230
"source": [
208-
"sarcasm_predictions = khc_sarcasm.predict(X_sarcasm_train)\n",
209-
"print(\"HeadlineSarcasm train predictions (first 10 values):\")\n",
210-
"display(sarcasm_predictions[:10])"
231+
"y_sarcasm_test_predicted = khc_sarcasm.predict(X_sarcasm_test)\n",
232+
"probas_sarcasm_test = khc_sarcasm.predict_proba(X_sarcasm_test)\n",
233+
"\n",
234+
"print(\"HeadlineSarcasm test predictions (first 10 values):\")\n",
235+
"display(sarcasm_test_predicted[:10])\n",
236+
"print(\"HeadlineSarcasm test prediction probabilities (first 10 values):\")\n",
237+
"display(sarcasm_test_probas[:10])"
238+
]
239+
},
240+
{
241+
"cell_type": "markdown",
242+
"metadata": {},
243+
"source": [
244+
"Finally we may estimate the accuracy and AUC for the test data:"
245+
]
246+
},
247+
{
248+
"cell_type": "code",
249+
"execution_count": null,
250+
"metadata": {},
251+
"outputs": [],
252+
"source": [
253+
"sarcasm_test_accuracy = metrics.accuracy_score(y_sarcasm_test, y_sarcasm_test_predicted)\n",
254+
"sarcasm_test_auc = metrics.roc_auc_score(y_sarcasm_test, probas_sarcasm_test[:, 1])\n",
255+
"\n",
256+
"print(f\"Sarcasm test accuracy: {sarcasm_test_accuracy}\")\n",
257+
"print(f\"Sarcasm test auc : {sarcasm_test_auc}\")"
211258
]
212259
},
213260
{
@@ -249,13 +296,13 @@
249296
"accidents_dataset_dir = os.path.join(kh.get_samples_dir(), \"AccidentsSummary\")\n",
250297
"\n",
251298
"accidents_file = os.path.join(accidents_dataset_dir, \"Accidents.txt\")\n",
252-
"accidents_df = pd.read_csv(accidents_file, sep=\"\\t\", encoding=\"ISO-8859-1\")\n",
299+
"accidents_df = pd.read_csv(accidents_file, sep=\"\\t\", encoding=\"latin1\")\n",
253300
"print(f\"Accidents dataframe (first 10 rows):\")\n",
254301
"display(accidents_df.head(10))\n",
255302
"print()\n",
256303
"\n",
257304
"vehicles_file = os.path.join(accidents_dataset_dir, \"Vehicles.txt\")\n",
258-
"vehicles_df = pd.read_csv(vehicles_file, sep=\"\\t\", encoding=\"ISO-8859-1\")\n",
305+
"vehicles_df = pd.read_csv(vehicles_file, sep=\"\\t\", encoding=\"latin1\")\n",
259306
"print(f\"Vehicles dataframe (first 10 rows):\")\n",
260307
"display(vehicles_df.head(10))"
261308
]
@@ -278,7 +325,7 @@
278325
"outputs": [],
279326
"source": [
280327
"accidents_main_df = accidents_df.drop(\"Gravity\", axis=1)\n",
281-
"y_accidents_train = accidents_df[\"Gravity\"]"
328+
"y_accidents = accidents_df[\"Gravity\"]"
282329
]
283330
},
284331
{
@@ -298,7 +345,7 @@
298345
},
299346
"outputs": [],
300347
"source": [
301-
"X_accidents_train = {\n",
348+
"X_accidents = {\n",
302349
" \"main_table\": \"accidents\",\n",
303350
" \"tables\": {\n",
304351
" \"accidents\": (accidents_main_df, \"AccidentId\"),\n",
@@ -307,6 +354,29 @@
307354
"}"
308355
]
309356
},
357+
{
358+
"cell_type": "markdown",
359+
"metadata": {},
360+
"source": [
361+
"#### Split the dataset into train and test"
362+
]
363+
},
364+
{
365+
"cell_type": "code",
366+
"execution_count": null,
367+
"metadata": {
368+
"is_khiops_tutorial_solution": true
369+
},
370+
"outputs": [],
371+
"source": [
372+
"(\n",
373+
" X_accidents_train,\n",
374+
" X_accidents_test,\n",
375+
" y_accidents_train,\n",
376+
" y_accidents_test,\n",
377+
") = train_test_split_dataset(X_accidents, y_accidents)"
378+
]
379+
},
310380
{
311381
"cell_type": "markdown",
312382
"metadata": {},
@@ -333,13 +403,15 @@
333403
"cell_type": "markdown",
334404
"metadata": {},
335405
"source": [
336-
"#### Print the accuracy and auc of the model\n"
406+
"#### Print the train accuracy and auc of the model\n"
337407
]
338408
},
339409
{
340410
"cell_type": "code",
341411
"execution_count": null,
342-
"metadata": {},
412+
"metadata": {
413+
"is_khiops_tutorial_solution": true
414+
},
343415
"outputs": [],
344416
"source": [
345417
"accidents_train_performance = (\n",
@@ -353,9 +425,32 @@
353425
"cell_type": "markdown",
354426
"metadata": {},
355427
"source": [
356-
"#### Deploy the classifier to obtain predictions on the training data\n",
428+
"#### Deploy the classifier to obtain predictions and its probabilites on the test data"
429+
]
430+
},
431+
{
432+
"cell_type": "code",
433+
"execution_count": null,
434+
"metadata": {
435+
"is_khiops_tutorial_solution": true
436+
},
437+
"outputs": [],
438+
"source": [
439+
"y_accidents_test_predicted = khc_accidents.predict(X_accidents_test)\n",
440+
"probas_accidents_test = khc_accidents.predict_proba(X_accidents_test)\n",
357441
"\n",
358-
"*Note that usually one deploys the model on new test data. We deploy on the train dataset to keep the tutorial simple*.\n"
442+
"print(\"Accidents test predictions (first 10 values):\")\n",
443+
"display(y_accidents_test_predicted[:10])\n",
444+
"print(\"Accidentns test prediction probabilities (first 10 values):\")\n",
445+
"display(probas_accidents_test[:10])"
446+
]
447+
},
448+
{
449+
"cell_type": "markdown",
450+
"metadata": {},
451+
"source": [
452+
"#### Obtain the accuracy and AUC on the test dataset\n",
453+
"\n"
359454
]
360455
},
361456
{
@@ -366,7 +461,15 @@
366461
},
367462
"outputs": [],
368463
"source": [
369-
"khc_accidents.predict(X_accidents_train)"
464+
"accidents_test_accuracy = metrics.accuracy_score(\n",
465+
" y_accidents_test, y_accidents_test_predicted\n",
466+
")\n",
467+
"accidents_test_auc = metrics.roc_auc_score(\n",
468+
" y_accidents_test, probas_accidents_test[:, 1]\n",
469+
")\n",
470+
"\n",
471+
"print(f\"Accidents test accuracy: {accidents_test_accuracy}\")\n",
472+
"print(f\"Accidents test auc : {accidents_test_auc}\")"
370473
]
371474
}
372475
],

0 commit comments

Comments
 (0)