라이브러리/Scikit-learn

Cross validation에서 ROC AUC 구하기

rongxian 2021. 2. 14. 20:44

StratifiedKFold을 통해 cross validation할 데이터 생성

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, StratifiedKFold
cv = StratifiedKFold(n_splits=5,shuffle=False)

각 cv별 fpr(False Positive Rate), tpr(True Positive Rate) 및 auc(Area under the ROC Curve) 계산 및 평균 계산

logit = LogisticRegression(fit_intercept=True)

tprs = []
aucs = []
mean_fpr_lr = np.linspace(0, 1, 100)
i = 1
for train, test in cv. split(X, y):
    prediction = logit_clf.fit(X.iloc[train], y.iloc[train]).predict_proba(X.iloc[test])
    fpr, tpr, t = roc_curve(y.iloc[test], prediction[:, 1])
    tprs.append(interp(mean_fpr_lr, fpr, tpr))
    roc_auc = auc(fpr, tpr)
    aucs.append(roc_auc)
mean_tpr_lr = np.mean(tprs, axis=0)
mean_auc_lr = auc(mean_fpr_lr, mean_tpr_lr)

... 외 다른 모델 생성한 뒤 이에 대해 plot.

plt.figure(figsize=(10,10))
plt.plot(mean_fpr_lr, mean_tpr_lr, label=r'Mean ROC from Logistic Regression (AUC = %0.5f )' % (mean_auc_lr),lw=2, alpha=1)
plt.plot(mean_fpr_dt, mean_tpr_dt, label=r'Mean ROC from Decision Tree (AUC = %0.5f )' % (mean_auc_dt),lw=2, alpha=1)
plt.plot(mean_fpr_rm, mean_tpr_rm, label=r'Mean ROC from Random Forest (AUC = %0.5f )' % (mean_auc_rm),lw=2, alpha=1)
plt.plot(mean_fpr_lgb_0, mean_tpr_lgb_0, label=r'Mean ROC from LightGBM (AUC = %0.5f )' % (mean_auc_lgb_0),lw=2, alpha=1)
plt.plot(mean_fpr_lgb_ft, mean_tpr_lgb_ft, label=r'Mean ROC from Fine Tuned LightGBM(AUC = %0.5f )' % (mean_auc_lgb_ft),lw=2, alpha=1)


plt.xlabel('False Positive Rate', fontsize=15)
plt.ylabel('True Positive Rate', fontsize=15)
plt.title("Receiver Operating Characteristic", fontsize=25)
plt.plot([0,1],[0,1], 'k--')
plt.legend(loc="lower right", fontsize=15);

 

참고: www.kaggle.com/kanncaa1/roc-curve-with-k-fold-cv