TF里幾種loss和注意事項

昨天複習幾種常見loss的時候想起在tensorflow里使用常見loss需要注意的地方,主要是三個方法:

tf.nn.sigmoid_cross_entropy_with_logitswww.tensorflow.org

tf.nn.softmax_cross_entropy_with_logitswww.tensorflow.org

tf.nn.sparse_softmax_cross_entropy_with_logitswww.tensorflow.org

打不開的話可能需要科學上網hhhhhh

準備1、先說一下什麼是logit,logit函數定義為:

L(p)=lnfrac{p}{1-p}

是一種將取值範圍在[0,1]內的概率映射到實數域[-inf,inf]的函數,如果p=0.5,函數值為0;p<0.5,函數值為負;p>0.5,函數值為正。

相對地,softmax和sigmoid則都是將[-inf,inf]映射到[0,1]的函數。

在tensorflow里的"logits"指的其實是,該方法是在logit數值上使用softmax或者sigmoid來進行normalization的,也暗示用戶不要將網路輸出進行sigmoid或者softmax,這些過程可以在函數內部更高效地計算。

準備2、獨立和互斥

有事件A和B

獨立:P(AnB) = P(A) * P(B)

互斥:P(AUB) = P(A) + P(B), P(AnB) = 0

準備3、cross entropy loss + softmax + sigmoid

請看之前的文章,複習:常見的損失函數

1、tf.nn.sigmoid_cross_entropy_with_logits

sigmoid_cross_entropy_with_logits( _sentinel=None, labels=None, logits=None, name=None)

計算網路輸出logits和標籤labels的sigmoid cross entropy loss,衡量獨立不互斥離散分類任務的誤差,說獨立不互斥離散分類任務是因為,在這些任務中類與類之間是獨立但是不互斥的,拿多分類任務中的多目標檢測來舉例子,一張圖中可以有各種instance,比如有一隻狗和一隻貓。對於一個總共有五類的多目標檢測任務,假如網路的輸出層有5個節點,label的形式是[1,1,0,0,1]這種,1表示該圖片有某種instance,0表示沒有。那麼,每個instance在這張圖中有沒有這顯然是獨立事件,但是多個instance可以存在一張圖中,這就說明事件們並不是互斥的。所以我們可以直接將網路的輸出用作該方法的logits輸入,從而進行輸出與label的cross entropy loss。

更加直白的來說,這種網路的輸入不需要進行one hot處理,網路輸出即是函數logits參數的輸入。

剖開函數內部,因為labels和logits的形狀都是[batch_size, num_classes],那麼如何計算他們的交叉熵呢,畢竟它們都不是有效的概率分布(一個batch內輸出結果經過sigmoid後和不為1)。其實loss的計算是element-wise的,方法返回的loss的形狀和labels是相同的,也是[batch_size, num_classes],再調用reduce_mean方法計算batch內的平均loss。所以這裡的cross entropy其實是一種class-wise的cross entropy,每一個class是否存在都是一個事件,對每一個事件都求cross entropy loss,再對所有的求平均,作為最終的loss。

2、tf.nn.softmax_cross_entropy_with_logits

softmax_cross_entropy_with_logits( _sentinel=None, labels=None, logits=None, dim=-1, name=None)

計算網路輸出logits和標籤labels的softmax cross entropy loss,衡量獨立互斥離散分類任務的誤差,說獨立互斥離散分類任務是因為,在這些任務中類與類之間是獨立而且互斥的,比如VOC classification、Imagenet、CIFAR-10甚至MNIST,這些都是多分類任務,但是一張圖就對應著一個類,class在圖片中是否存在是獨立的,並且一張圖中只能有一個class,所以是獨立且互斥事件。

該函數要求每一個label都是一個有效的概率分布,對於Imagenet中的ILSVRC2012這種任務,那麼label應該就對應一個one hot編碼,ILSVRC2012提供的數據集中一共有1000個類,那麼label就應該是一個1x1000的vector,形式為[0,0,...,1,0,....0],1000個元素中有且只有一個元素是1,其餘都是0。

這樣要求的原因很簡單,因為網路的輸出要進行softmax,得到的就是一個有效的概率分布,這裡不同與sigmoid,因為sigmoid並沒有保證網路所有的輸出經過sigmoid後和為1,不是一個有效的概率分布。

有了labels和softmax後的logits,就可以計算交叉熵損失了,最後得到的是形狀為[batch_size, 1]的loss。

3、tf.nn.sparse_softmax_cross_entropy_with_logits

sparse_softmax_cross_entropy_with_logits( _sentinel=None, labels=None, logits=None, name=None)

這個版本是tf.nn.softmax_cross_entropy_with_logits的易用版本,這個版本的logits的形狀依然是[batch_size, num_classes],但是labels的形狀是[batch_size, 1],每個label的取值是從[0, num_classes)的離散值,這也更加符合我們的使用習慣,是哪一類就標哪個類對應的label。

如果已經對label進行了one hot編碼,則可以直接使用tf.nn.softmax_cross_entropy_with_logits。

4、總結:

到底是用sigmoid版本的cross entropy還是softmax版本的cross entropy主要取決於我們模型的目的,以及label的組織方式,這個需要大家在使用的時候去揣摩,到底使用哪一種loss比較合理。

在我最近訓練的segmentation模型中,使用的就是sparse softmax cross entropy,使用的思路就是將輸出的結果從NHWC(這裡C=1,表示該pixel所屬的class),進行一次reshape,形狀變為[N*H*W, 1],label也是如此,傳入函數中進行計算,從而產生loss。從模型訓練的結果來看,這種使用方法沒有錯誤。

如有錯誤還望指正。


推薦閱讀:

一個模型庫學習所有:谷歌開源模塊化深度學習系統Tensor2Tensor
手把手教你用TensorFlow實現看圖說話|教程+代碼
作為深度學習最強框架的TensorFlow如何進行時序預測!
tf.get_variable
TensorFlow 教程 #04 - 保存 &amp; 恢復

TAG:深度學習DeepLearning | TensorFlow |