圖像檢測faster rcnn(3.4,get_anchor_gt函數)

源碼地址:keras版本faster rcnn

想了解這篇文章的前後內容出門左拐:faster rcnn代碼理解-keras(目錄)

視頻目錄:深度學習一行一行敲faster rcnn-keras版(視頻目錄)

這章是關於--data_generators.py的get_anchor_gt函數

該函數的作用是得到rpn網路的訓練數據

本章代碼流程

函數輸入:

def get_anchor_gt(all_img_data, class_count, C, img_length_calc_function, backend, mode = train):

  1. 圖片信息
  2. 類別統計信息
  3. 訓練信息類
  4. 計算輸出特徵圖大小的函數
  5. keras用的什麼內核
  6. 是否為訓練

函數輸出:

yield np.copy(x_img), [np.copy(y_rpn_cls), np.copy(y_rpn_regr)], img_data_aug

  1. 圖片
  2. 數據對象:第一個是是否包含對象,第二個是回歸梯度
  3. 增強後的圖片信息

代碼分析:

-------------------------------①------------------------------

sample_selector = SampleSelector(class_count)

這個類是確定是否要跳過該圖片,以達到類平衡

if C.balanced_classes and sample_selector.skip_sample_for_balanced_class(img_data): continue

決定是否要跳過該圖片。continue是結束本次循環,開始下一次循環

if mode == train: img_data_aug, x_img = data_augment.augment(img_data, C, augment=True) else: img_data_aug, x_img = data_augment.augment(img_data, C, augment=False)

augment是用來增強圖片的,主要是圖片的選中操作


-------------------------------②------------------------------

(width, height) = (img_data_aug[width], img_data_aug[height]) (rows, clos, _) = x_img.shape assert clos == width assert rows == height (resized_width, resized_height) = get_new_img_size(width, height, C.im_size) x_img = cv2.resize(x_img, (resized_width, resized_height), interpolation=cv2.INTER_CUBIC)

  1. 顯示判斷經過增強後的圖片邏輯上不存在問題,即圖片寬高與實際一致
  2. get_new_img_size:faster_rcnn要把圖片的最短邊規整到600(可以設置成其它).這個是得到新圖片的尺寸

try: y_rpn_cls, y_rpn_regr = calc_rpn(C, img_data_aug, width, height, resized_width, resized_height, img_length_calc_function) except: continue

得到每一張圖片的每一個點的兩個特性以供RPN網路訓練【深度學習一行一行敲faster rcnn-keras版(3.3,calc_rpn函數)】

  1. y_rpn_cls:是否包含物體
  2. y_rpn_regr:回歸梯度為多少

-------------------------------③------------------------------

x_img = x_img[:,:,(2,1,0)] x_img = x_img.astype(np.float32) x_img[:,:,0] -= C.img_channel_mean[0] x_img[:,:,1] -= C.img_channel_mean[1] x_img[:,:,2] -= C.img_channel_mean[2] x_img /= C.img_scaling_factor x_img = np.transpose(x_img, (2,0,1)) x_img = np.expand_dims(x_img, axis=0) y_rpn_regr[:, y_rpn_regr.shape[1]//2:,:,:] *= C.std_scaling if backend == tf: x_img = np.transpose(x_img, (0,2,3,1)) y_rpn_cls = np.transpose(y_rpn_cls, (0,2,3,1)) y_rpn_regr = np.transpose(y_rpn_regr, (0,2,3,1))

  1. 將BGR圖片變為RGB,因為公開訓練好的VGG模型是按照這個訓練的
  2. 減去均值理由同上
  3. 將深度變為第一個維度
  4. 給圖片增加一個維度
  5. 給回歸梯度除上一個規整因子
  6. 如果用的是tf內核,還是要把深度調到最後一位了

yield np.copy(x_img), [np.copy(y_rpn_cls), np.copy(y_rpn_regr)], img_data_aug

歡迎關注公眾號:huangxiaobai880

https://www.zhihu.com/video/919238973438697472
推薦閱讀:

LevelDB源碼解析7. 日誌格式
ROS導航包源碼學習4 --- 全局規劃
TiDB 源碼閱讀系列文章(三)SQL 的一生
Retrofit原理解析最簡潔的思路
LevelDB源碼解析8. 讀取日誌

TAG:深度學習DeepLearning | 機器學習 | 源碼閱讀 |