From adc1da3e49df519c0e4020d1d04c9f8e6c1ec95e Mon Sep 17 00:00:00 2001 From: arevelle-ensae Date: Sun, 10 Mar 2024 11:30:57 +0000 Subject: [PATCH] adjust pipeline --- Sport/Modelization/2_Modelization_sport.ipynb | 611 +++++++++++++++++- 1 file changed, 588 insertions(+), 23 deletions(-) diff --git a/Sport/Modelization/2_Modelization_sport.ipynb b/Sport/Modelization/2_Modelization_sport.ipynb index a3d0476..2922b21 100644 --- a/Sport/Modelization/2_Modelization_sport.ipynb +++ b/Sport/Modelization/2_Modelization_sport.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 201, "id": "f271eb45-1470-4764-8c2e-31374efa1fe5", "metadata": {}, "outputs": [], @@ -22,7 +22,7 @@ "import re\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.ensemble import RandomForestClassifier\n", - "from sklearn.metrics import accuracy_score, confusion_matrix, classification_report\n", + "from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, recall_score\n", "from sklearn.utils import class_weight\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.pipeline import Pipeline\n", @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 202, "id": "3fecb606-22e5-4dee-8efa-f8dff0832299", "metadata": {}, "outputs": [], @@ -64,7 +64,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 203, "id": "59dd4694-a812-4923-b995-a2ee86c74f85", "metadata": {}, "outputs": [], @@ -76,15 +76,15 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 204, "id": "017f7e9a-3ba0-40fa-bdc8-51b98cc1fdb3", "metadata": {}, "outputs": [], "source": [ "def load_train_test():\n", " BUCKET = \"projet-bdc2324-team1/Generalization/sport\"\n", - " File_path_train = BUCKET + \"/\" + \"Train_set.csv\"\n", - " File_path_test = BUCKET + \"/\" + \"Test_set.csv\"\n", + " File_path_train = BUCKET + \"/Train_set/\" + \"dataset_train5.csv\"\n", + " File_path_test = BUCKET + \"/Test_set/\" + \"dataset_test5.csv\"\n", " \n", " with fs.open( File_path_train, mode=\"rb\") as file_in:\n", " dataset_train = pd.read_csv(file_in, sep=\",\")\n", @@ -99,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 205, "id": "825d14a3-6967-4733-bfd4-64bf61c2bd43", "metadata": {}, "outputs": [], @@ -119,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 206, "id": "c479b230-b4bd-4cfb-b76b-d9faf6d95772", "metadata": {}, "outputs": [], @@ -129,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 207, "id": "69eaec12-b30f-4d30-a461-ea520d5cbf77", "metadata": {}, "outputs": [], @@ -137,6 +137,26 @@ "X_train, X_test, y_train, y_test = features_target_split(dataset_train, dataset_test)" ] }, + { + "cell_type": "code", + "execution_count": 208, + "id": "d039f31d-0093-46c6-9743-ddec1381f758", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shape train : (330117, 8)\n", + "Shape test : (141480, 8)\n" + ] + } + ], + "source": [ + "print(\"Shape train : \", X_train.shape)\n", + "print(\"Shape test : \", X_test.shape)" + ] + }, { "cell_type": "markdown", "id": "a1d6de94-4e11-481a-a0ce-412bf29f692c", @@ -147,10 +167,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 209, "id": "b808da43-c444-4e94-995a-7ec6ccd01e2d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{0.0: 0.5381774965030861, 1.0: 7.048360235716116}" + ] + }, + "execution_count": 209, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Compute Weights\n", "weights = class_weight.compute_class_weight(class_weight = 'balanced', classes = np.unique(y_train['y_has_purchased']),\n", @@ -162,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 210, "id": "b32a79ea-907f-4dfc-9832-6c74bef3200c", "metadata": {}, "outputs": [], @@ -193,7 +224,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 211, "id": "9809a688-bfbc-4685-a77f-17a8b2b79ab3", "metadata": {}, "outputs": [], @@ -201,12 +232,12 @@ "# Set loss\n", "\n", "balanced_scorer = make_scorer(balanced_accuracy_score)\n", - "f1_scorer = make_scorer(f1_score)\n" + "recall_scorer = make_scorer(recall_score)\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 212, "id": "206d9a95-7c37-4506-949b-e77d225e42c5", "metadata": {}, "outputs": [], @@ -214,27 +245,519 @@ "# Hyperparameter\n", "\n", "param_grid = {'logreg__C': np.logspace(-10, 6, 17, base=2),\n", - " 'logreg__penalty': ['l2', 'L1'],\n", + " 'logreg__penalty': ['l1', 'l2'],\n", " 'logreg__class_weight': ['balanced', weight_dict]} " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 213, "id": "7ff2f7bd-efc1-4f7c-a3c9-caa916aa2f2b", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
Pipeline(steps=[('preprocessor',\n",
+       "                 ColumnTransformer(transformers=[('num',\n",
+       "                                                  Pipeline(steps=[('scaler',\n",
+       "                                                                   StandardScaler())]),\n",
+       "                                                  ['nb_tickets', 'nb_purchases',\n",
+       "                                                   'total_amount',\n",
+       "                                                   'nb_suppliers',\n",
+       "                                                   'nb_tickets_internet',\n",
+       "                                                   'nb_campaigns',\n",
+       "                                                   'nb_campaigns_opened']),\n",
+       "                                                 ('cat',\n",
+       "                                                  Pipeline(steps=[('onehot',\n",
+       "                                                                   OneHotEncoder(handle_unknown='ignore',\n",
+       "                                                                                 sparse_output=False))]),\n",
+       "                                                  ['opt_in'])])),\n",
+       "                ('logreg',\n",
+       "                 LogisticRegression(class_weight={0.0: 0.5381774965030861,\n",
+       "                                                  1.0: 7.048360235716116},\n",
+       "                                    max_iter=5000, solver='saga'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ], + "text/plain": [ + "Pipeline(steps=[('preprocessor',\n", + " ColumnTransformer(transformers=[('num',\n", + " Pipeline(steps=[('scaler',\n", + " StandardScaler())]),\n", + " ['nb_tickets', 'nb_purchases',\n", + " 'total_amount',\n", + " 'nb_suppliers',\n", + " 'nb_tickets_internet',\n", + " 'nb_campaigns',\n", + " 'nb_campaigns_opened']),\n", + " ('cat',\n", + " Pipeline(steps=[('onehot',\n", + " OneHotEncoder(handle_unknown='ignore',\n", + " sparse_output=False))]),\n", + " ['opt_in'])])),\n", + " ('logreg',\n", + " LogisticRegression(class_weight={0.0: 0.5381774965030861,\n", + " 1.0: 7.048360235716116},\n", + " max_iter=5000, solver='saga'))])" + ] + }, + "execution_count": 213, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Pipeline\n", "\n", "pipeline = Pipeline(steps=[\n", " ('preprocessor', preproc),\n", - " ('logreg', LogisticRegression(solver='saga', max_iter=1000)) \n", + " ('logreg', LogisticRegression(solver='saga', class_weight = weight_dict,\n", + " max_iter=5000)) \n", "])\n", "\n", "pipeline.set_output(transform=\"pandas\")" ] }, + { + "cell_type": "markdown", + "id": "ed415f60-9663-4179-877b-233faf6e1645", + "metadata": {}, + "source": [ + "## Baseline" + ] + }, { "cell_type": "code", "execution_count": null, @@ -255,8 +778,14 @@ "y_pred = pipeline.predict(X_test)\n", "\n", "# Calculate the F1 score\n", + "acc = accuracy_score(y_test, y_pred)\n", + "print(f\"Accuracy Score: {acc}\")\n", + "\n", "f1 = f1_score(y_test, y_pred)\n", - "print(f\"F1 Score: {f1}\")" + "print(f\"F1 Score: {f1}\")\n", + "\n", + "recall = recall_score(y_test, y_pred)\n", + "print(f\"Recall Score: {recall}\")" ] }, { @@ -274,6 +803,32 @@ "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "580b58d7-596f-4207-8c99-4365aba2bc9f", + "metadata": {}, + "outputs": [], + "source": [ + "y_pred_prob = pipeline.predict_proba(X_test)[:, 1]\n", + "\n", + "# Calcul des taux de faux positifs (FPR) et de vrais positifs (TPR)\n", + "fpr, tpr, thresholds = roc_curve(y_test, y_pred_prob, pos_label=1)\n", + "\n", + "# Calcul de l'aire sous la courbe ROC (AUC)\n", + "roc_auc = auc(fpr, tpr)\n", + "\n", + "plt.figure(figsize = (14, 8))\n", + "plt.plot(fpr, tpr, label=\"ROC curve(area = %0.3f)\" % roc_auc)\n", + "plt.plot([0, 1], [0, 1], color=\"red\",label=\"Random Baseline\", linestyle=\"--\")\n", + "plt.grid(color='gray', linestyle='--', linewidth=0.5)\n", + "plt.xlabel('Taux de faux positifs (FPR)')\n", + "plt.ylabel('Taux de vrais positifs (TPR)')\n", + "plt.title('Courbe ROC : modèle logistique')\n", + "plt.legend(loc=\"lower right\")\n", + "plt.show()" + ] + }, { "cell_type": "markdown", "id": "ae8e9bd3-0f6a-4f82-bb4c-470cbdc8d6bb", @@ -282,6 +837,16 @@ "## Cross Validation" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f0535de-34f1-4e97-b993-b429ecf0a554", + "metadata": {}, + "outputs": [], + "source": [ + "y_train = y_train['y_has_purchased']" + ] + }, { "cell_type": "code", "execution_count": null, @@ -290,8 +855,8 @@ "outputs": [], "source": [ "# Cross validation\n", - "y_train = y_train['y_has_purchased']\n", - "grid_search = GridSearchCV(pipeline, param_grid, cv=5, scoring=f1_scorer, error_score='raise',\n", + "\n", + "grid_search = GridSearchCV(pipeline, param_grid, cv=3, scoring=f1_scorer, error_score='raise',\n", " n_jobs=-1)\n", "\n", "grid_search.fit(X_train, y_train)\n",