ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Cross validation에서 ROC AUC 구하기
    라이브러리/Scikit-learn 2021. 2. 14. 20:44

    StratifiedKFold을 통해 cross validation할 데이터 생성

    from sklearn.linear_model import LogisticRegression
    from sklearn.model_selection import cross_val_score, StratifiedKFold
    cv = StratifiedKFold(n_splits=5,shuffle=False)

    각 cv별 fpr(False Positive Rate), tpr(True Positive Rate) 및 auc(Area under the ROC Curve) 계산 및 평균 계산

    logit = LogisticRegression(fit_intercept=True)
    
    tprs = []
    aucs = []
    mean_fpr_lr = np.linspace(0, 1, 100)
    i = 1
    for train, test in cv. split(X, y):
        prediction = logit_clf.fit(X.iloc[train], y.iloc[train]).predict_proba(X.iloc[test])
        fpr, tpr, t = roc_curve(y.iloc[test], prediction[:, 1])
        tprs.append(interp(mean_fpr_lr, fpr, tpr))
        roc_auc = auc(fpr, tpr)
        aucs.append(roc_auc)
    mean_tpr_lr = np.mean(tprs, axis=0)
    mean_auc_lr = auc(mean_fpr_lr, mean_tpr_lr)

    ... 외 다른 모델 생성한 뒤 이에 대해 plot.

    plt.figure(figsize=(10,10))
    plt.plot(mean_fpr_lr, mean_tpr_lr, label=r'Mean ROC from Logistic Regression (AUC = %0.5f )' % (mean_auc_lr),lw=2, alpha=1)
    plt.plot(mean_fpr_dt, mean_tpr_dt, label=r'Mean ROC from Decision Tree (AUC = %0.5f )' % (mean_auc_dt),lw=2, alpha=1)
    plt.plot(mean_fpr_rm, mean_tpr_rm, label=r'Mean ROC from Random Forest (AUC = %0.5f )' % (mean_auc_rm),lw=2, alpha=1)
    plt.plot(mean_fpr_lgb_0, mean_tpr_lgb_0, label=r'Mean ROC from LightGBM (AUC = %0.5f )' % (mean_auc_lgb_0),lw=2, alpha=1)
    plt.plot(mean_fpr_lgb_ft, mean_tpr_lgb_ft, label=r'Mean ROC from Fine Tuned LightGBM(AUC = %0.5f )' % (mean_auc_lgb_ft),lw=2, alpha=1)
    
    
    plt.xlabel('False Positive Rate', fontsize=15)
    plt.ylabel('True Positive Rate', fontsize=15)
    plt.title("Receiver Operating Characteristic", fontsize=25)
    plt.plot([0,1],[0,1], 'k--')
    plt.legend(loc="lower right", fontsize=15);

     

    참고: www.kaggle.com/kanncaa1/roc-curve-with-k-fold-cv

    댓글

Designed by Tistory.