TF里幾種loss和注意事項
昨天複習幾種常見loss的時候想起在tensorflow里使用常見loss需要注意的地方,主要是三個方法:
tf.nn.sigmoid_cross_entropy_with_logitstf.nn.softmax_cross_entropy_with_logitstf.nn.sparse_softmax_cross_entropy_with_logits打不開的話可能需要科學上網hhhhhh
準備1、先說一下什麼是logit,logit函數定義為:
是一種將取值範圍在[0,1]內的概率映射到實數域[-inf,inf]的函數,如果p=0.5,函數值為0;p<0.5,函數值為負;p>0.5,函數值為正。
相對地,softmax和sigmoid則都是將[-inf,inf]映射到[0,1]的函數。
在tensorflow里的"logits"指的其實是,該方法是在logit數值上使用softmax或者sigmoid來進行normalization的,也暗示用戶不要將網路輸出進行sigmoid或者softmax,這些過程可以在函數內部更高效地計算。
準備2、獨立和互斥
有事件A和B
獨立:P(AnB) = P(A) * P(B)
互斥:P(AUB) = P(A) + P(B), P(AnB) = 0
準備3、cross entropy loss + softmax + sigmoid
請看之前的文章,複習:常見的損失函數
1、tf.nn.sigmoid_cross_entropy_with_logits
sigmoid_cross_entropy_with_logits( _sentinel=None, labels=None, logits=None, name=None)
計算網路輸出logits和標籤labels的sigmoid cross entropy loss,衡量獨立不互斥離散分類任務的誤差,說獨立不互斥離散分類任務是因為,在這些任務中類與類之間是獨立但是不互斥的,拿多分類任務中的多目標檢測來舉例子,一張圖中可以有各種instance,比如有一隻狗和一隻貓。對於一個總共有五類的多目標檢測任務,假如網路的輸出層有5個節點,label的形式是[1,1,0,0,1]這種,1表示該圖片有某種instance,0表示沒有。那麼,每個instance在這張圖中有沒有這顯然是獨立事件,但是多個instance可以存在一張圖中,這就說明事件們並不是互斥的。所以我們可以直接將網路的輸出用作該方法的logits輸入,從而進行輸出與label的cross entropy loss。
更加直白的來說,這種網路的輸入不需要進行one hot處理,網路輸出即是函數logits參數的輸入。
剖開函數內部,因為labels和logits的形狀都是[batch_size, num_classes],那麼如何計算他們的交叉熵呢,畢竟它們都不是有效的概率分布(一個batch內輸出結果經過sigmoid後和不為1)。其實loss的計算是element-wise的,方法返回的loss的形狀和labels是相同的,也是[batch_size, num_classes],再調用reduce_mean方法計算batch內的平均loss。所以這裡的cross entropy其實是一種class-wise的cross entropy,每一個class是否存在都是一個事件,對每一個事件都求cross entropy loss,再對所有的求平均,作為最終的loss。
2、tf.nn.softmax_cross_entropy_with_logits
softmax_cross_entropy_with_logits( _sentinel=None, labels=None, logits=None, dim=-1, name=None)
計算網路輸出logits和標籤labels的softmax cross entropy loss,衡量獨立互斥離散分類任務的誤差,說獨立互斥離散分類任務是因為,在這些任務中類與類之間是獨立而且互斥的,比如VOC classification、Imagenet、CIFAR-10甚至MNIST,這些都是多分類任務,但是一張圖就對應著一個類,class在圖片中是否存在是獨立的,並且一張圖中只能有一個class,所以是獨立且互斥事件。
該函數要求每一個label都是一個有效的概率分布,對於Imagenet中的ILSVRC2012這種任務,那麼label應該就對應一個one hot編碼,ILSVRC2012提供的數據集中一共有1000個類,那麼label就應該是一個1x1000的vector,形式為[0,0,...,1,0,....0],1000個元素中有且只有一個元素是1,其餘都是0。
這樣要求的原因很簡單,因為網路的輸出要進行softmax,得到的就是一個有效的概率分布,這裡不同與sigmoid,因為sigmoid並沒有保證網路所有的輸出經過sigmoid後和為1,不是一個有效的概率分布。
有了labels和softmax後的logits,就可以計算交叉熵損失了,最後得到的是形狀為[batch_size, 1]的loss。
3、tf.nn.sparse_softmax_cross_entropy_with_logits
sparse_softmax_cross_entropy_with_logits( _sentinel=None, labels=None, logits=None, name=None)
這個版本是tf.nn.softmax_cross_entropy_with_logits的易用版本,這個版本的logits的形狀依然是[batch_size, num_classes],但是labels的形狀是[batch_size, 1],每個label的取值是從[0, num_classes)的離散值,這也更加符合我們的使用習慣,是哪一類就標哪個類對應的label。
如果已經對label進行了one hot編碼,則可以直接使用tf.nn.softmax_cross_entropy_with_logits。
4、總結:
到底是用sigmoid版本的cross entropy還是softmax版本的cross entropy主要取決於我們模型的目的,以及label的組織方式,這個需要大家在使用的時候去揣摩,到底使用哪一種loss比較合理。
在我最近訓練的segmentation模型中,使用的就是sparse softmax cross entropy,使用的思路就是將輸出的結果從NHWC(這裡C=1,表示該pixel所屬的class),進行一次reshape,形狀變為[N*H*W, 1],label也是如此,傳入函數中進行計算,從而產生loss。從模型訓練的結果來看,這種使用方法沒有錯誤。
如有錯誤還望指正。
推薦閱讀:
※一個模型庫學習所有:谷歌開源模塊化深度學習系統Tensor2Tensor
※手把手教你用TensorFlow實現看圖說話|教程+代碼
※作為深度學習最強框架的TensorFlow如何進行時序預測!
※tf.get_variable
※TensorFlow 教程 #04 - 保存 & 恢復
TAG:深度學習DeepLearning | TensorFlow |