Skip to content

Commit 4274fb0

Browse files
Add train/test split function
1 parent 7ca35de commit 4274fb0

4 files changed

+192
-50
lines changed

Core Basics 2 - Train a Classifier on a Star Multi-Table Dataset.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,9 @@
344344
{
345345
"cell_type": "code",
346346
"execution_count": null,
347-
"metadata": {},
347+
"metadata": {
348+
"is_khiops_tutorial_solution": true
349+
},
348350
"outputs": [],
349351
"source": [
350352
"# To visualize uncomment the line below\n",

Sklearn Basics 2 - Train a Classifier on a Star Multi-Table Dataset.ipynb

Lines changed: 120 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
"import pandas as pd\n",
2323
"from khiops import core as kh\n",
2424
"from khiops.sklearn import KhiopsClassifier\n",
25+
"from khiops.utils.helpers import train_test_split_dataset\n",
26+
"from sklearn import metrics\n",
2527
"\n",
2628
"# If there are any issues you may Khiops status with the following command\n",
2729
"# kh.get_runner().print_status()"
@@ -106,8 +108,8 @@
106108
"metadata": {},
107109
"outputs": [],
108110
"source": [
109-
"headlines_train_df = headlines_df.drop(\"IsSarcasm\", axis=1)\n",
110-
"y_sarcasm_train = headlines_df[\"IsSarcasm\"]"
111+
"headlines_main_df = headlines_df.drop(\"IsSarcasm\", axis=1)\n",
112+
"y_sarcasm = headlines_df[\"IsSarcasm\"]"
111113
]
112114
},
113115
{
@@ -138,15 +140,36 @@
138140
"metadata": {},
139141
"outputs": [],
140142
"source": [
141-
"X_sarcasm_train = {\n",
143+
"X_sarcasm = {\n",
142144
" \"main_table\": \"headlines\",\n",
143145
" \"tables\": {\n",
144-
" \"headlines\": (headlines_train_df, \"HeadlineId\"),\n",
146+
" \"headlines\": (headlines_main_df, \"HeadlineId\"),\n",
145147
" \"headline_words\": (headlines_words_df, \"HeadlineId\"),\n",
146148
" },\n",
147149
"}"
148150
]
149151
},
152+
{
153+
"cell_type": "markdown",
154+
"metadata": {},
155+
"source": [
156+
"To separate this dataset into train and test, we user the `khiops-python` helper function `train_test_split_dataset`. This function allows to separate ``dict`` dataset specifications:"
157+
]
158+
},
159+
{
160+
"cell_type": "code",
161+
"execution_count": null,
162+
"metadata": {},
163+
"outputs": [],
164+
"source": [
165+
"(\n",
166+
" X_sarcasm_train,\n",
167+
" X_sarcasm_test,\n",
168+
" y_sarcasm_train,\n",
169+
" y_sarcasm_test,\n",
170+
") = train_test_split_dataset(X_sarcasm, y_sarcasm)"
171+
]
172+
},
150173
{
151174
"cell_type": "markdown",
152175
"metadata": {},
@@ -196,7 +219,7 @@
196219
"cell_type": "markdown",
197220
"metadata": {},
198221
"source": [
199-
"Now, we use our sarcasm classifier to obtain predictions on the training data. We normally do that on new test data, and again a multi-table dataset specification would have been needed."
222+
"Now, we use our sarcasm classifier to obtain predictions and probabilities on the test data:"
200223
]
201224
},
202225
{
@@ -205,9 +228,33 @@
205228
"metadata": {},
206229
"outputs": [],
207230
"source": [
208-
"sarcasm_predictions = khc_sarcasm.predict(X_sarcasm_train)\n",
209-
"print(\"HeadlineSarcasm train predictions (first 10 values):\")\n",
210-
"display(sarcasm_predictions[:10])"
231+
"y_sarcasm_test_predicted = khc_sarcasm.predict(X_sarcasm_test)\n",
232+
"probas_sarcasm_test = khc_sarcasm.predict_proba(X_sarcasm_test)\n",
233+
"\n",
234+
"print(\"HeadlineSarcasm test predictions (first 10 values):\")\n",
235+
"display(y_sarcasm_test_predicted[:10])\n",
236+
"print(\"HeadlineSarcasm test prediction probabilities (first 10 values):\")\n",
237+
"display(probas_sarcasm_test[:10])"
238+
]
239+
},
240+
{
241+
"cell_type": "markdown",
242+
"metadata": {},
243+
"source": [
244+
"Finally we may estimate the accuracy and AUC for the test data:"
245+
]
246+
},
247+
{
248+
"cell_type": "code",
249+
"execution_count": null,
250+
"metadata": {},
251+
"outputs": [],
252+
"source": [
253+
"sarcasm_test_accuracy = metrics.accuracy_score(y_sarcasm_test, y_sarcasm_test_predicted)\n",
254+
"sarcasm_test_auc = metrics.roc_auc_score(y_sarcasm_test, probas_sarcasm_test[:, 1])\n",
255+
"\n",
256+
"print(f\"Sarcasm test accuracy: {sarcasm_test_accuracy}\")\n",
257+
"print(f\"Sarcasm test auc : {sarcasm_test_auc}\")"
211258
]
212259
},
213260
{
@@ -249,13 +296,13 @@
249296
"accidents_dataset_dir = os.path.join(kh.get_samples_dir(), \"AccidentsSummary\")\n",
250297
"\n",
251298
"accidents_file = os.path.join(accidents_dataset_dir, \"Accidents.txt\")\n",
252-
"accidents_df = pd.read_csv(accidents_file, sep=\"\\t\", encoding=\"ISO-8859-1\")\n",
299+
"accidents_df = pd.read_csv(accidents_file, sep=\"\\t\", encoding=\"latin1\")\n",
253300
"print(f\"Accidents dataframe (first 10 rows):\")\n",
254301
"display(accidents_df.head(10))\n",
255302
"print()\n",
256303
"\n",
257304
"vehicles_file = os.path.join(accidents_dataset_dir, \"Vehicles.txt\")\n",
258-
"vehicles_df = pd.read_csv(vehicles_file, sep=\"\\t\", encoding=\"ISO-8859-1\")\n",
305+
"vehicles_df = pd.read_csv(vehicles_file, sep=\"\\t\", encoding=\"latin1\")\n",
259306
"print(f\"Vehicles dataframe (first 10 rows):\")\n",
260307
"display(vehicles_df.head(10))"
261308
]
@@ -278,7 +325,7 @@
278325
"outputs": [],
279326
"source": [
280327
"accidents_main_df = accidents_df.drop(\"Gravity\", axis=1)\n",
281-
"y_accidents_train = accidents_df[\"Gravity\"]"
328+
"y_accidents = accidents_df[\"Gravity\"]"
282329
]
283330
},
284331
{
@@ -298,7 +345,7 @@
298345
},
299346
"outputs": [],
300347
"source": [
301-
"X_accidents_train = {\n",
348+
"X_accidents = {\n",
302349
" \"main_table\": \"accidents\",\n",
303350
" \"tables\": {\n",
304351
" \"accidents\": (accidents_main_df, \"AccidentId\"),\n",
@@ -307,6 +354,29 @@
307354
"}"
308355
]
309356
},
357+
{
358+
"cell_type": "markdown",
359+
"metadata": {},
360+
"source": [
361+
"#### Split the dataset into train and test"
362+
]
363+
},
364+
{
365+
"cell_type": "code",
366+
"execution_count": null,
367+
"metadata": {
368+
"is_khiops_tutorial_solution": true
369+
},
370+
"outputs": [],
371+
"source": [
372+
"(\n",
373+
" X_accidents_train,\n",
374+
" X_accidents_test,\n",
375+
" y_accidents_train,\n",
376+
" y_accidents_test,\n",
377+
") = train_test_split_dataset(X_accidents, y_accidents)"
378+
]
379+
},
310380
{
311381
"cell_type": "markdown",
312382
"metadata": {},
@@ -333,13 +403,15 @@
333403
"cell_type": "markdown",
334404
"metadata": {},
335405
"source": [
336-
"#### Print the accuracy and auc of the model\n"
406+
"#### Print the train accuracy and auc of the model\n"
337407
]
338408
},
339409
{
340410
"cell_type": "code",
341411
"execution_count": null,
342-
"metadata": {},
412+
"metadata": {
413+
"is_khiops_tutorial_solution": true
414+
},
343415
"outputs": [],
344416
"source": [
345417
"accidents_train_performance = (\n",
@@ -353,9 +425,32 @@
353425
"cell_type": "markdown",
354426
"metadata": {},
355427
"source": [
356-
"#### Deploy the classifier to obtain predictions on the training data\n",
428+
"#### Deploy the classifier to obtain predictions and its probabilites on the test data"
429+
]
430+
},
431+
{
432+
"cell_type": "code",
433+
"execution_count": null,
434+
"metadata": {
435+
"is_khiops_tutorial_solution": true
436+
},
437+
"outputs": [],
438+
"source": [
439+
"y_accidents_test_predicted = khc_accidents.predict(X_accidents_test)\n",
440+
"probas_accidents_test = khc_accidents.predict_proba(X_accidents_test)\n",
357441
"\n",
358-
"*Note that usually one deploys the model on new test data. We deploy on the train dataset to keep the tutorial simple*.\n"
442+
"print(\"Accidents test predictions (first 10 values):\")\n",
443+
"display(y_accidents_test_predicted[:10])\n",
444+
"print(\"Accidentns test prediction probabilities (first 10 values):\")\n",
445+
"display(probas_accidents_test[:10])"
446+
]
447+
},
448+
{
449+
"cell_type": "markdown",
450+
"metadata": {},
451+
"source": [
452+
"#### Obtain the accuracy and AUC on the test dataset\n",
453+
"\n"
359454
]
360455
},
361456
{
@@ -366,7 +461,15 @@
366461
},
367462
"outputs": [],
368463
"source": [
369-
"khc_accidents.predict(X_accidents_train)"
464+
"accidents_test_accuracy = metrics.accuracy_score(\n",
465+
" y_accidents_test, y_accidents_test_predicted\n",
466+
")\n",
467+
"accidents_test_auc = metrics.roc_auc_score(\n",
468+
" y_accidents_test, probas_accidents_test[:, 1]\n",
469+
")\n",
470+
"\n",
471+
"print(f\"Accidents test accuracy: {accidents_test_accuracy}\")\n",
472+
"print(f\"Accidents test auc : {accidents_test_auc}\")"
370473
]
371474
}
372475
],

Sklearn Basics 3 - Train a Classifier on a Snowflake Multi-Table Dataset.ipynb

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
"import pandas as pd\n",
2222
"from khiops import core as kh\n",
2323
"from khiops.sklearn import KhiopsClassifier\n",
24+
"from khiops.utils.helpers import train_test_split_dataset\n",
25+
"from sklearn import metrics\n",
2426
"\n",
2527
"# If there are any issues you may Khiops status with the following command\n",
2628
"# kh.get_runner().print_status()"
@@ -86,23 +88,6 @@
8688
"display(places_df.head(10))"
8789
]
8890
},
89-
{
90-
"cell_type": "markdown",
91-
"metadata": {},
92-
"source": [
93-
"#### Create the main feature matrix and the target vector for `Accidents`"
94-
]
95-
},
96-
{
97-
"cell_type": "code",
98-
"execution_count": null,
99-
"metadata": {},
100-
"outputs": [],
101-
"source": [
102-
"accidents_main_df = accidents_df.drop(\"Gravity\", axis=1)\n",
103-
"y_accidents_train = accidents_df[\"Gravity\"]"
104-
]
105-
},
10691
{
10792
"cell_type": "markdown",
10893
"metadata": {},
@@ -112,8 +97,7 @@
11297
"Note the main table `Accidents` and the secondary table `Places` have one key `AccidentId`.\n",
11398
"Tables `Vehicles` (the other secondary table) and `Users` (the tertiary table) have two keys: `AccidentId` and `VehicleId`.\n",
11499
"\n",
115-
"To describe relations between tables, the field `relations` must be added to the dictionary of table specifications. This field\n",
116-
"contains a list of tuples describing the relations between tables. The first two values (`str`) of each tuple correspond to names of both the parent and the child table involved in the relation. A third value (`bool`) can be optionally set as `True` to indicate that the relation is `1:1`. For example, if the tuple `(table1, table2, True)` is contained in this field, it means that:\n",
100+
"To describe relations between tables, we add the `relations` field must to the dataset spec. This field contains a list of tuples describing the relations between tables. The first two values (`str`) of each tuple correspond to names of both the parent and the child table involved in the relation. A third value (`bool`) can be optionally set as `True` to indicate that the relation is `1:1`. For example, if the tuple `(table1, table2, True)` is contained in this field, it means that:\n",
117101
"\n",
118102
" - `table1` and `table2` are in a `1:1` relationship\n",
119103
" - The key of `table1` is contained in that of `table2` (ie. keys are hierarchical)\n",
@@ -127,10 +111,10 @@
127111
"metadata": {},
128112
"outputs": [],
129113
"source": [
130-
"X_accidents_train = {\n",
114+
"X_accidents = {\n",
131115
" \"main_table\": \"Accidents\",\n",
132116
" \"tables\": {\n",
133-
" \"Accidents\": (accidents_main_df, \"AccidentId\"),\n",
117+
" \"Accidents\": (accidents_df.drop(\"Gravity\", axis=1), \"AccidentId\"),\n",
134118
" \"Vehicles\": (vehicles_df, [\"AccidentId\", \"VehicleId\"]),\n",
135119
" \"Users\": (users_df, [\"AccidentId\", \"VehicleId\"]),\n",
136120
" \"Places\": (places_df, [\"AccidentId\"]),\n",
@@ -140,7 +124,30 @@
140124
" (\"Vehicles\", \"Users\"),\n",
141125
" (\"Accidents\", \"Places\", True),\n",
142126
" ],\n",
143-
"}"
127+
"}\n",
128+
"y_accidents = accidents_df[\"Gravity\"]"
129+
]
130+
},
131+
{
132+
"cell_type": "markdown",
133+
"metadata": {},
134+
"source": [
135+
"#### Split the dataset into train and test\n",
136+
"We use the helper function `train_test_split_dataset` with the `X` dataset spec to obtain one spec for train and another for test."
137+
]
138+
},
139+
{
140+
"cell_type": "code",
141+
"execution_count": null,
142+
"metadata": {},
143+
"outputs": [],
144+
"source": [
145+
"(\n",
146+
" X_accidents_train,\n",
147+
" X_accidents_test,\n",
148+
" y_accidents_train,\n",
149+
" y_accidents_test,\n",
150+
") = train_test_split_dataset(X_accidents, y_accidents, test_size=0.3)"
144151
]
145152
},
146153
{
@@ -167,7 +174,7 @@
167174
"cell_type": "markdown",
168175
"metadata": {},
169176
"source": [
170-
"#### Print the accuracy and auc of the model\n"
177+
"#### Print the train accuracy and train auc of the model"
171178
]
172179
},
173180
{
@@ -187,20 +194,46 @@
187194
"cell_type": "markdown",
188195
"metadata": {},
189196
"source": [
190-
"#### Deploy the classifier to obtain predictions on the training data\n",
197+
"#### Deploy the classifier to obtain predictions and probabilities on the test data"
198+
]
199+
},
200+
{
201+
"cell_type": "code",
202+
"execution_count": null,
203+
"metadata": {},
204+
"outputs": [],
205+
"source": [
206+
"y_accidents_test_predicted = khc_accidents.predict(X_accidents_test)\n",
207+
"probas_accidents_test = khc_accidents.predict_proba(X_accidents_test)\n",
191208
"\n",
192-
"Note that usually one deploys the model on new test data. We deploy on the train dataset to keep the tutorial simple*.\n"
209+
"print(\"Accidents test predictions (first 10 values):\")\n",
210+
"display(y_accidents_test_predicted[:10])\n",
211+
"print(\"Accidentns test prediction probabilities (first 10 values):\")\n",
212+
"display(probas_accidents_test[:10])"
213+
]
214+
},
215+
{
216+
"cell_type": "markdown",
217+
"metadata": {},
218+
"source": [
219+
"#### Estimate the accuracy and AUC metrics on the test data"
193220
]
194221
},
195222
{
196223
"cell_type": "code",
197224
"execution_count": null,
198-
"metadata": {
199-
"is_khiops_tutorial_solution": true
200-
},
225+
"metadata": {},
201226
"outputs": [],
202227
"source": [
203-
"khc_accidents.predict(X_accidents_train)"
228+
"accidents_test_accuracy = metrics.accuracy_score(\n",
229+
" y_accidents_test, y_accidents_test_predicted\n",
230+
")\n",
231+
"accidents_test_auc = metrics.roc_auc_score(\n",
232+
" y_accidents_test, probas_accidents_test[:, 1]\n",
233+
")\n",
234+
"\n",
235+
"print(f\"Accidents test accuracy: {accidents_test_accuracy}\")\n",
236+
"print(f\"Accidents test auc : {accidents_test_auc}\")"
204237
]
205238
}
206239
],

0 commit comments

Comments
 (0)