10 examples of 'plot roc curve sklearn' in Python

``109def plot_roc(y_test, y_pred, label=''):110    """Compute ROC curve and ROC area"""111112    fpr, tpr, _ = roc_curve(y_test, y_pred)113    roc_auc = auc(fpr, tpr)114115    # Plot of a ROC curve for a specific class116    plt.figure()117    plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)118    plt.plot([0, 1], [0, 1], 'k--')119    plt.xlim([0.0, 1.0])120    plt.ylim([0.0, 1.05])121    plt.xlabel('False Positive Rate')122    plt.ylabel('True Positive Rate')123    plt.title('Receiver operating characteristic' + label)124    plt.legend(loc="lower right")125    plt.show()``
``21def ROC_plot(features,X_,y_, pred_,title):22    fpr_, tpr_, thresholds = roc_curve(y_, pred_)23    optimal_idx = np.argmax(tpr_ - fpr_)24#https://stackoverflow.com/questions/28719067/roc-curve-and-cut-off-point-python25    optimal_threshold = thresholds[optimal_idx]26    auc_ = auc(fpr_, tpr_)27    title = "{} auc=".format(title)28    print("{} auc={} OT={:.4g}".format(title, auc_,optimal_threshold))29    plt.plot(fpr_, tpr_, label="{}:{:.4g}".format(title, auc_))30    plt.xlabel('False positive rate')31    plt.ylabel('True positive rate')32    plt.title('SMPLEs={} Features={} OT={:.4g}'.format(X_.shape[0],len(features),optimal_threshold))33    plt.legend(loc='best')34    plt.savefig("./_auc_[{}].jpg".format(features))35    plt.show()36    return auc_,optimal_threshold``
``26def compute_and_plot_roc_scores(y_test, y_test_probas, num_class, path=None):27    """Compute ROC statistics and plot ROC curves.2829    Arguments:30        y_test: [int]31            list of test class labels as integer indices32        y_test_probas: np.ndarray, float33            array of predicted probabilities with shape34            (num_sample, num_class)35        num_class: int36            number of classes37        path: string38            filepath where to save the ROC curve plot; if None will not perform39            plotting4041    Returns:42        roc_auc_dict: {int: float}43            dictionary mapping classes to ROC AUC scores44        fpr_dict: {string: np.ndarray}45            dictionary mapping names of classes or an averaging method to46            arrays of increasing false positive rates47        tpr_dict: {string: float}48            dictionary mapping names of classes or an averaging method to49            arrays of increasing true positive rates50    """51    roc_auc_dict, fpr_dict, tpr_dict = _compute_roc_stats(y_test, y_test_probas,52                                                          num_class)5354    if path is not None:55        _create_roc_plot(roc_auc_dict, fpr_dict, tpr_dict, num_class, path)5657    return roc_auc_dict, fpr_dict, tpr_dict``
``220def precision_recall_curve(clf, x_test, y_test):221    from sklearn.metrics import precision_recall_curve222    223    for i in range(2):224        y_probabilities = [x[i] for x in clf.predict_proba(x_test)]225        precision, recall, thresholds = precision_recall_curve(y_test, y_probabilities)226227        plt.title('Precision Recall Curve')228        plt.plot(recall, precision, 'b')229230    plt.show()``
``264def compute_ROC_lrw_multiclass(correct_word_idx,265                               val_confidences, test_confidences,266                               savePlot=False, showPlot=False,267                               plot_title='ROC curve'):268    # P =&gt; argmax conf is highest confidence269    print("1/{0:01d} Computing val ROC...".format(4 if (savePlot or showPlot) else 2))270    val_fpr, val_tpr, val_roc_auc = \271        compute_ROC_multiclass(label_binarize(correct_word_idx,272                                              classes=np.arange(val_confidences.shape[1])),273                               val_confidences)274    print("2/{0:01d} Computing test ROC...".format(4 if (savePlot or showPlot) else 2))275    test_fpr, test_tpr, test_roc_auc = \276        compute_ROC_multiclass(label_binarize(correct_word_idx,277                                              classes=np.arange(test_confidences.shape[1])),278                               test_confidences)279    if showPlot or savePlot:280        print("3/{0:01d} Computing val ROC operating point...".format(4 if (savePlot or showPlot) else 2))281        val_OP_fpr, val_OP_tpr = get_multiclass_ROC_operating_point(correct_word_idx,282            val_confidences)283        print("4/{0:01d} Computing test ROC operating point...".format(4 if (savePlot or showPlot) else 2))284        test_OP_fpr, test_OP_tpr = get_multiclass_ROC_operating_point(correct_word_idx,285            test_confidences)286        plt.plot(val_fpr['micro'], val_tpr['micro'], color='C0', linestyle=':', linewidth=3, label='val_micro; AUC={0:0.4f}'.format(val_roc_auc['micro']))287        plt.plot(val_fpr['macro'], val_tpr['macro'], color='C0', linestyle='--', label='val_macro; AUC={0:0.4f}'.format(val_roc_auc['macro']))288        plt.plot(val_OP_fpr, val_OP_tpr, color='C0', marker='x', markersize=10)289        plt.plot(test_fpr['micro'], test_tpr['micro'], color='C1', linestyle=':', linewidth=3, label='test_micro; AUC={0:0.4f}'.format(test_roc_auc['micro']))290        plt.plot(test_fpr['macro'], test_tpr['macro'], color='C1', linestyle='--', label='test_macro; AUC={0:0.4f}'.format(test_roc_auc['macro']))291        plt.plot(test_OP_fpr, test_OP_tpr, color='C1', marker='x', markersize=10)292        plt.legend(loc='lower right')293        plt.xlabel('False positive rate')294        plt.ylabel('True positive rate')295        plt.title(plot_title)296    if savePlot:297        plt.savefig('a.png')298    if showPlot:299        plt.show()300    if showPlot or savePlot:301        plt.close()302    return val_fpr, val_tpr, val_roc_auc, val_OP_fpr, val_OP_tpr, test_fpr, test_tpr, test_roc_auc, test_OP_fpr, test_OP_tpr``
``46def plot_roc(score_list, save_dir, plot_name):4748    save_path = os.path.join(save_dir, plot_name + ".jpg")49    # 按照 score 排序50    threshold_value = sorted([score for score, _ in score_list])5152    threshold_num = len(threshold_value)53    accracy_array = np.zeros(threshold_num)54    precision_array = np.zeros(threshold_num)55    TPR_array = np.zeros(threshold_num)56    TNR_array = np.zeros(threshold_num)57    FNR_array = np.zeros(threshold_num)58    FPR_array = np.zeros(threshold_num)5960    # calculate all the rates61    for thres in range(threshold_num):62        accracy, precision, TPR, TNR, FNR, FPR = cal_rate(score_list, threshold_value[thres])63        accracy_array[thres] = accracy64        precision_array[thres] = precision65        TPR_array[thres] = TPR66        TNR_array[thres] = TNR67        FNR_array[thres] = FNR68        FPR_array[thres] = FPR6970    AUC = np.trapz(TPR_array, FPR_array)71    threshold = np.argmin(abs(FNR_array - FPR_array))72    EER = (FNR_array[threshold] + FPR_array[threshold]) / 273    # print('EER : %f AUC : %f' % (EER, -AUC))74    plt.plot(FPR_array, TPR_array)7576    plt.title('ROC')77    plt.xlabel('FPR')78    plt.ylabel('TPR')79    plt.text(0.2, 0, s="EER :{} AUC :{} Threshold:{}".format(round(EER, 4), round(-AUC, 4),80                                                             round(threshold_value[threshold], 4)), fontsize=10)81    plt.legend()82    plt.savefig(save_path)83    plt.show()``
``302def get_roc_curve(model, data, thread_count=-1, plot=False):303    """304    Build points of ROC curve.305306    Parameters307    ----------308    model : catboost.CatBoost309        The trained model.310311    data : catboost.Pool or list of catboost.Pool312        A set of samples to build ROC curve with.313314    thread_count : int (default=-1)315        Number of threads to work with.316        If -1, then the number of threads is set to the number of CPU cores.317318    plot : bool, optional (default=False)319        If True, draw curve.320321    Returns322    -------323    curve points : tuple of three arrays (fpr, tpr, thresholds)324    """325    if type(data) == Pool:326        data = [data]327    if not isinstance(data, list):328        raise CatBoostError('data must be a catboost.Pool or list of pools.')329    for pool in data:330        if not isinstance(pool, Pool):331            raise CatBoostError('one of data pools is not catboost.Pool')332333    roc_curve = _get_roc_curve(model._object, data, thread_count)334335    if plot:336        with _import_matplotlib() as plt:337            _draw(plt, roc_curve[0], roc_curve[1], 'False Positive Rate', 'True Positive Rate', 'ROC Curve')338339    return roc_curve``
``499def test_roc(self):500    self.lr_roc_curve(self.max_lr_f1(True))``
``191def plot_roc_auc_per_class(self):192    """193    Plot the ROC AUC per class as a barplot.194    """195    self.per_class_metrics_list[0] = sorted(self.per_class_metrics_list[0], key=lambda x: -float(x['ROC_auc']))196    fig, ax = plt.subplots()197198    ax.bar(x=list(range(len(self.per_class_metrics_list[0]))),199           height=[float(x['ROC_auc']) for x in self.per_class_metrics_list[0]],200           width=1,201           color=colors['blue'],202           alpha=0.7203           )204    ax.set_ylabel('ROC AUC')205    ax.set_xlabel('Class')206    ax.set_title('ROC per Class')207    plt.savefig(os.path.join(self.plot_path, 'roc_auc_per_class_{}{}.png'.format(self.step, self.early)))208    plt.close()``
``98def plot_auc(self):99    if self.n_classes != 2:100        display("plot_auc() not yet implemented for multiclass classifiers")101        return None102103    # Move binarized to classifier104    y_true_binarized = label_binarize(self.y_true, classes=self.classes)105    y_pred_binarized = 1 - self.y_pred_proba106107    y_true_binarized = np.hstack((y_true_binarized, 1 - y_true_binarized))108    y_pred_binarized = np.hstack((y_pred_binarized, 1 - y_pred_binarized))109110    fig = plt.figure()111112    fpr = dict()113    tpr = dict()114    roc_auc = dict()115    for i in range(self.n_classes):116        fpr[i], tpr[i], _ = sklearn.metrics.roc_curve(117            y_true_binarized[:, i], y_pred_binarized[:, i]118        )119        roc_auc[i] = sklearn.metrics.auc(fpr[i], tpr[i])120121        # return roc_auc122        self._plot_auc_label(fig, fpr[i], tpr[i], roc_auc[i], i)123124    display(HTML("<h2>AUC Plot</h2>"))125    display(fig)``