{ "cells": [ { "cell_type": "markdown", "id": "ff8cc602-e733-4a31-bf46-a31087511fe0", "metadata": {}, "source": [ "# Predict sales - sports companies" ] }, { "cell_type": "markdown", "id": "415e466a-1a71-4150-bff7-2f8904766df4", "metadata": {}, "source": [ "## Importations" ] }, { "cell_type": "code", "execution_count": 1, "id": "b5aaf421-850a-4a86-8e99-2c1f0723bd6c", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import os\n", "import s3fs\n", "import re\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, recall_score\n", "from sklearn.utils import class_weight\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.compose import ColumnTransformer\n", "from sklearn.preprocessing import OneHotEncoder\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.model_selection import GridSearchCV\n", "from sklearn.preprocessing import StandardScaler, MaxAbsScaler, MinMaxScaler\n", "from sklearn.metrics import make_scorer, f1_score, balanced_accuracy_score\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score\n", "from sklearn.exceptions import ConvergenceWarning, DataConversionWarning\n", "\n", "import pickle\n", "import warnings" ] }, { "cell_type": "markdown", "id": "c2f44070-451e-4109-9a08-3b80011d610f", "metadata": {}, "source": [ "## Load data " ] }, { "cell_type": "code", "execution_count": 2, "id": "b5f8135f-b6e7-4d6d-b8e1-da185b944aff", "metadata": {}, "outputs": [], "source": [ "# Create filesystem object\n", "S3_ENDPOINT_URL = \"https://\" + os.environ[\"AWS_S3_ENDPOINT\"]\n", "fs = s3fs.S3FileSystem(client_kwargs={'endpoint_url': S3_ENDPOINT_URL})" ] }, { "cell_type": "code", "execution_count": 3, "id": "2668a243-4ff8-40c6-9de2-5c9c07bcf714", "metadata": {}, "outputs": [], "source": [ "def load_train_test():\n", " BUCKET = \"projet-bdc2324-team1/Generalization/sport\"\n", " File_path_train = BUCKET + \"/Train_set.csv\"\n", " File_path_test = BUCKET + \"/Test_set.csv\"\n", " \n", " with fs.open( File_path_train, mode=\"rb\") as file_in:\n", " dataset_train = pd.read_csv(file_in, sep=\",\")\n", " # dataset_train['y_has_purchased'] = dataset_train['y_has_purchased'].fillna(0)\n", "\n", " with fs.open(File_path_test, mode=\"rb\") as file_in:\n", " dataset_test = pd.read_csv(file_in, sep=\",\")\n", " # dataset_test['y_has_purchased'] = dataset_test['y_has_purchased'].fillna(0)\n", " \n", " return dataset_train, dataset_test" ] }, { "cell_type": "code", "execution_count": 4, "id": "13eba3e1-3ea5-435b-8b05-6d7d5744cbe2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_462/2459610029.py:7: DtypeWarning: Columns (38) have mixed types. Specify dtype option on import or set low_memory=False.\n", " dataset_train = pd.read_csv(file_in, sep=\",\")\n" ] }, { "data": { "text/plain": [ "customer_id 0\n", "nb_tickets 0\n", "nb_purchases 0\n", "total_amount 0\n", "nb_suppliers 0\n", "vente_internet_max 0\n", "purchase_date_min 0\n", "purchase_date_max 0\n", "time_between_purchase 0\n", "nb_tickets_internet 0\n", "street_id 0\n", "structure_id 222825\n", "mcp_contact_id 70874\n", "fidelity 0\n", "tenant_id 0\n", "is_partner 0\n", "deleted_at 224213\n", "gender 0\n", "is_email_true 0\n", "opt_in 0\n", "last_buying_date 66139\n", "max_price 66139\n", "ticket_sum 0\n", "average_price 66023\n", "average_purchase_delay 66139\n", "average_price_basket 66139\n", "average_ticket_basket 66139\n", "total_price 116\n", "purchase_count 0\n", "first_buying_date 66139\n", "country 23159\n", "gender_label 0\n", "gender_female 0\n", "gender_male 0\n", "gender_other 0\n", "country_fr 23159\n", "nb_campaigns 0\n", "nb_campaigns_opened 0\n", "time_to_open 123159\n", "y_has_purchased 0\n", "dtype: int64" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset_train, dataset_test = load_train_test()\n", "dataset_train.isna().sum()" ] }, { "cell_type": "code", "execution_count": 18, "id": "e46622e7-0fc1-43f8-a7e7-34a5e90068b2", "metadata": {}, "outputs": [], "source": [ "def features_target_split(dataset_train, dataset_test):\n", " \"\"\"\n", " features_l = ['nb_tickets', 'nb_purchases', 'total_amount', 'nb_suppliers', 'vente_internet_max', 'purchase_date_min', 'purchase_date_max', \n", " 'time_between_purchase', 'nb_tickets_internet', 'fidelity', 'is_email_true', 'opt_in', #'is_partner',\n", " 'gender_female', 'gender_male', 'gender_other', 'nb_campaigns', 'nb_campaigns_opened']\n", " \"\"\"\n", "\n", " # we suppress fidelity, time between purchase, and gender other (colinearity issue)\n", " features_l = ['nb_tickets', 'nb_purchases', 'total_amount', 'nb_suppliers', 'vente_internet_max', \n", " 'purchase_date_min', 'purchase_date_max', 'nb_tickets_internet', 'is_email_true', \n", " 'opt_in', 'gender_female', 'gender_male', 'nb_campaigns', 'nb_campaigns_opened']\n", " \n", " X_train = dataset_train[features_l]\n", " y_train = dataset_train[['y_has_purchased']]\n", "\n", " X_test = dataset_test[features_l]\n", " y_test = dataset_test[['y_has_purchased']]\n", " return X_train, X_test, y_train, y_test" ] }, { "cell_type": "code", "execution_count": 19, "id": "cec4f386-e643-4bd8-b8cd-8917d2c1b3d0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape train : (224213, 14)\n", "Shape test : (96096, 14)\n" ] } ], "source": [ "X_train, X_test, y_train, y_test = features_target_split(dataset_train, dataset_test)\n", "print(\"Shape train : \", X_train.shape)\n", "print(\"Shape test : \", X_test.shape)" ] }, { "cell_type": "markdown", "id": "c9e8edbd-7ff6-42f9-a8eb-10d27ca19c8a", "metadata": {}, "source": [ "## Prepare preprocessing and Hyperparameters" ] }, { "cell_type": "code", "execution_count": 20, "id": "639b432a-c39c-4bf8-8ee2-e136d156e0dd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{0.0: 0.5837086520288036, 1.0: 3.486549107420539}" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Compute Weights\n", "weights = class_weight.compute_class_weight(class_weight = 'balanced', classes = np.unique(y_train['y_has_purchased']),\n", " y = y_train['y_has_purchased'])\n", "\n", "weight_dict = {np.unique(y_train['y_has_purchased'])[i]: weights[i] for i in range(len(np.unique(y_train['y_has_purchased'])))}\n", "weight_dict" ] }, { "cell_type": "code", "execution_count": 21, "id": "34644a00-85a5-41c9-98df-41178cb3ac69", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
nb_ticketsnb_purchasestotal_amountnb_suppliersvente_internet_maxpurchase_date_minpurchase_date_maxnb_tickets_internetis_email_trueopt_ingender_femalegender_malenb_campaignsnb_campaigns_opened
02.01.060.001.00.0355.268981355.2689810.0TrueFalse010.00.0
18.03.0140.001.00.0373.540289219.2622690.0TrueFalse010.00.0
22.01.050.001.00.05.2024425.2024420.0TrueFalse010.00.0
33.01.090.001.00.05.1789585.1789580.0TrueFalse010.00.0
42.01.078.001.00.05.1740395.1740390.0TrueFalse100.00.0
.............................................
2242080.00.00.000.00.0550.000000550.0000000.0TrueFalse0134.03.0
2242091.01.020.001.01.0392.501030392.5010301.0TrueFalse0123.06.0
2242100.00.00.000.00.0550.000000550.0000000.0TrueTrue018.04.0
2242111.01.097.111.01.0172.334074172.3340741.0TrueFalse0113.05.0
2242120.00.00.000.00.0550.000000550.0000000.0TrueFalse014.04.0
\n", "

224213 rows × 14 columns

\n", "
" ], "text/plain": [ " nb_tickets nb_purchases total_amount nb_suppliers \\\n", "0 2.0 1.0 60.00 1.0 \n", "1 8.0 3.0 140.00 1.0 \n", "2 2.0 1.0 50.00 1.0 \n", "3 3.0 1.0 90.00 1.0 \n", "4 2.0 1.0 78.00 1.0 \n", "... ... ... ... ... \n", "224208 0.0 0.0 0.00 0.0 \n", "224209 1.0 1.0 20.00 1.0 \n", "224210 0.0 0.0 0.00 0.0 \n", "224211 1.0 1.0 97.11 1.0 \n", "224212 0.0 0.0 0.00 0.0 \n", "\n", " vente_internet_max purchase_date_min purchase_date_max \\\n", "0 0.0 355.268981 355.268981 \n", "1 0.0 373.540289 219.262269 \n", "2 0.0 5.202442 5.202442 \n", "3 0.0 5.178958 5.178958 \n", "4 0.0 5.174039 5.174039 \n", "... ... ... ... \n", "224208 0.0 550.000000 550.000000 \n", "224209 1.0 392.501030 392.501030 \n", "224210 0.0 550.000000 550.000000 \n", "224211 1.0 172.334074 172.334074 \n", "224212 0.0 550.000000 550.000000 \n", "\n", " nb_tickets_internet is_email_true opt_in gender_female \\\n", "0 0.0 True False 0 \n", "1 0.0 True False 0 \n", "2 0.0 True False 0 \n", "3 0.0 True False 0 \n", "4 0.0 True False 1 \n", "... ... ... ... ... \n", "224208 0.0 True False 0 \n", "224209 1.0 True False 0 \n", "224210 0.0 True True 0 \n", "224211 1.0 True False 0 \n", "224212 0.0 True False 0 \n", "\n", " gender_male nb_campaigns nb_campaigns_opened \n", "0 1 0.0 0.0 \n", "1 1 0.0 0.0 \n", "2 1 0.0 0.0 \n", "3 1 0.0 0.0 \n", "4 0 0.0 0.0 \n", "... ... ... ... \n", "224208 1 34.0 3.0 \n", "224209 1 23.0 6.0 \n", "224210 1 8.0 4.0 \n", "224211 1 13.0 5.0 \n", "224212 1 4.0 4.0 \n", "\n", "[224213 rows x 14 columns]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train" ] }, { "cell_type": "code", "execution_count": 83, "id": "295676df-36ac-43d8-8b31-49ff08efd6e7", "metadata": {}, "outputs": [], "source": [ "# preprocess data \n", "# numeric features - standardize\n", "# categorical features - encode\n", "# encoded features - do nothing\n", "\n", "numeric_features = ['nb_tickets', 'nb_purchases', 'total_amount', 'nb_suppliers', 'vente_internet_max', \n", " 'purchase_date_min', 'purchase_date_max', 'nb_tickets_internet', 'nb_campaigns', \n", " 'nb_campaigns_opened' # , 'gender_male', 'gender_female'\n", " ]\n", "\n", "numeric_transformer = Pipeline(steps=[\n", " #(\"imputer\", SimpleImputer(strategy=\"mean\")), \n", " (\"scaler\", StandardScaler()) \n", "])\n", "\n", "categorical_features = ['opt_in', 'is_email_true'] \n", "\n", "# Transformer for the categorical features\n", "categorical_transformer = Pipeline(steps=[\n", " #(\"imputer\", SimpleImputer(strategy=\"most_frequent\")), # Impute missing values with the most frequent\n", " (\"onehot\", OneHotEncoder(handle_unknown='ignore', sparse_output=False))\n", "])\n", "\n", "preproc = ColumnTransformer(\n", " transformers=[\n", " (\"num\", numeric_transformer, numeric_features),\n", " (\"cat\", categorical_transformer, categorical_features)\n", " ]\n", ")" ] }, { "cell_type": "code", "execution_count": 80, "id": "f46fb56e-c908-40b4-868f-9684d1ae01c2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "nb_tickets 0\n", "nb_purchases 0\n", "total_amount 0\n", "nb_suppliers 0\n", "vente_internet_max 0\n", "purchase_date_min 0\n", "purchase_date_max 0\n", "nb_tickets_internet 0\n", "nb_campaigns 0\n", "nb_campaigns_opened 0\n", "dtype: int64" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train[numeric_features].isna().sum()" ] }, { "cell_type": "code", "execution_count": 52, "id": "e729781b-4d65-42c5-bdc5-82b4d653aaf0", "metadata": {}, "outputs": [], "source": [ "# Set loss\n", "balanced_scorer = make_scorer(balanced_accuracy_score)\n", "recall_scorer = make_scorer(recall_score)\n", "f1_scorer = make_scorer(f1_score)" ] }, { "cell_type": "code", "execution_count": 24, "id": "a7ebbe6f-70ba-4276-be18-f10e7bfd7423", "metadata": {}, "outputs": [], "source": [ "def draw_confusion_matrix(y_test, y_pred):\n", " conf_matrix = confusion_matrix(y_test, y_pred)\n", " sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=['Class 0', 'Class 1'], yticklabels=['Class 0', 'Class 1'])\n", " plt.xlabel('Predicted')\n", " plt.ylabel('Actual')\n", " plt.title('Confusion Matrix')\n", " plt.show()\n", "\n", "\n", "def draw_roc_curve(X_test, y_test):\n", " y_pred_prob = pipeline.predict_proba(X_test)[:, 1]\n", "\n", " # Calcul des taux de faux positifs (FPR) et de vrais positifs (TPR)\n", " fpr, tpr, thresholds = roc_curve(y_test, y_pred_prob, pos_label=1)\n", " \n", " # Calcul de l'aire sous la courbe ROC (AUC)\n", " roc_auc = auc(fpr, tpr)\n", " \n", " plt.figure(figsize = (14, 8))\n", " plt.plot(fpr, tpr, label=\"ROC curve(area = %0.3f)\" % roc_auc)\n", " plt.plot([0, 1], [0, 1], color=\"red\",label=\"Random Baseline\", linestyle=\"--\")\n", " plt.grid(color='gray', linestyle='--', linewidth=0.5)\n", " plt.xlabel('Taux de faux positifs (FPR)')\n", " plt.ylabel('Taux de vrais positifs (TPR)')\n", " plt.title('Courbe ROC : modèle logistique')\n", " plt.legend(loc=\"lower right\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 25, "id": "2334eb51-e6ea-4fd0-89ce-f54cd474d332", "metadata": {}, "outputs": [], "source": [ "def draw_features_importance(pipeline, model):\n", " coefficients = pipeline.named_steps['logreg'].coef_[0]\n", " feature_names = pipeline.named_steps['logreg'].feature_names_in_\n", " \n", " # Tracer l'importance des caractéristiques\n", " plt.figure(figsize=(10, 6))\n", " plt.barh(feature_names, coefficients, color='skyblue')\n", " plt.xlabel('Importance des caractéristiques')\n", " plt.ylabel('Caractéristiques')\n", " plt.title('Importance des caractéristiques dans le modèle de régression logistique')\n", " plt.grid(True)\n", " plt.show()\n", "\n", "def draw_prob_distribution(X_test):\n", " y_pred_prob = pipeline.predict_proba(X_test)[:, 1]\n", " plt.figure(figsize=(8, 6))\n", " plt.hist(y_pred_prob, bins=10, range=(0, 1), color='blue', alpha=0.7)\n", " \n", " plt.xlim(0, 1)\n", " plt.ylim(0, None)\n", " \n", " plt.title('Histogramme des probabilités pour la classe 1')\n", " plt.xlabel('Probabilité')\n", " plt.ylabel('Fréquence')\n", " plt.grid(True)\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 92, "id": "83917b97-4d9b-4e3c-ba27-1e546ce885d3", "metadata": {}, "outputs": [], "source": [ "# Hyperparameter\n", "\n", "param_c = np.logspace(-10, 4, 15, base=2)\n", "# param_penalty_type = ['l1', 'l2', 'elasticnet']\n", "param_penalty_type = ['l1']\n", "param_grid = {'logreg__C': param_c,\n", " 'logreg__penalty': param_penalty_type} " ] }, { "cell_type": "code", "execution_count": 26, "id": "3ae25049-920c-4a6d-a59d-c26e3b45dec6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1024" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "2 ** 10" ] }, { "cell_type": "code", "execution_count": 95, "id": "ba4cde9f-a614-4a43-81b9-e16e78aa6c4c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('preprocessor',\n",
       "                 ColumnTransformer(transformers=[('num',\n",
       "                                                  Pipeline(steps=[('scaler',\n",
       "                                                                   StandardScaler())]),\n",
       "                                                  ['nb_tickets', 'nb_purchases',\n",
       "                                                   'total_amount',\n",
       "                                                   'nb_suppliers',\n",
       "                                                   'vente_internet_max',\n",
       "                                                   'purchase_date_min',\n",
       "                                                   'purchase_date_max',\n",
       "                                                   'nb_tickets_internet',\n",
       "                                                   'nb_campaigns',\n",
       "                                                   'nb_campaigns_opened']),\n",
       "                                                 ('cat',\n",
       "                                                  Pipeline(steps=[('onehot',\n",
       "                                                                   OneHotEncoder(handle_unknown='ignore',\n",
       "                                                                                 sparse_output=False))]),\n",
       "                                                  ['opt_in',\n",
       "                                                   'is_email_true'])])),\n",
       "                ('logreg',\n",
       "                 LogisticRegression(class_weight={0.0: 0.5837086520288036,\n",
       "                                                  1.0: 3.486549107420539},\n",
       "                                    max_iter=5000, solver='saga'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('preprocessor',\n", " ColumnTransformer(transformers=[('num',\n", " Pipeline(steps=[('scaler',\n", " StandardScaler())]),\n", " ['nb_tickets', 'nb_purchases',\n", " 'total_amount',\n", " 'nb_suppliers',\n", " 'vente_internet_max',\n", " 'purchase_date_min',\n", " 'purchase_date_max',\n", " 'nb_tickets_internet',\n", " 'nb_campaigns',\n", " 'nb_campaigns_opened']),\n", " ('cat',\n", " Pipeline(steps=[('onehot',\n", " OneHotEncoder(handle_unknown='ignore',\n", " sparse_output=False))]),\n", " ['opt_in',\n", " 'is_email_true'])])),\n", " ('logreg',\n", " LogisticRegression(class_weight={0.0: 0.5837086520288036,\n", " 1.0: 3.486549107420539},\n", " max_iter=5000, solver='saga'))])" ] }, "execution_count": 95, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Pipeline\n", "pipeline = Pipeline(steps=[\n", " ('preprocessor', preproc),\n", " ('logreg', LogisticRegression(solver='saga', class_weight = weight_dict,\n", " max_iter=5000)) \n", "])\n", "\n", "pipeline.set_output(transform=\"pandas\")" ] }, { "cell_type": "code", "execution_count": 40, "id": "1e4c1be5-176d-4222-9b3c-fe27225afe36", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
nb_ticketsnb_purchasestotal_amountnb_suppliersvente_internet_maxpurchase_date_minpurchase_date_maxnb_tickets_internetis_email_trueopt_ingender_femalegender_malenb_campaignsnb_campaigns_opened
6446611.04.0281.41.01.0238.33059130.28504011.0TrueFalse100.00.0
1413270.00.00.00.00.0550.000000550.0000000.0TrueTrue0010.00.0
599992.01.00.01.01.0350.288926350.2889262.0TrueFalse100.00.0
268820.00.00.00.00.0550.000000550.0000000.0TrueFalse104.01.0
6295211.03.0325.01.01.0424.486781237.28226211.0TrueFalse000.00.0
.............................................
1413180.00.00.00.00.0550.000000550.0000000.0TrueTrue0016.01.0
1138383.02.015.01.01.0153.15294590.2770993.0TrueTrue0131.014.0
1849260.00.00.00.00.0550.000000550.0000000.0TrueTrue0118.00.0
146171.01.020.01.00.0239.258970239.2589700.0TrueTrue010.00.0
216854.01.088.01.00.0240.355162240.3551620.0TrueTrue010.00.0
\n", "

10000 rows × 14 columns

\n", "
" ], "text/plain": [ " nb_tickets nb_purchases total_amount nb_suppliers \\\n", "64466 11.0 4.0 281.4 1.0 \n", "141327 0.0 0.0 0.0 0.0 \n", "59999 2.0 1.0 0.0 1.0 \n", "26882 0.0 0.0 0.0 0.0 \n", "62952 11.0 3.0 325.0 1.0 \n", "... ... ... ... ... \n", "141318 0.0 0.0 0.0 0.0 \n", "113838 3.0 2.0 15.0 1.0 \n", "184926 0.0 0.0 0.0 0.0 \n", "14617 1.0 1.0 20.0 1.0 \n", "21685 4.0 1.0 88.0 1.0 \n", "\n", " vente_internet_max purchase_date_min purchase_date_max \\\n", "64466 1.0 238.330591 30.285040 \n", "141327 0.0 550.000000 550.000000 \n", "59999 1.0 350.288926 350.288926 \n", "26882 0.0 550.000000 550.000000 \n", "62952 1.0 424.486781 237.282262 \n", "... ... ... ... \n", "141318 0.0 550.000000 550.000000 \n", "113838 1.0 153.152945 90.277099 \n", "184926 0.0 550.000000 550.000000 \n", "14617 0.0 239.258970 239.258970 \n", "21685 0.0 240.355162 240.355162 \n", "\n", " nb_tickets_internet is_email_true opt_in gender_female \\\n", "64466 11.0 True False 1 \n", "141327 0.0 True True 0 \n", "59999 2.0 True False 1 \n", "26882 0.0 True False 1 \n", "62952 11.0 True False 0 \n", "... ... ... ... ... \n", "141318 0.0 True True 0 \n", "113838 3.0 True True 0 \n", "184926 0.0 True True 0 \n", "14617 0.0 True True 0 \n", "21685 0.0 True True 0 \n", "\n", " gender_male nb_campaigns nb_campaigns_opened \n", "64466 0 0.0 0.0 \n", "141327 0 10.0 0.0 \n", "59999 0 0.0 0.0 \n", "26882 0 4.0 1.0 \n", "62952 0 0.0 0.0 \n", "... ... ... ... \n", "141318 0 16.0 1.0 \n", "113838 1 31.0 14.0 \n", "184926 1 18.0 0.0 \n", "14617 1 0.0 0.0 \n", "21685 1 0.0 0.0 \n", "\n", "[10000 rows x 14 columns]" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# reduce X_train to reduce the training time\n", "\n", "X_train_subsample = X_train.sample(n=10000)\n", "y_train_subsample = y_train.loc[X_train_subsample.index]\n", "X_train_subsample" ] }, { "cell_type": "code", "execution_count": 41, "id": "2b09c2cd-fd5c-49b3-be66-cec6c5ec1351", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
y_has_purchased
644660.0
1413270.0
599990.0
268820.0
629520.0
......
1413180.0
1138380.0
1849260.0
146170.0
216850.0
\n", "

10000 rows × 1 columns

\n", "
" ], "text/plain": [ " y_has_purchased\n", "64466 0.0\n", "141327 0.0\n", "59999 0.0\n", "26882 0.0\n", "62952 0.0\n", "... ...\n", "141318 0.0\n", "113838 0.0\n", "184926 0.0\n", "14617 0.0\n", "21685 0.0\n", "\n", "[10000 rows x 1 columns]" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train_subsample" ] }, { "cell_type": "code", "execution_count": 42, "id": "6c33fcd8-17d8-4390-b836-faec9ada9acd", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('preprocessor',\n",
       "                 ColumnTransformer(transformers=[('num',\n",
       "                                                  Pipeline(steps=[('scaler',\n",
       "                                                                   StandardScaler())]),\n",
       "                                                  ['nb_tickets', 'nb_purchases',\n",
       "                                                   'total_amount',\n",
       "                                                   'nb_suppliers',\n",
       "                                                   'vente_internet_max',\n",
       "                                                   'purchase_date_min',\n",
       "                                                   'purchase_date_max',\n",
       "                                                   'nb_tickets_internetnb_campaigns',\n",
       "                                                   'nb_campaigns_opened']),\n",
       "                                                 ('cat',\n",
       "                                                  Pipeline(steps=[('onehot',\n",
       "                                                                   OneHotEncoder(handle_unknown='ignore',\n",
       "                                                                                 sparse_output=False))]),\n",
       "                                                  ['opt_in',\n",
       "                                                   'is_email_true'])])),\n",
       "                ('logreg',\n",
       "                 LogisticRegression(class_weight={0.0: 0.5837086520288036,\n",
       "                                                  1.0: 3.486549107420539},\n",
       "                                    max_iter=5000, solver='saga'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('preprocessor',\n", " ColumnTransformer(transformers=[('num',\n", " Pipeline(steps=[('scaler',\n", " StandardScaler())]),\n", " ['nb_tickets', 'nb_purchases',\n", " 'total_amount',\n", " 'nb_suppliers',\n", " 'vente_internet_max',\n", " 'purchase_date_min',\n", " 'purchase_date_max',\n", " 'nb_tickets_internetnb_campaigns',\n", " 'nb_campaigns_opened']),\n", " ('cat',\n", " Pipeline(steps=[('onehot',\n", " OneHotEncoder(handle_unknown='ignore',\n", " sparse_output=False))]),\n", " ['opt_in',\n", " 'is_email_true'])])),\n", " ('logreg',\n", " LogisticRegression(class_weight={0.0: 0.5837086520288036,\n", " 1.0: 3.486549107420539},\n", " max_iter=5000, solver='saga'))])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline" ] }, { "cell_type": "code", "execution_count": 43, "id": "710ccccc-50c9-4aba-8cf1-11483dbbdd1c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'logreg__C': array([9.765625e-04, 1.953125e-03, 3.906250e-03, 7.812500e-03,\n", " 1.562500e-02, 3.125000e-02, 6.250000e-02, 1.250000e-01,\n", " 2.500000e-01, 5.000000e-01, 1.000000e+00, 2.000000e+00,\n", " 4.000000e+00, 8.000000e+00, 1.600000e+01]),\n", " 'logreg__penalty': ['l1', 'l2', 'elasticnet']}" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "param_grid" ] }, { "cell_type": "code", "execution_count": 46, "id": "ab078cf8-0d4c-4b23-9f33-2483cf605b06", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "make_scorer(f1_score, response_method='predict')" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f1_scorer" ] }, { "cell_type": "code", "execution_count": 50, "id": "8062169e-8305-42b0-aeff-8f714117da40", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
nb_ticketsnb_purchasestotal_amountnb_suppliersvente_internet_maxpurchase_date_minpurchase_date_maxnb_tickets_internetis_email_trueopt_ingender_femalegender_malenb_campaignsnb_campaigns_opened
6446611.04.0281.41.01.0238.33059130.28504011.0TrueFalse100.00.0
1413270.00.00.00.00.0550.000000550.0000000.0TrueTrue0010.00.0
599992.01.00.01.01.0350.288926350.2889262.0TrueFalse100.00.0
268820.00.00.00.00.0550.000000550.0000000.0TrueFalse104.01.0
6295211.03.0325.01.01.0424.486781237.28226211.0TrueFalse000.00.0
.............................................
1413180.00.00.00.00.0550.000000550.0000000.0TrueTrue0016.01.0
1138383.02.015.01.01.0153.15294590.2770993.0TrueTrue0131.014.0
1849260.00.00.00.00.0550.000000550.0000000.0TrueTrue0118.00.0
146171.01.020.01.00.0239.258970239.2589700.0TrueTrue010.00.0
216854.01.088.01.00.0240.355162240.3551620.0TrueTrue010.00.0
\n", "

10000 rows × 14 columns

\n", "
" ], "text/plain": [ " nb_tickets nb_purchases total_amount nb_suppliers \\\n", "64466 11.0 4.0 281.4 1.0 \n", "141327 0.0 0.0 0.0 0.0 \n", "59999 2.0 1.0 0.0 1.0 \n", "26882 0.0 0.0 0.0 0.0 \n", "62952 11.0 3.0 325.0 1.0 \n", "... ... ... ... ... \n", "141318 0.0 0.0 0.0 0.0 \n", "113838 3.0 2.0 15.0 1.0 \n", "184926 0.0 0.0 0.0 0.0 \n", "14617 1.0 1.0 20.0 1.0 \n", "21685 4.0 1.0 88.0 1.0 \n", "\n", " vente_internet_max purchase_date_min purchase_date_max \\\n", "64466 1.0 238.330591 30.285040 \n", "141327 0.0 550.000000 550.000000 \n", "59999 1.0 350.288926 350.288926 \n", "26882 0.0 550.000000 550.000000 \n", "62952 1.0 424.486781 237.282262 \n", "... ... ... ... \n", "141318 0.0 550.000000 550.000000 \n", "113838 1.0 153.152945 90.277099 \n", "184926 0.0 550.000000 550.000000 \n", "14617 0.0 239.258970 239.258970 \n", "21685 0.0 240.355162 240.355162 \n", "\n", " nb_tickets_internet is_email_true opt_in gender_female \\\n", "64466 11.0 True False 1 \n", "141327 0.0 True True 0 \n", "59999 2.0 True False 1 \n", "26882 0.0 True False 1 \n", "62952 11.0 True False 0 \n", "... ... ... ... ... \n", "141318 0.0 True True 0 \n", "113838 3.0 True True 0 \n", "184926 0.0 True True 0 \n", "14617 0.0 True True 0 \n", "21685 0.0 True True 0 \n", "\n", " gender_male nb_campaigns nb_campaigns_opened \n", "64466 0 0.0 0.0 \n", "141327 0 10.0 0.0 \n", "59999 0 0.0 0.0 \n", "26882 0 4.0 1.0 \n", "62952 0 0.0 0.0 \n", "... ... ... ... \n", "141318 0 16.0 1.0 \n", "113838 1 31.0 14.0 \n", "184926 1 18.0 0.0 \n", "14617 1 0.0 0.0 \n", "21685 1 0.0 0.0 \n", "\n", "[10000 rows x 14 columns]" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train_subsample" ] }, { "cell_type": "code", "execution_count": 86, "id": "0270013a-6523-4cf8-8de0-569c0d1c5db5", "metadata": {}, "outputs": [], "source": [ "warnings.filterwarnings('ignore')\n", "warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)\n", "warnings.filterwarnings(\"ignore\", category=DataConversionWarning)" ] }, { "cell_type": "code", "execution_count": 88, "id": "7a49d78a-5a9b-44a9-95cf-3fca1b3febfa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Returned hyperparameter: {'logreg__C': 0.03125, 'logreg__penalty': 'l1'}\n", "Best classification accuracy in train is: 0.47785817197986385\n" ] } ], "source": [ "# run the pipeline on the subsample\n", "\n", "logit_grid = GridSearchCV(pipeline, param_grid, cv=3, scoring = f1_scorer #, error_score=\"raise\"\n", " )\n", "logit_grid.fit(X_train_subsample, y_train_subsample)\n", "\n", "# print results\n", "print('Returned hyperparameter: {}'.format(logit_grid.best_params_))\n", "print('Best classification accuracy in train is: {}'.format(logit_grid.best_score_))\n", "# print('Classification accuracy on test is: {}'.format(logit_grid.score(X_test, y_test)))" ] }, { "cell_type": "code", "execution_count": 89, "id": "b1d5e71d-1078-4370-86e8-52b1ae378898", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([9.765625e-04, 1.953125e-03, 3.906250e-03, 7.812500e-03,\n", " 1.562500e-02, 3.125000e-02, 6.250000e-02, 1.250000e-01,\n", " 2.500000e-01, 5.000000e-01, 1.000000e+00, 2.000000e+00,\n", " 4.000000e+00, 8.000000e+00, 1.600000e+01])" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "param_c" ] }, { "cell_type": "code", "execution_count": 96, "id": "cfe04739-fe9c-4802-9d34-885a8cfce0dc", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
GridSearchCV(cv=3,\n",
       "             estimator=Pipeline(steps=[('preprocessor',\n",
       "                                        ColumnTransformer(transformers=[('num',\n",
       "                                                                         Pipeline(steps=[('scaler',\n",
       "                                                                                          StandardScaler())]),\n",
       "                                                                         ['nb_tickets',\n",
       "                                                                          'nb_purchases',\n",
       "                                                                          'total_amount',\n",
       "                                                                          'nb_suppliers',\n",
       "                                                                          'vente_internet_max',\n",
       "                                                                          'purchase_date_min',\n",
       "                                                                          'purchase_date_max',\n",
       "                                                                          'nb_tickets_internet',\n",
       "                                                                          'nb_campaigns',\n",
       "                                                                          'nb_campaigns_opened']),\n",
       "                                                                        ('cat',\n",
       "                                                                         Pipeline(steps=[(...\n",
       "                                                                         1.0: 3.486549107420539},\n",
       "                                                           max_iter=5000,\n",
       "                                                           solver='saga'))]),\n",
       "             param_grid={'logreg__C': array([9.765625e-04, 1.953125e-03, 3.906250e-03, 7.812500e-03,\n",
       "       1.562500e-02, 3.125000e-02, 6.250000e-02, 1.250000e-01,\n",
       "       2.500000e-01, 5.000000e-01, 1.000000e+00, 2.000000e+00,\n",
       "       4.000000e+00, 8.000000e+00, 1.600000e+01]),\n",
       "                         'logreg__penalty': ['l1']},\n",
       "             scoring=make_scorer(f1_score, response_method='predict'))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "GridSearchCV(cv=3,\n", " estimator=Pipeline(steps=[('preprocessor',\n", " ColumnTransformer(transformers=[('num',\n", " Pipeline(steps=[('scaler',\n", " StandardScaler())]),\n", " ['nb_tickets',\n", " 'nb_purchases',\n", " 'total_amount',\n", " 'nb_suppliers',\n", " 'vente_internet_max',\n", " 'purchase_date_min',\n", " 'purchase_date_max',\n", " 'nb_tickets_internet',\n", " 'nb_campaigns',\n", " 'nb_campaigns_opened']),\n", " ('cat',\n", " Pipeline(steps=[(...\n", " 1.0: 3.486549107420539},\n", " max_iter=5000,\n", " solver='saga'))]),\n", " param_grid={'logreg__C': array([9.765625e-04, 1.953125e-03, 3.906250e-03, 7.812500e-03,\n", " 1.562500e-02, 3.125000e-02, 6.250000e-02, 1.250000e-01,\n", " 2.500000e-01, 5.000000e-01, 1.000000e+00, 2.000000e+00,\n", " 4.000000e+00, 8.000000e+00, 1.600000e+01]),\n", " 'logreg__penalty': ['l1']},\n", " scoring=make_scorer(f1_score, response_method='predict'))" ] }, "execution_count": 96, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logit_grid = GridSearchCV(pipeline, param_grid, cv=3, scoring = f1_scorer #, error_score=\"raise\"\n", " )\n", "logit_grid" ] }, { "cell_type": "code", "execution_count": 97, "id": "6debc66c-a56d-41fa-8ef8-ba388e0e14fe", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'logreg__C': array([9.765625e-04, 1.953125e-03, 3.906250e-03, 7.812500e-03,\n", " 1.562500e-02, 3.125000e-02, 6.250000e-02, 1.250000e-01,\n", " 2.500000e-01, 5.000000e-01, 1.000000e+00, 2.000000e+00,\n", " 4.000000e+00, 8.000000e+00, 1.600000e+01]),\n", " 'logreg__penalty': ['l1']}" ] }, "execution_count": 97, "metadata": {}, "output_type": "execute_result" } ], "source": [ "param_grid" ] }, { "cell_type": "code", "execution_count": 98, "id": "e394cc04-5d0b-4a64-9aa0-415dc8a3cbbc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Returned hyperparameter: {'logreg__C': 0.03125, 'logreg__penalty': 'l1'}\n", "Best classification accuracy in train is: 0.42160313383818665\n", "Classification accuracy on test is: 0.47078982841737305\n" ] } ], "source": [ "# run the pipeline on the full sample\n", "\n", "logit_grid = GridSearchCV(pipeline, param_grid, cv=3, scoring = f1_scorer #, error_score=\"raise\"\n", " )\n", "logit_grid.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 99, "id": "8e6cf558-a4f4-4159-9835-364ee3bb1ed2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Returned hyperparameter: {'logreg__C': 0.03125, 'logreg__penalty': 'l1'}\n", "Best classification F1 score in train is: 0.42160313383818665\n", "Classification F1 score on test is: 0.47078982841737305\n" ] } ], "source": [ "# print results\n", "print('Returned hyperparameter: {}'.format(logit_grid.best_params_))\n", "print('Best classification F1 score in train is: {}'.format(logit_grid.best_score_))\n", "print('Classification F1 score on test is: {}'.format(logit_grid.score(X_test, y_test)))" ] }, { "cell_type": "code", "execution_count": 100, "id": "e2ff26cb-f137-4a23-9add-bdb61bebdf9c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
GridSearchCV(cv=3,\n",
       "             estimator=Pipeline(steps=[('preprocessor',\n",
       "                                        ColumnTransformer(transformers=[('num',\n",
       "                                                                         Pipeline(steps=[('scaler',\n",
       "                                                                                          StandardScaler())]),\n",
       "                                                                         ['nb_tickets',\n",
       "                                                                          'nb_purchases',\n",
       "                                                                          'total_amount',\n",
       "                                                                          'nb_suppliers',\n",
       "                                                                          'vente_internet_max',\n",
       "                                                                          'purchase_date_min',\n",
       "                                                                          'purchase_date_max',\n",
       "                                                                          'nb_tickets_internet',\n",
       "                                                                          'nb_campaigns',\n",
       "                                                                          'nb_campaigns_opened']),\n",
       "                                                                        ('cat',\n",
       "                                                                         Pipeline(steps=[(...\n",
       "                                                                         1.0: 3.486549107420539},\n",
       "                                                           max_iter=5000,\n",
       "                                                           solver='saga'))]),\n",
       "             param_grid={'logreg__C': array([9.765625e-04, 1.953125e-03, 3.906250e-03, 7.812500e-03,\n",
       "       1.562500e-02, 3.125000e-02, 6.250000e-02, 1.250000e-01,\n",
       "       2.500000e-01, 5.000000e-01, 1.000000e+00, 2.000000e+00,\n",
       "       4.000000e+00, 8.000000e+00, 1.600000e+01]),\n",
       "                         'logreg__penalty': ['l1']},\n",
       "             scoring=make_scorer(f1_score, response_method='predict'))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "GridSearchCV(cv=3,\n", " estimator=Pipeline(steps=[('preprocessor',\n", " ColumnTransformer(transformers=[('num',\n", " Pipeline(steps=[('scaler',\n", " StandardScaler())]),\n", " ['nb_tickets',\n", " 'nb_purchases',\n", " 'total_amount',\n", " 'nb_suppliers',\n", " 'vente_internet_max',\n", " 'purchase_date_min',\n", " 'purchase_date_max',\n", " 'nb_tickets_internet',\n", " 'nb_campaigns',\n", " 'nb_campaigns_opened']),\n", " ('cat',\n", " Pipeline(steps=[(...\n", " 1.0: 3.486549107420539},\n", " max_iter=5000,\n", " solver='saga'))]),\n", " param_grid={'logreg__C': array([9.765625e-04, 1.953125e-03, 3.906250e-03, 7.812500e-03,\n", " 1.562500e-02, 3.125000e-02, 6.250000e-02, 1.250000e-01,\n", " 2.500000e-01, 5.000000e-01, 1.000000e+00, 2.000000e+00,\n", " 4.000000e+00, 8.000000e+00, 1.600000e+01]),\n", " 'logreg__penalty': ['l1']},\n", " scoring=make_scorer(f1_score, response_method='predict'))" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logit_grid" ] }, { "cell_type": "code", "execution_count": 105, "id": "5d553da2-5c2a-491a-b4d2-f31c30c201a6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'scoring': make_scorer(f1_score, response_method='predict'),\n", " 'estimator': Pipeline(steps=[('preprocessor',\n", " ColumnTransformer(transformers=[('num',\n", " Pipeline(steps=[('scaler',\n", " StandardScaler())]),\n", " ['nb_tickets', 'nb_purchases',\n", " 'total_amount',\n", " 'nb_suppliers',\n", " 'vente_internet_max',\n", " 'purchase_date_min',\n", " 'purchase_date_max',\n", " 'nb_tickets_internet',\n", " 'nb_campaigns',\n", " 'nb_campaigns_opened']),\n", " ('cat',\n", " Pipeline(steps=[('onehot',\n", " OneHotEncoder(handle_unknown='ignore',\n", " sparse_output=False))]),\n", " ['opt_in',\n", " 'is_email_true'])])),\n", " ('logreg',\n", " LogisticRegression(class_weight={0.0: 0.5837086520288036,\n", " 1.0: 3.486549107420539},\n", " max_iter=5000, solver='saga'))]),\n", " 'n_jobs': None,\n", " 'refit': True,\n", " 'cv': 3,\n", " 'verbose': 0,\n", " 'pre_dispatch': '2*n_jobs',\n", " 'error_score': nan,\n", " 'return_train_score': False,\n", " 'param_grid': {'logreg__C': array([9.765625e-04, 1.953125e-03, 3.906250e-03, 7.812500e-03,\n", " 1.562500e-02, 3.125000e-02, 6.250000e-02, 1.250000e-01,\n", " 2.500000e-01, 5.000000e-01, 1.000000e+00, 2.000000e+00,\n", " 4.000000e+00, 8.000000e+00, 1.600000e+01]),\n", " 'logreg__penalty': ['l1']},\n", " 'multimetric_': False,\n", " 'best_index_': 5,\n", " 'best_score_': 0.42160313383818665,\n", " 'best_params_': {'logreg__C': 0.03125, 'logreg__penalty': 'l1'},\n", " 'best_estimator_': Pipeline(steps=[('preprocessor',\n", " ColumnTransformer(transformers=[('num',\n", " Pipeline(steps=[('scaler',\n", " StandardScaler())]),\n", " ['nb_tickets', 'nb_purchases',\n", " 'total_amount',\n", " 'nb_suppliers',\n", " 'vente_internet_max',\n", " 'purchase_date_min',\n", " 'purchase_date_max',\n", " 'nb_tickets_internet',\n", " 'nb_campaigns',\n", " 'nb_campaigns_opened']),\n", " ('cat',\n", " Pipeline(steps=[('onehot',\n", " OneHotEncoder(handle_unknown='ignore',\n", " sparse_output=False))]),\n", " ['opt_in',\n", " 'is_email_true'])])),\n", " ('logreg',\n", " LogisticRegression(C=0.03125,\n", " class_weight={0.0: 0.5837086520288036,\n", " 1.0: 3.486549107420539},\n", " max_iter=5000, penalty='l1',\n", " solver='saga'))]),\n", " 'refit_time_': 305.1356477737427,\n", " 'feature_names_in_': array(['nb_tickets', 'nb_purchases', 'total_amount', 'nb_suppliers',\n", " 'vente_internet_max', 'purchase_date_min', 'purchase_date_max',\n", " 'nb_tickets_internet', 'is_email_true', 'opt_in', 'gender_female',\n", " 'gender_male', 'nb_campaigns', 'nb_campaigns_opened'], dtype=object),\n", " 'scorer_': make_scorer(f1_score, response_method='predict'),\n", " 'cv_results_': {'mean_fit_time': array([ 11.07076669, 13.15744201, 27.35094929, 40.0343461 ,\n", " 94.58210254, 140.45846391, 159.83818332, 162.80178094,\n", " 163.94260454, 171.08749111, 169.26621262, 166.36741408,\n", " 167.91208776, 173.06720233, 170.93666704]),\n", " 'std_fit_time': array([ 0.09462032, 1.51362591, 6.70859141, 22.68643753, 28.72690872,\n", " 70.8434823 , 85.23159321, 79.71538593, 82.70486235, 84.79706797,\n", " 86.79005212, 84.67956107, 83.94889047, 89.68716252, 89.41361431]),\n", " 'mean_score_time': array([0.11632609, 0.10857773, 0.18140252, 0.1291213 , 0.11651532,\n", " 0.07535577, 0.12481014, 0.16039928, 0.15685773, 0.07996233,\n", " 0.12988146, 0.10067987, 0.1194102 , 0.09737802, 0.09390028]),\n", " 'std_score_time': array([0.02131792, 0.03620144, 0.05853886, 0.06555575, 0.03228018,\n", " 0.01433186, 0.03501336, 0.05466042, 0.06882891, 0.01002881,\n", " 0.00495894, 0.00905774, 0.04075337, 0.03269379, 0.01990173]),\n", " 'param_logreg__C': masked_array(data=[0.0009765625, 0.001953125, 0.00390625, 0.0078125,\n", " 0.015625, 0.03125, 0.0625, 0.125, 0.25, 0.5, 1.0, 2.0,\n", " 4.0, 8.0, 16.0],\n", " mask=[False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False],\n", " fill_value='?',\n", " dtype=object),\n", " 'param_logreg__penalty': masked_array(data=['l1', 'l1', 'l1', 'l1', 'l1', 'l1', 'l1', 'l1', 'l1',\n", " 'l1', 'l1', 'l1', 'l1', 'l1', 'l1'],\n", " mask=[False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False],\n", " fill_value='?',\n", " dtype=object),\n", " 'params': [{'logreg__C': 0.0009765625, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 0.001953125, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 0.00390625, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 0.0078125, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 0.015625, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 0.03125, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 0.0625, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 0.125, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 0.25, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 0.5, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 1.0, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 2.0, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 4.0, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 8.0, 'logreg__penalty': 'l1'},\n", " {'logreg__C': 16.0, 'logreg__penalty': 'l1'}],\n", " 'split0_test_score': array([0.27289073, 0.2738913 , 0.27382853, 0.27409759, 0.27454764,\n", " 0.27661894, 0.2766145 , 0.27584723, 0.27571682, 0.27576295,\n", " 0.27580092, 0.27577943, 0.27581248, 0.27581909, 0.27581909]),\n", " 'split1_test_score': array([0.4714244 , 0.47196015, 0.48362373, 0.48891733, 0.49066854,\n", " 0.49091122, 0.49086284, 0.49065871, 0.49062783, 0.49049541,\n", " 0.49048106, 0.49045238, 0.49043804, 0.49043804, 0.4904237 ]),\n", " 'split2_test_score': array([0.50689906, 0.50092334, 0.4981377 , 0.49759178, 0.49725836,\n", " 0.49727924, 0.49708801, 0.49738305, 0.49751781, 0.49738248,\n", " 0.49738248, 0.49738248, 0.49738248, 0.49738248, 0.49738248]),\n", " 'mean_test_score': array([0.4170714 , 0.4155916 , 0.41852999, 0.42020223, 0.42082484,\n", " 0.42160313, 0.42152178, 0.42129633, 0.42128749, 0.42121361,\n", " 0.42122149, 0.42120476, 0.421211 , 0.4212132 , 0.42120842]),\n", " 'std_test_score': array([0.10297463, 0.1008925 , 0.10249081, 0.10337226, 0.10346859,\n", " 0.10255226, 0.10249644, 0.10288467, 0.10297243, 0.10288758,\n", " 0.10286646, 0.10287015, 0.10285136, 0.10284824, 0.10284503]),\n", " 'rank_test_score': array([14, 15, 13, 12, 11, 1, 2, 3, 4, 6, 5, 10, 8, 7, 9],\n", " dtype=int32)},\n", " 'n_splits_': 3}" ] }, "execution_count": 105, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logit_grid.__dict__" ] }, { "cell_type": "code", "execution_count": 114, "id": "3573f34e-25d5-4afb-82cc-52323e2f63c6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 1.34302143, 0. , 0.02675567, 0.45036527, -0.05004637,\n", " 0.7663532 , -1.35216757, 0.17404712, 0.13679663, 0.10249737,\n", " 0.40815146, -0.6311938 , 0.11194512, -0.33498749]])" ] }, "execution_count": 114, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# coefficients trouvés pour le modèle optimal\n", "logit_grid.best_estimator_.named_steps[\"logreg\"].coef_" ] }, { "cell_type": "code", "execution_count": 125, "id": "0332a814-61fb-4b71-836a-e8ace70b1a44", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'preprocessor': ColumnTransformer(transformers=[('num',\n", " Pipeline(steps=[('scaler', StandardScaler())]),\n", " ['nb_tickets', 'nb_purchases', 'total_amount',\n", " 'nb_suppliers', 'vente_internet_max',\n", " 'purchase_date_min', 'purchase_date_max',\n", " 'nb_tickets_internet', 'nb_campaigns',\n", " 'nb_campaigns_opened']),\n", " ('cat',\n", " Pipeline(steps=[('onehot',\n", " OneHotEncoder(handle_unknown='ignore',\n", " sparse_output=False))]),\n", " ['opt_in', 'is_email_true'])]),\n", " 'logreg': LogisticRegression(C=0.03125,\n", " class_weight={0.0: 0.5837086520288036,\n", " 1.0: 3.486549107420539},\n", " max_iter=5000, penalty='l1', solver='saga')}" ] }, "execution_count": 125, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logit_grid.best_estimator_.named_steps" ] }, { "cell_type": "code", "execution_count": 116, "id": "287615b9-e062-4b84-be61-26b9364b2cf4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-0.22304234])" ] }, "execution_count": 116, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logit_grid.best_estimator_.named_steps[\"logreg\"].intercept_" ] }, { "cell_type": "code", "execution_count": 115, "id": "4d50899d-cc0b-4a71-9406-f8b0a277c4a6", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
nb_ticketsnb_purchasestotal_amountnb_suppliersvente_internet_maxpurchase_date_minpurchase_date_maxnb_tickets_internetis_email_trueopt_ingender_femalegender_malenb_campaignsnb_campaigns_opened
02.01.060.001.00.0355.268981355.2689810.0TrueFalse010.00.0
18.03.0140.001.00.0373.540289219.2622690.0TrueFalse010.00.0
22.01.050.001.00.05.2024425.2024420.0TrueFalse010.00.0
33.01.090.001.00.05.1789585.1789580.0TrueFalse010.00.0
42.01.078.001.00.05.1740395.1740390.0TrueFalse100.00.0
.............................................
2242080.00.00.000.00.0550.000000550.0000000.0TrueFalse0134.03.0
2242091.01.020.001.01.0392.501030392.5010301.0TrueFalse0123.06.0
2242100.00.00.000.00.0550.000000550.0000000.0TrueTrue018.04.0
2242111.01.097.111.01.0172.334074172.3340741.0TrueFalse0113.05.0
2242120.00.00.000.00.0550.000000550.0000000.0TrueFalse014.04.0
\n", "

224213 rows × 14 columns

\n", "
" ], "text/plain": [ " nb_tickets nb_purchases total_amount nb_suppliers \\\n", "0 2.0 1.0 60.00 1.0 \n", "1 8.0 3.0 140.00 1.0 \n", "2 2.0 1.0 50.00 1.0 \n", "3 3.0 1.0 90.00 1.0 \n", "4 2.0 1.0 78.00 1.0 \n", "... ... ... ... ... \n", "224208 0.0 0.0 0.00 0.0 \n", "224209 1.0 1.0 20.00 1.0 \n", "224210 0.0 0.0 0.00 0.0 \n", "224211 1.0 1.0 97.11 1.0 \n", "224212 0.0 0.0 0.00 0.0 \n", "\n", " vente_internet_max purchase_date_min purchase_date_max \\\n", "0 0.0 355.268981 355.268981 \n", "1 0.0 373.540289 219.262269 \n", "2 0.0 5.202442 5.202442 \n", "3 0.0 5.178958 5.178958 \n", "4 0.0 5.174039 5.174039 \n", "... ... ... ... \n", "224208 0.0 550.000000 550.000000 \n", "224209 1.0 392.501030 392.501030 \n", "224210 0.0 550.000000 550.000000 \n", "224211 1.0 172.334074 172.334074 \n", "224212 0.0 550.000000 550.000000 \n", "\n", " nb_tickets_internet is_email_true opt_in gender_female \\\n", "0 0.0 True False 0 \n", "1 0.0 True False 0 \n", "2 0.0 True False 0 \n", "3 0.0 True False 0 \n", "4 0.0 True False 1 \n", "... ... ... ... ... \n", "224208 0.0 True False 0 \n", "224209 1.0 True False 0 \n", "224210 0.0 True True 0 \n", "224211 1.0 True False 0 \n", "224212 0.0 True False 0 \n", "\n", " gender_male nb_campaigns nb_campaigns_opened \n", "0 1 0.0 0.0 \n", "1 1 0.0 0.0 \n", "2 1 0.0 0.0 \n", "3 1 0.0 0.0 \n", "4 0 0.0 0.0 \n", "... ... ... ... \n", "224208 1 34.0 3.0 \n", "224209 1 23.0 6.0 \n", "224210 1 8.0 4.0 \n", "224211 1 13.0 5.0 \n", "224212 1 4.0 4.0 \n", "\n", "[224213 rows x 14 columns]" ] }, "execution_count": 115, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# c'est la 2ème variable nb_purchases qui a été supprimée par le LASSO\n", "X_train" ] }, { "cell_type": "code", "execution_count": 122, "id": "e53b1f79-762d-4f1f-8505-91de1088af42", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "32.0" ] }, "execution_count": 122, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# best param : alpha = 32\n", "1/logit_grid.best_params_[\"logreg__C\"]" ] }, { "cell_type": "code", "execution_count": 127, "id": "41bcaaf6-ab58-4004-a3c5-586d77e872d1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy Score: 0.7589597902097902\n", "F1 Score: 0.47078982841737305\n", "Recall Score: 0.7525931336742148\n" ] } ], "source": [ "# print results for the best model\n", "\n", "y_pred = logit_grid.predict(X_test)\n", "\n", "# Calculate the F1 score\n", "acc = accuracy_score(y_test, y_pred)\n", "print(f\"Accuracy Score: {acc}\")\n", "\n", "f1 = f1_score(y_test, y_pred)\n", "print(f\"F1 Score: {f1}\")\n", "\n", "recall = recall_score(y_test, y_pred)\n", "print(f\"Recall Score: {recall}\")" ] }, { "cell_type": "code", "execution_count": 128, "id": "a454bb57-76eb-4a22-9950-0733d39e449f", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# confusion matrix \n", "\n", "draw_confusion_matrix(y_test, y_pred)" ] }, { "cell_type": "code", "execution_count": 138, "id": "25ec1701-ade5-4419-8b46-8a1bb109cf84", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# ROC curve\n", "\n", "# Calcul des taux de faux positifs (FPR) et de vrais positifs (TPR)\n", "y_pred_prob = logit_grid.predict_proba(X_test)[:, 1]\n", "\n", "fpr, tpr, thresholds = roc_curve(y_test, y_pred_prob, pos_label=1)\n", "\n", "# Calcul de l'aire sous la courbe ROC (AUC)\n", "roc_auc = auc(fpr, tpr)\n", "\n", "plt.figure(figsize = (14, 8))\n", "plt.plot(fpr, tpr, label=\"ROC curve(area = %0.3f)\" % roc_auc)\n", "plt.plot([0, 1], [0, 1], color=\"red\",label=\"Random Baseline\", linestyle=\"--\")\n", "plt.grid(color='gray', linestyle='--', linewidth=0.5)\n", "plt.xlabel('Taux de faux positifs (FPR)')\n", "plt.ylabel('Taux de vrais positifs (TPR)')\n", "plt.title('Courbe ROC : modèle logistique')\n", "plt.legend(loc=\"lower right\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }