標籤:

AUC的理解與計算

AUC的理解與計算

來自專欄計算廣告演算法4 人贊了文章

AUC 理解與計算

AUC 的全稱是 AreaUnderRoc 即 Roc 曲線與坐標軸形成的面積,取值範圍 [0, 1].

ROC

Roc (Receiver operating characteristic) 曲線是一種二元分類模型分類效果的分析工具。首先需要知道如下定義:

Roc 空間將偽陽性率(FPR)定義為 X 軸,真陽性率(TPR)定義為 Y 軸。

  • TPR: 在所有實際為陽性的樣本中,被正確地判斷為陽性之比率 TPR = TP/P = TP/(TP+FN)
  • FPR: 在所有實際為陰性的樣本中,被錯誤地判定為陽性之比率 FPR = FP/N = FP/(FP + TN)

給定一個二元分類模型和它的閾值,就能從所有樣本的(陽性/陰性)真實值和預測值計算出一個 (X=FPR, Y=TPR) 座標點。

從 (0, 0) 到 (1,1) 的對角線將ROC空間劃分為左上/右下兩個區域,在這條線的以上的點代表了一個好的分類結果(勝過隨機分類),而在這條線以下的點代表了差的分類結果(劣於隨機分類)。

完美的預測是一個在左上角的點,在ROC空間座標 (0,1)點,X=0 代表著沒有偽陽性,Y=1 代表著沒有偽陰性(所有的陽性都是真陽性);也就是說,不管分類器輸出結果是陽性或陰性,都是100%正確。一個隨機的預測會得到位於從 (0, 0) 到 (1, 1) 對角線(也叫無識別率線)上的一個點;最直觀的隨機預測的例子就是拋硬幣。

AUC

AUC 最普遍的定義是 ROC 曲線下的面積。但其實另一種定義更常用,分別隨機從正負樣本集中抽取一個正樣本,一個負樣本,正樣本的預測值大於負樣本的概率。

後一個定義可以從前一個定義推導出來,有興趣的可以看下 Wilcoxon-Mann-Witney Test。

AUC 的計算

按照定義

分別隨機從政府樣本集中抽取一個正負樣本,正樣本的預測值大於負樣本的概率。

根據古典概率模型

分母是正負樣本總的組合數,分子是正樣本大於負樣本的組合數

// 預測值 + 標籤case class LabeledPred(predict: Double, label: Int)def auc(points: Seq[LabeledPred]) = { val posNum = points.count(_.label > 0) val negNum = points.length - posNum if (posNum == 0 || negNum == 0) { println("Error: Lables of all samples are the same.") 0.0 } else { val sorted = points.sortBy(_.predict) var negSum = 0 var posGTNeg = 0 for (p <- sorted) { if (p.label > 0) { posGTNeg = posGTNeg + negSum } else { negSum = negSum + 1 } } posGTNeg.toDouble/(posNum * negNum).toDouble }}

AUC 的局限性

AUC作為排序的評價指標本身具有一定的局限性,它衡量的是整體樣本間的排序能力,對於計算廣告領域來說,它衡量的是不同用戶對不同廣告之間的排序能力,而線上環境往往需要關注同一個用戶的不同廣告之間的排序能力。

所以,阿里在 Deep Interest Network中提到一種改進版本的 AUC 指標,用戶加權平均AUC(gAUC)更能反映線上真實環境的排序能力。

直接貼代碼

// 預測值 + 標籤 + 用戶組 case class LabeledPredDev(device: String, labeledPred: LabeledPred) /** * gAUC, 用戶加權平均 AUC * @param points * @return */ def gAuc(points: Seq[LabeledPredDev]) = { val userMap = points.groupBy(_.device) .filterNot(_._2.forall(_.labeledPred.label > 0)) // 去掉全部記錄為正例的用戶 .filterNot(_._2.forall(_.labeledPred.label < 1)) // 去掉全部記錄為負例的用戶 val userCntMap = userMap.mapValues(_.size) val userAucMap = userMap.mapValues(tmp => auc(tmp.map(_.labeledPred))) val cntAucPairs = for (key <- userCntMap.keys.toSeq) yield (userCntMap.getOrElse(key, 0), userAucMap.getOrElse(key, 0.0)) val sumImprs = cntAucPairs.map(_._1).sum val weightedAuc = cntAucPairs.foldLeft(0.0){ (sum, pair) => { pair._1 * pair._2 + sum } } weightedAuc.toDouble/sumImprs.toDouble }

推薦閱讀:

CTR點擊率預估
深度學習在CTR預估中的應用

TAG:機器學習 | ctr |