用Keras和直方圖均衡化進行深度學習的圖像增強
閱讀本技術文章前,請操滿一口河南話,你就能看下去了
一、俺,遇到了啥子問題撒~?
我現在寫的文章都是因為遇到問題了,然後把解決過程給大家呈現出來
那麼,現在我遇到了一個醫學圖像處理問題。最近在處理醫學圖像的問題,發現DataSet一共只有400張圖像,還是分為四類。
那怎麼辦呢?
可能你會說:這還不簡單,遷移學習啊soga,小夥子可以啊,不過今天我們不講它(因為我還沒實踐過)
在這篇文章中,我們將討論並解決此問題:
二、俺,怎麼解決的嘞?
- 圖像增強:它是什麼?它為什麼如此重要?
- Keras:如何將它用於基本的圖像增強。
- 直方圖均衡化:這是什麼?它有什麼用處?
- 實現直方圖均衡技術:修改keras.preprocessing image.py文件的一種方法。
三、俺,怎麼做的嘞?
接下來我會從這四方面來討論解決數據不足的問題
1.圖像增強:它是什麼?它為什麼如此重要?
深度神經網路,尤其是卷積神經網路(CNN),尤其擅長圖像分類任務。最先進的CNN甚至已經被證明超過了人類在圖像識別方面的表現。
image source:https://www.eff.org/ai/metrics
如果想克服收集數以千計的訓練圖像的高昂費用,圖像增強則就是從現有數據集生成訓練數據。
圖像增強是將已經存在於訓練數據集中的圖像進行處理,並對其進行處理以創建相同圖像的許多改變的版本。這既提供了更多的圖像來訓練,也可以幫助我們的分類器暴露在更廣泛的倆個都和色彩情況下,從而使我們的分類器更具有魯棒性,以下是imgaug庫中不同增強的一些示例
source image:https://github.com/aleju/imgaug
2.使用Keras進行基本圖像增強
有很多方法來預處理圖像,在這篇文章中,我借鑒使用keras深度學習庫為增強圖像提供的一些最常用的開箱即用方法,然後演示如何修改keras.preprocessing image.py文件以啟用直方圖均衡化方法。
我們將使用keras自帶的cifar10數據集。但是,我們只會使用數據集中的貓和狗的圖像,以便保持足夠小的任務在CPU上執行。
- 載入 和 格式化數據
我們要做的第一件事就是載入cifar10數據集並格式化圖像,為CNN做準備。
我們還會仔細查看一些圖像,以確保數據已正確載入
先偷看一下長什麼樣?
from __future__ import print_functionimport kerasfrom keras.datasets import cifar10from keras import backend as Kimport matplotlibfrom matplotlib import pyplot as pltimport numpy as np# input image dimensionsimg_rows, img_cols = 32, 32 # the data, shuffled and split between train and test sets(x_train, y_train), (x_test, y_test) = cifar10.load_data() # Only look at cats [=3] and dogs [=5]train_picks = np.ravel(np.logical_or(y_train==3,y_train==5)) test_picks = np.ravel(np.logical_or(y_test==3,y_test==5)) y_train = np.array(y_train[train_picks]==5,dtype=int)y_test = np.array(y_test[test_picks]==5,dtype=int)x_train = x_train[train_picks]x_test = x_test[test_picks].........images = range(0,9)for i in images:plt.subplot(330 + 1 + i)plt.imshow(x_train[i], cmap=pyplot.get_cmap("gray"))# show the plotplt.show()
此處代碼參考鏈接地址:
https://github.com/ryanleeallred/Image_Augmentation/blob/master/Histogram_Modification.ipynb
cifar10圖像只有32 x 32像素,所以在這裡放大時看起來有顆粒感,但是CNN並不知道它有顆粒感,只能看到數據, 嗯,還是人類牛逼。
- 從ImageDataGenerator()創建一個圖像生成器
用keras增強 圖像數據 非常簡單。 Jason Brownlee 對此提供了一個很好的教程。
首先,我們需要通過調用ImageDataGenerator()函數來創建一個圖像生成器,並將它傳遞給我們想要在圖像上執行的變化的參數列表。然後,我們將調用fit()我們的圖像生成器的功能,這將逐批地應用到圖像的變化。默認情況下,這些修改將被隨機應用,所以並不是每一個圖像都會被改變。大家也可以使用keras.preprocessing導出增強的圖像文件到一個文件夾,以便建立一個巨大的數據集的改變圖像,如果你想這樣做,可以參考keras文檔。
- 隨機旋轉圖像
# Rotate images by 90 degreesdatagen = ImageDataGenerator(rotation_range=90)# fit parameters from datadatagen.fit(x_train)# Configure batch size and retrieve one batch of imagesfor X_batch, y_batch in datagen.flow(x_train, y_train, batch_size=9):# Show 9 imagesfor i in range(0, 9): pyplot.subplot(330 + 1 + i) pyplot.imshow(X_batch[i].reshape(img_rows, img_cols, 3))# show the plotpyplot.show()break
- 垂直翻轉圖像
# Flip images verticallydatagen = ImageDataGenerator(vertical_flip=True)# fit parameters from datadatagen.fit(x_train)# Configure batch size and retrieve one batch of imagesfor X_batch, y_batch in datagen.flow(x_train, y_train, batch_size=9):# Show 9 imagesfor i in range(0, 9): pyplot.subplot(330 + 1 + i) pyplot.imshow(X_batch[i].reshape(img_rows, img_cols, 3))# show the plotpyplot.show()break
備註:我感覺這裡需要針對數據集,因為很少有人把狗翻過來看,或者拍照(hahhhh)
- 將圖像垂直或水平移動20%
# Shift images vertically or horizontally # Fill missing pixels with the color of the nearest pixeldatagen = ImageDataGenerator(width_shift_range=.2, height_shift_range=.2, fill_mode="nearest")# fit parameters from datadatagen.fit(x_train)# Configure batch size and retrieve one batch of imagesfor X_batch, y_batch in datagen.flow(x_train, y_train, batch_size=9):# Show 9 imagesfor i in range(0, 9): pyplot.subplot(330 + 1 + i) pyplot.imshow(X_batch[i].reshape(img_rows, img_cols, 3))# show the plotpyplot.show()break
3.直方圖均衡技術
直方圖均衡化是指對比度較低的圖像,並增加圖像相對高低的對比度,以便在陰影中產生細微的差異,並創建較高的對比度圖像。結果可能是驚人的,特別是對於灰度圖像,如圖
使用圖像增強技術來提高圖像的對比度,此方法有時也被稱為「 直方圖拉伸
」,因為它們採用像素強度的分布和拉伸分布來適應更寬範圍的值,從而增加圖像的最亮部分和最暗部分之間的對比度水平。
直方圖均衡
直方圖均衡通過檢測圖像中像素密度的分布並將這些像素密度繪製在直方圖上來增加圖像的對比度。然後分析該直方圖的分布,並且如果存在當前未被使用的像素亮度範圍,則直方圖被「拉伸」以覆蓋這些範圍,然後被「 反投影 」到圖像上以增加總體形象的對比
自適應均衡
自適應均衡與常規直方圖均衡的不同之處在於計算幾個不同的直方圖,每個直方圖對應於圖像的不同部分; 然而,在其他無趣的部分有過度放大雜訊的傾向。
下面的代碼來自於sci-kit圖像庫的文檔,並且已經被修改為在我們的cifar10數據集的第一個圖像上執行上述三個增強。
首先,我們將從sci-kit圖像(skimage)庫中導入必要的模塊,然後修改sci-kit圖像文檔中的代碼以查看數據集第一幅圖像上的增強
# Import skimage modulesfrom skimage import data, img_as_floatfrom skimage import exposure# Lets try augmenting a cifar10 image using these techniquesfrom skimage import data, img_as_floatfrom skimage import exposure# Load an example image from cifar10 datasetimg = images[0]# Set font size for imagesmatplotlib.rcParams["font.size"] = 8# Contrast stretchingp2, p98 = np.percentile(img, (2, 98))img_rescale = exposure.rescale_intensity(img, in_range=(p2, p98))# Histogram Equalizationimg_eq = exposure.equalize_hist(img)# Adaptive Equalizationimg_adapteq = exposure.equalize_adapthist(img, clip_limit=0.03)#### Everything below here is just to create the plot/graphs ##### Display resultsfig = plt.figure(figsize=(8, 5)) axes = np.zeros((2, 4), dtype=np.object)axes[0, 0] = fig.add_subplot(2, 4, 1)for i in range(1, 4): axes[0, i] = fig.add_subplot(2, 4, 1+i, sharex=axes[0,0], sharey=axes[0,0])for i in range(0, 4): axes[1, i] = fig.add_subplot(2, 4, 5+i)ax_img, ax_hist, ax_cdf = plot_img_and_hist(img, axes[:, 0])ax_img.set_title("Low contrast image")y_min, y_max = ax_hist.get_ylim()ax_hist.set_ylabel("Number of pixels")ax_hist.set_yticks(np.linspace(0, y_max, 5))ax_img, ax_hist, ax_cdf = plot_img_and_hist(img_rescale, axes[:, 1])ax_img.set_title("Contrast stretching")ax_img, ax_hist, ax_cdf = plot_img_and_hist(img_eq, axes[:, 2])ax_img.set_title("Histogram equalization")ax_img, ax_hist, ax_cdf = plot_img_and_hist(img_adapteq, axes[:, 3])ax_img.set_title("Adaptive equalization")ax_cdf.set_ylabel("Fraction of total intensity")ax_cdf.set_yticks(np.linspace(0, 1, 5))# prevent overlap of y-axis labelsfig.tight_layout()plt.show()
4.修改keras.preprocessing以啟用直方圖均衡技術。
現在我們已經成功地從cifar10數據集中修改了一個圖像,我們將演示如何修改keras.preprocessing
image.py文件,以執行這些不同的直方圖修改技術,就像我們開箱即可使用的keras增強使用ImageDataGenerator()。
以下是我們將要執行此功能的一般步驟:
- 在你自己的機器上找到keras.preprocessing image.py文件。
- 將image.py文件複製到您的文件或筆記本中。
- 為每個均衡技術添加一個屬性到DataImageGenerator()init函數。
- 將IF語句子句添加到random_transform方法,以便在我們調用時實現增強datagen.fit()。
對keras.preprocessing
運行
image.py
文件進行修改的最簡單方法之一就是將其內容複製並粘貼到我們的代碼中。這將刪除需要導入它。為了確保您抓取的是之前導入的文件的相同版本,最好抓取image.py您計算機上已有的文件。print(keras.__file__)
將列印出機器上keras庫的路徑。路徑(對於mac用戶)可能如下所示:
/usr/local/lib/python3.5/dist-packages/keras/__init__.pyc
這給了我們在本地機器上keras的路徑。
繼續前進,在那裡導航,
然後進入preprocessing文件夾
。在裡面preprocessing你會看到image
.py文件。然後您可以將其內容複製到您的代碼中。該文件很長,但對於初學者來說,這可能是最簡單的方法之一。
編輯 image.py
在image.py的頂部,你可以注釋掉這行:from ..import backend as K如果你已經包含在上面。
此時,請仔細檢查以確保您正在導入必要的scikit-image模塊,以便複製的模塊image.py可以看到它們。
from skimage import data, img_as_floatfrom skimage import exposure
我們現在需要在ImageDataGenerator類的 __ init __
方法中添加六行代碼,以便它具有三個代表我們要添加的增強類型的屬性。下面的代碼是從我目前的image.py中複製的。與#####側面的線是我已經添加的線
def __init__(self, contrast_stretching=False, ##### histogram_equalization=False,##### adaptive_equalization=False, ##### featurewise_center=False, samplewise_center=False, featurewise_std_normalization=False, samplewise_std_normalization=False, zca_whitening=False, rotation_range=0., width_shift_range=0., height_shift_range=0., shear_range=0., zoom_range=0., channel_shift_range=0., fill_mode=』nearest』, cval=0., horizontal_flip=False, vertical_flip=False, rescale=None, preprocessing_function=None, data_format=None): if data_format is None: data_format = K.image_data_format() self.counter = 0 self.contrast_stretching = contrast_stretching, ##### self.adaptive_equalization = adaptive_equalization ##### self.histogram_equalization = histogram_equalization ##### self.featurewise_center = featurewise_center self.samplewise_center = samplewise_center self.featurewise_std_normalization = featurewise_std_normalization self.samplewise_std_normalization = samplewise_std_normalization self.zca_whitening = zca_whitening self.rotation_range = rotation_range self.width_shift_range = width_shift_range self.height_shift_range = height_shift_range self.shear_range = shear_range self.zoom_range = zoom_range self.channel_shift_range = channel_shift_range self.fill_mode = fill_mode self.cval = cval self.horizontal_flip = horizontal_flip self.vertical_flip = vertical_flip self.rescale = rescale self.preprocessing_function = preprocessing_function
該random_transform()(下)函數來響應我們一直傳遞到的參數ImageDataGenerator()功能。
如果我們已經設置了contrast_stretching,adaptive_equalization或者histogram_equalization參數True,當我們調用ImageDataGenerator()時(就像我們對其他圖像增強一樣)random_transform()將會應用所需的圖像增強。
def random_transform(self, x): img_row_axis = self.row_axis - 1 img_col_axis = self.col_axis - 1 img_channel_axis = self.channel_axis - 1# use composition of homographies# to generate final transform that needs to be applied if self.rotation_range: theta = np.pi / 180 * np.random.uniform(-self.rotation_range, self.rotation_range) else: theta = 0 if self.height_shift_range: tx = np.random.uniform(-self.height_shift_range, self.height_shift_range) * x.shape[img_row_axis] else: tx = 0 if self.width_shift_range: ty = np.random.uniform(-self.width_shift_range, self.width_shift_range) * x.shape[img_col_axis] else: ty = 0 if self.shear_range: shear = np.random.uniform(-self.shear_range, self.shear_range) else: shear = 0 if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: zx, zy = 1, 1 else: zx, zy = np.random.uniform(self.zoom_range[0], self.zoom_range[1], 2)transform_matrix = None if theta != 0: rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) transform_matrix = rotation_matrix if tx != 0 or ty != 0: shift_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) transform_matrix = shift_matrix if transform_matrix is None else np.dot(transform_matrix, shift_matrix) if shear != 0: shear_matrix = np.array([[1, -np.sin(shear), 0], [0, np.cos(shear), 0], [0, 0, 1]]) transform_matrix = shear_matrix if transform_matrix is None else np.dot(transform_matrix, shear_matrix) if zx != 1 or zy != 1: zoom_matrix = np.array([[zx, 0, 0], [0, zy, 0], [0, 0, 1]]) transform_matrix = zoom_matrix if transform_matrix is None else np.dot(transform_matrix, zoom_matrix) if transform_matrix is not None: h, w = x.shape[img_row_axis], x.shape[img_col_axis] transform_matrix = transform_matrix_offset_center(transform_matrix, h, w) x = apply_transform(x, transform_matrix, img_channel_axis, fill_mode=self.fill_mode, cval=self.cval) if self.channel_shift_range != 0: x = random_channel_shift(x, self.channel_shift_range, img_channel_axis) if self.horizontal_flip: if np.random.random() < 0.5: x = flip_axis(x, img_col_axis) if self.vertical_flip: if np.random.random() < 0.5: x = flip_axis(x, img_row_axis) if self.contrast_stretching: ##### if np.random.random() < 0.5: ##### p2, p98 = np.percentile(x, (2, 98)) ##### x = exposure.rescale_intensity(x, in_range=(p2, p98)) ##### if self.adaptive_equalization: ##### if np.random.random() < 0.5: ##### x = exposure.equalize_adapthist(x, clip_limit=0.03) ##### if self.histogram_equalization: ##### if np.random.random() < 0.5: ##### x = exposure.equalize_hist(x) ##### return x
現在我們擁有所有必要的代碼,並且可以調用ImageDataGenerator()來執行我們的直方圖修改技術。如果我們將所有三個值都設置為,則這是幾張圖片的樣子True
# Initialize Generatordatagen = ImageDataGenerator(contrast_stretching=True, adaptive_equalization=True, histogram_equalization=True)# fit parameters from datadatagen.fit(x_train)# Configure batch size and retrieve one batch of imagesfor x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=9):# Show the first 9 imagesfor i in range(0, 9): pyplot.subplot(330 + 1 + i) pyplot.imshow(x_batch[i].reshape(img_rows, img_cols, 3))# show the plotpyplot.show()break
- 培訓並驗證您的Keras CNN
最後一步是訓練CNN並驗證模型model.fit_generator(),以便在增強圖像上訓練和驗證我們的神經網路.
from keras.models import Sequentialfrom keras.layers import Dense, Dropout, Flattenfrom keras.layers import Conv2D, MaxPooling2Dbatch_size = 64num_classes = 2epochs = 10model = Sequential()model.add(Conv2D(4, kernel_size=(3, 3),activation="relu",input_shape=input_shape))model.add(Conv2D(8, (3, 3), activation="relu"))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(0.25))model.add(Flatten())model.add(Dense(16, activation="relu"))model.add(Dropout(0.5))model.add(Dense(2, activation="softmax"))model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adadelta(), metrics=["accuracy"])datagen.fit(x_train)history = model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), steps_per_epoch=x_train.shape[0] // batch_size, epochs=20, validation_data=(x_test, y_test))
End 總結
我在這裡展現了一張圖片的增強結果,下圖是我最後的增強結果
左上、增強測試圖片 右上、增強結果
左下、原始數據標籤 右下、原始數據
大家自己嘗試一下哈,誰還不會喊個加油啊--->加油,加油,加油!
轉載和疑問聲明
如果你有什麼疑問或者想要轉載,沒有允許是不能轉載的哈
讚賞一下能不能轉?哈哈,聯繫我啊,我告訴你呢 ~~ 歡迎聯繫我哈,我會給大家慢慢解答啦~~~怎麼聯繫我? 笨啊~ ~~ 你留言也行 你關注微信公眾號1.聽朕給你說:2.tzgns666,3.或者掃那個二維碼,後台聯繫我也行啦!http://weixin.qq.com/r/MjpGQo3EvEXAKckhb2_2 (二維碼自動識別)
終於寫完了,大家可以關注機器學習演算法全棧工程師!同名微信公眾號。
碼字不易啊,如果你覺得本文有幫助,三毛也是愛哦!
推薦閱讀: