圖像檢測faster rcnn(3.4,get_anchor_gt函數)
源碼地址:keras版本faster rcnn
想了解這篇文章的前後內容出門左拐:faster rcnn代碼理解-keras(目錄)
視頻目錄:深度學習一行一行敲faster rcnn-keras版(視頻目錄)
這章是關於--data_generators.py的get_anchor_gt函數
該函數的作用是得到rpn網路的訓練數據
函數輸入:
def get_anchor_gt(all_img_data, class_count, C, img_length_calc_function, backend, mode = train):
- 圖片信息
- 類別統計信息
- 訓練信息類
- 計算輸出特徵圖大小的函數
- keras用的什麼內核
- 是否為訓練
函數輸出:
yield np.copy(x_img), [np.copy(y_rpn_cls), np.copy(y_rpn_regr)], img_data_aug
- 圖片
- 數據對象:第一個是是否包含對象,第二個是回歸梯度
- 增強後的圖片信息
代碼分析:
-------------------------------①------------------------------
sample_selector = SampleSelector(class_count)
這個類是確定是否要跳過該圖片,以達到類平衡
if C.balanced_classes and sample_selector.skip_sample_for_balanced_class(img_data): continue
決定是否要跳過該圖片。continue是結束本次循環,開始下一次循環
if mode == train: img_data_aug, x_img = data_augment.augment(img_data, C, augment=True) else: img_data_aug, x_img = data_augment.augment(img_data, C, augment=False)
augment是用來增強圖片的,主要是圖片的選中操作
-------------------------------②------------------------------
(width, height) = (img_data_aug[width], img_data_aug[height]) (rows, clos, _) = x_img.shape assert clos == width assert rows == height (resized_width, resized_height) = get_new_img_size(width, height, C.im_size) x_img = cv2.resize(x_img, (resized_width, resized_height), interpolation=cv2.INTER_CUBIC)
- 顯示判斷經過增強後的圖片邏輯上不存在問題,即圖片寬高與實際一致
- get_new_img_size:faster_rcnn要把圖片的最短邊規整到600(可以設置成其它).這個是得到新圖片的尺寸
try: y_rpn_cls, y_rpn_regr = calc_rpn(C, img_data_aug, width, height, resized_width, resized_height, img_length_calc_function) except: continue
得到每一張圖片的每一個點的兩個特性以供RPN網路訓練【深度學習一行一行敲faster rcnn-keras版(3.3,calc_rpn函數)】
- y_rpn_cls:是否包含物體
- y_rpn_regr:回歸梯度為多少
-------------------------------③------------------------------
x_img = x_img[:,:,(2,1,0)] x_img = x_img.astype(np.float32) x_img[:,:,0] -= C.img_channel_mean[0] x_img[:,:,1] -= C.img_channel_mean[1] x_img[:,:,2] -= C.img_channel_mean[2] x_img /= C.img_scaling_factor x_img = np.transpose(x_img, (2,0,1)) x_img = np.expand_dims(x_img, axis=0) y_rpn_regr[:, y_rpn_regr.shape[1]//2:,:,:] *= C.std_scaling if backend == tf: x_img = np.transpose(x_img, (0,2,3,1)) y_rpn_cls = np.transpose(y_rpn_cls, (0,2,3,1)) y_rpn_regr = np.transpose(y_rpn_regr, (0,2,3,1))
- 將BGR圖片變為RGB,因為公開訓練好的VGG模型是按照這個訓練的
- 減去均值理由同上
- 將深度變為第一個維度
- 給圖片增加一個維度
- 給回歸梯度除上一個規整因子
- 如果用的是tf內核,還是要把深度調到最後一位了
yield np.copy(x_img), [np.copy(y_rpn_cls), np.copy(y_rpn_regr)], img_data_aug
歡迎關注公眾號:huangxiaobai880
https://www.zhihu.com/video/919238973438697472推薦閱讀:
※LevelDB源碼解析7. 日誌格式
※ROS導航包源碼學習4 --- 全局規劃
※TiDB 源碼閱讀系列文章(三)SQL 的一生
※Retrofit原理解析最簡潔的思路
※LevelDB源碼解析8. 讀取日誌
TAG:深度學習DeepLearning | 機器學習 | 源碼閱讀 |