【Keras實踐】CNN 訓練自己的數據集

本文需要使用的Keras模塊有:

  • fit_generator:用於從Python生成器中訓練網路
  • ImageDataGenerator:用於實時數據提升

1、文件配置

數據集按照下面的形式存放(圖片名可不遵循以下規則)

data/
train/
dogs/
dog001.jpg
dog002.jpg
...
cats/
cat001/jpg
cat002.jpg
...
validation/
dogs/
dog001.jpg
dog002.jpg
...
cats/
cat001/jpg
cat002.jpg
...

2、數據預處理和數據提升

為了盡量利用我們有限的訓練數據,我們將通過一系列隨機變換堆數據進行提升,這樣我們的模型將看不到任何兩張完全相同的圖片,這有利於我們抑制過擬合,使得模型的泛化能力更好。

在Keras中,這個步驟可以通過keras.preprocessing.image.ImageGenerator來實現,這個類使你可以:

  • 在訓練過程中,設置要施行的隨機變換
  • 通過.flow.flow_from_directory(directory)方法實例化一個針對圖像batch的生成器,這些生成器可以被用作keras模型相關方法的輸入,如fit_generatorevaluate_generatorpredict_generator

現在讓我們看個例子:

from keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode=nearest)

上面顯示的只是一部分選項,請閱讀文檔的相關部分來查看全部可用的選項。我們來快速的瀏覽一下這些選項的含義:

  • rotation_range是一個0~180的度數,用來指定隨機選擇圖片的角度。
  • width_shiftheight_shift用來指定水平和豎直方向隨機移動的程度,這是兩個0~1之間的比例。
  • rescale值將在執行其他處理前乘到整個圖像上,我們的圖像在RGB通道都是0~255的整數,這樣的操作可能使圖像的值過高或過低,所以我們將這個值定為0~1之間的數。
  • shear_range是用來進行剪切變換的程度。
  • zoom_range用來進行隨機的放大。
  • horizontal_flip隨機的對圖片進行水平翻轉,這個參數適用於水平翻轉不影響圖片語義的時候。
  • fill_mode用來指定當需要進行像素填充,如旋轉,水平和豎直位移時,如何填充新出現的像素。

3、構建model

from keras.models import Sequential
from keras.layers import Convolution2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense

model = Sequential()
model.add(Convolution2D(32, 3, 3, input_shape=(3, 150, 150)))
model.add(Activation(relu))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Convolution2D(32, 3, 3))
model.add(Activation(relu))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Convolution2D(64, 3, 3))
model.add(Activation(relu))
model.add(MaxPooling2D(pool_size=(2, 2)))

# the model so far outputs 3D feature maps (height, width, features)

model.add(Flatten()) # this converts our 3D feature maps to 1D feature vectors
model.add(Dense(64))
model.add(Activation(relu))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation(sigmoid))

model.compile(loss=binary_crossentropy,
optimizer=adam,
metrics=[accuracy])

然後我們開始準備數據,使用.flow_from_directory()來從我們的jpgs圖片中直接產生數據和標籤(該類會根據文件夾名稱自動one-hot編碼,真的是太方便了)。

# this is the augmentation configuration we will use for training
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1./255)

# this is a generator that will read pictures found in
# subfolers of data/train, and indefinitely generate
# batches of augmented image data
train_generator = train_datagen.flow_from_directory(
data/train, # this is the target directory
target_size=(150, 150), # all images will be resized to 150x150
batch_size=32,
class_mode=binary) # since we use binary_crossentropy loss, we need binary labels

# this is a similar generator, for validation data
validation_generator = test_datagen.flow_from_directory(
data/validation,
target_size=(150, 150),
batch_size=32,
class_mode=binary)

這個生成器來訓練網路

model.fit_generator(
train_generator,
samples_per_epoch=2000,
nb_epoch=50,
validation_data=validation_generator,
nb_val_samples=800)
model.save_weights(first_try.h5) # always save your weights after training or during training

推薦閱讀:

TAG:Keras | 卷積神經網路(CNN) | 深度學習(DeepLearning) |