はじめに こんにちは。BASEのDataStrategyチームで機械学習を触っている竹内です。 機械学習といえばLLMやDiffusionモデルなど生成モデルの発展が目覚ましい昨今ですが、その一方で構造化データに対して特徴量エンジニアリングを行い、CVを切って、LightGBMなどの便利な決定木ベースのフレームワークに投げて、できたモデルの出力を吟味し、時には致命的なリークに気付き頭を抱えるといった王道のアプローチは相変わらず現役で、実務に関していえば当分お世話になる機会が減ることはないかなという気がしています。 今回はそういったクラス分類モデルにおける性能の評価指標の1つである、ROC曲線やAUC、PR曲線といった概念について振り返りつつ、実務上しばしば見られる不均衡データに適用する際の注意点などについて、軽いシミュレーションも交えつつ掘り下げてみようと思います。 ROC曲線 2クラス分類器の性能を評価する際によく使用されるツールとしてROC Operating Characteristic)曲線が挙げられます。 ROC曲線は縦軸に再現率(Recall)、横軸に偽陽性率(False Positive Rate)をとり、それぞれの検出閾値に対する再現率、偽陽性率をプロットすることで得られる曲線です。ROC曲線の下側の領域の面積をAUC(Area Under the Curve)と呼び、その面積の大きさによって分類器の性能の高さ、つまりその分類器がどの程度うまく正例と負例を分離できているかを評価することができます。 同じデータセットにおける異なる2つのモデルの性能を比較する際や、再現率と偽陽性率のトレードオフを考慮した適切な閾値を決定する際にROC曲線のプロットは有効な選択肢の1つとなります。 model_1よりmodel_2の方がAUCが高いため、model_2の方が性能が高いと判断できる。 不均衡データにおいてROC曲線を扱う際の注意点 実務において2クラス分類器を適用する例はいろいろありますが、例えば「ある工場の部品生産ラインにおいて、ごく稀に発生する不良品を検知する」といった不良品検知や異常検知、不正検知などの文脈においては、正例に対して負例のサイズが極端に大きい、いわゆる不均衡データを扱うことがしばしばあります。 そうした極端に不均衡なデータにおいては、ROC曲線におけるAUCの指標がうまく機能しないことがあります。 具体的にはROC曲線は適合率(Precision)の指標を含んでいないため、「正しく検出できた正例に対して、巻き込んで検出してしまった負例がどれだけ多いか」という点が考慮されず、結果として非常に多くの偽陽性が発生してしまっていてもAUCは高い値を示してしまうことがあります。 先ほどの工場の不良品検知の例で言えば「AUCの値はほぼ1であるにも関わらず、稼働させてみたら不良品の見逃しは少ないものの、大量の偽陽性が発生してしまうことで、後続の目視チェックの工程をパンクさせてしまう。」といった問題を生じさせてしまう可能性があります。 また、しばしばそうした不均衡データにおいてはある程度の性能であればどのモデルのAUCもほぼ1に近い値を取るようなROC曲線がプロットされてしまい、異なるモデルを比較することが難しいこともあります。 PR曲線 ROC曲線と同様な概念としてしばしば用いられることがあるPR(Precision-Recall)曲線は、縦軸に再現率(Recall)、横軸に適合率(Precision)をとったものです。こちらはROC曲線とは異なり、横軸に適合率を採用しているため「正しく検出できた正例に対して、巻き込んで検出してしまった負例がどれだけ多いか」を反映することができます。 そのため、ROC曲線における高いAUCが観測できていても、データに不均衡性がみられる場合はPR曲線もプロットしてみることでモデルの性能に関する示唆を得られる可能性があります。また、ROC曲線におけるAUCについては差がほとんど見られない2つのモデルの性能を比較する場合においても、PR曲線における改善がみられるケースもあります。 データに不均衡性がみられる場合はROC曲線だけでなくPR曲線のプロットも見ておくと良いかもしれません。 実験 簡単なシミュレーションを行うことで、不均衡データにおけるROC曲線とPR曲線の性質について軽く検証してみます。 まず、正例は平均3、負例は平均1で分散は共に1の正規分布にしたがう特徴量Xもつことを既知とした上で、Xが閾値以上のものを正例、閾値以下のものを負例として分類するモデルを考えます。 その上で、正例のデータは1000件に固定し、負例の方は正例の件数の1倍、10倍、100倍、1000倍の4つのパターンで生成し、このモデルの性能を評価することを考えます。 上図は生成されたデータのヒストグラムであり、負例が多くなると潰れて見えなくなっていますが、正例のデータは4パターンで全て同じものを使用しています。 緑色の点線は閾値を示しており、モデルはこれより右側のものを正例、左側のものを負例として識別することになります。 図では閾値1.5の例を示していますが、ROC曲線やPR曲線をプロットする際はこの緑色の点線を左端から右端へ移動させた際の各指標を計算することになります。 # コード例 import numpy as np import matplotlib.pyplot as plt # 正規分布に従う乱数を生成 ratios = [ 1 , 10 , 100 , 1000 ] sample_size_positives = 1000 positives = np.random.normal( 3 , 1 , sample_size_positives) dataset = {} for r in ratios: sample_size_negatives = sample_size_positives * r negatives = np.random.normal( 0 , 1 , sample_size_negatives) dataset[r] = (sample_size_negatives, positives, negatives) # ratio=1の場合 ratio = 1 _, positives, negatives = dataset[ratio] fig, ax = plt.subplots() ax.hist(positives, bins= 100 , color= "r" , alpha= 0.5 , label= "positives" ) ax.hist(negatives, bins= 100 , color= "b" , alpha= 0.5 , label= "negatives" ) ax.axvline(x=threshold_example, color= "g" , linestyle= "--" , label=f "{threshold_example=}" ) ax.legend(loc= "upper right" ) ax.title.set_text(f "{ratio=}" ) ちなみに、グラフにおける各指標のイメージは以下のようなものとなります。 再現率(Recall) = (点線右側の赤色部分の面積) / (赤色部分全体の面積) 偽陽性率(False Positive Rate) = (点線右側の青色部分の面積) / (青色部分全体の面積) 適合率(Precision) = (点線右側の赤色部分の面積) / {(点線右側の赤色部分の面積) + (点線右側の青色部分の面積)} まず、4つのパターンについて、ROC曲線をプロットしてみます。 ROC曲線はどのパターンについてもほぼ同じで、左上に張り付くような形をとっており、AUCはほぼ1に近い値を取っていることがわかります。つまり、ROC曲線はデータの不均衡性の影響をほぼ反映していないことがわかります。 上図の赤い点はそれぞれの曲線上での閾値1.5における点を示したもの # コード例 # ROC curve threshols = np.linspace(- 5 , 5 , 1000 ) threshold_example = 1.5 fig, ax = plt.subplots() ax.plot([ 0 , 1 ], [ 0 , 1 ], linestyle= "--" , color= "g" , label= "random" ) for i, r in enumerate (ratios): recalls = [] fprs = [] sample_size_negatives, positives, negatives = dataset[r] for j in range ( len (threshols)): recalls.append(np.sum(positives > threshols[j]) / sample_size_positives) fprs.append(np.sum(negatives > threshols[j]) / (sample_size_negatives)) ax.scatter(fprs, recalls, color=plot_colors[i], label=f "ratio={r}" , s= 3 ) # 閾値threshold_exampleの時のrecallとfpr ax.scatter( [np.sum(negatives > threshold_example) / sample_size_negatives], [np.sum(positives > threshold_example) / sample_size_positives], color= "r" , s= 15 ) ax.legend(loc= "lower right" ) plt.xlabel( "FPR" ) plt.ylabel( "Recall" ) 次に4つのパターンについてPR曲線をプロットしてみます。 PR曲線についてはROC曲線とは異なり、4つのパターンそれぞれで全く異なる曲線がプロットされており、ROC曲線と比較してデータの不均衡性の影響を反映しやすいことがわかります。 図の赤い点における同じ閾値1.5で見てみると再現率は全てのパターンで同じ値を取っていますが、適合率においてはデータが不均衡になるにつれて低い値を取ることがわかります。 上図の赤い点はそれぞれの曲線上での閾値1.5における点を示したもの # コード例 # PR curve fig, ax = plt.subplots() for i, r in enumerate (ratios): recalls = [] precisions = [] sample_size_negatives, positives, negatives = dataset[r] for j in range ( len (threshols)): recalls.append(np.sum(positives > threshols[j]) / sample_size_positives) precisions.append(np.sum(positives > threshols[j]) / (np.sum(positives > threshols[j]) + np.sum(negatives > threshols[j]))) ax.scatter(precisions, recalls, color=plot_colors[i], label=f "ratio={r}" , s= 3 ) # 閾値1の時の適合率と再現率 plt.scatter( [np.sum(positives > threshold_example) / (np.sum(positives > threshold_example) + np.sum(negatives > 1.5 ))], [np.sum(positives > threshold_example) / sample_size_positives], color= "r" , s= 15 ) ax.legend(loc= "lower right" ) plt.xlabel( "Precision" ) plt.ylabel( "Recall" ) 最後に 今回は2クラス分類における、データの不均衡性とROC曲線およびPR曲線の関係性について掘り下げてみました。 一般的にはモデルの性能評価にはROC曲線を用いることが多い気がしますが、PR曲線の存在も頭に入れておくと、不均衡なデータを扱う際に役に立つかもしれません。 一応適合率を扱う際の注意点を挙げておくと、検証用データにおける負例が本番のデータからダウンサンプリングされている場合、算出される適合率がそのサンプリング比率に依存してしまう点があります。 例えば負例が本番のデータから1/rの件数だけランダムサンプリングされている場合、実際の適合率へ補正する計算式としては、検証用データにおける偽陽性(False Positives)の件数をN(FP)などとして といった形が考えられます。負例が圧倒的に多いケースでは、状況に応じて分母のN(TP)を0に近似してしまって、検証用データにおける適合率に1/rをかけた値を真の適合率としてしまっても良いかもしれません。 また、実務においては偽陽性や偽陰性の数だけで評価するのではなく、その質や誤分類コストの非対称性などにも注意を払うことが重要です。不良品検知の例でいえば、偽陰性の中に深刻な欠陥を見逃している例がないか、不正決済の検知であれば大きな損害をもたらす決済を見逃している例がないか、といったサンプルごとに異なる誤分類コストについても十分吟味する必要があります。 最後となりますが、DataStragetyチームではBASEにおけるデータ分析基盤の改善を一緒に行っていくメンバーを募集しています。ご興味のある方は気軽にご応募ください! A-1.Tech_データエンジニア / BASE株式会社 明日のBASEアドベントカレンダーは @h7jin16 さんと @u_hayato13 さんの記事です。お楽しみに! 参考資料 Wikipedia,Receiver operating characteristic https://en.wikipedia.org/wiki/Receiver_operating_characteristic scikit-learn,precision_recall_curve https://scikit-learn.org/1.5/modules/generated/sklearn.metrics.precision_recall_curve.html Google,Machine Learning Crash Course,roc-and-auc https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc?hl=en