add calibration curve

This commit is contained in:
Alexis REVELLE 2024-03-18 09:10:28 +00:00
parent 6eddec93bc
commit 5408ce677b

View File

@ -324,10 +324,24 @@
" plt.plot(fpr, tpr, label=\"ROC curve(area = %0.3f)\" % roc_auc)\n",
" plt.plot([0, 1], [0, 1], color=\"red\",label=\"Random Baseline\", linestyle=\"--\")\n",
" plt.grid(color='gray', linestyle='--', linewidth=0.5)\n",
" plt.xlabel('Taux de faux positifs (FPR)')\n",
" plt.ylabel('Taux de vrais positifs (TPR)')\n",
" plt.title('Courbe ROC : modèle logistique')\n",
" plt.xlabel(\"False Positive Rate\")\n",
" plt.ylabel(\"True Positive Rate\")\n",
" plt.title(\"ROC Curve\", size=18)\n",
" plt.legend(loc=\"lower right\")\n",
" plt.show()\n",
"\n",
"\n",
"def draw_calibration_curve(X_test, y_test):\n",
" y_pred_prob = pipeline.predict_proba(X_test)[:, 1]\n",
" frac_pos, mean_pred = calibration_curve(y_test, y_probs_bs, n_bins=10)\n",
"\n",
" # Plot the calibration curve\n",
" plt.plot(mean_pred, frac_pos, 's-', label='Logistic Regression')\n",
" plt.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')\n",
" plt.xlabel('Mean predicted value')\n",
" plt.ylabel('Fraction of positive predictions')\n",
" plt.title(\"Calibration Curve\")\n",
" plt.legend()\n",
" plt.show()"
]
},
@ -1552,6 +1566,16 @@
"draw_prob_distribution(X_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7ee0972-79ac-481e-a370-d71b085a3c27",
"metadata": {},
"outputs": [],
"source": [
"draw_calibration_curve(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"id": "ae8e9bd3-0f6a-4f82-bb4c-470cbdc8d6bb",