使用Keras快速搭建深度學習模型測試最新fashion-mnist數據集
keras交流群:523412399
github地址:QuantumLiu/fashion-mnist-demo-by-Keras
fashion-mnist簡介
近日,一個名為fashion-mnist的圖像分類數據集火遍機器學習圈。截止9月1日,在短短7天內就獲得了2k+的star。
zalandoresearch/fashion-mnist
FashionMNIST是一個替代MNIST手寫數字集的圖像數據集。 它是由Zalando(一家德國的時尚科技公司)旗下的研究部門提供。其涵蓋了來自10種類別的共7萬個不同商品的正面圖片。FashionMNIST的大小、格式和訓練集/測試集劃分與原始的MNIST完全一致。60000/10000的訓練測試數據劃分,28x28的灰度圖片。你可以直接用它來測試你的機器學習和深度學習演算法性能,且不需要改動任何的代碼。
眾所周知,mnist是深度學習/圖像分類領域的入門demo和最受歡迎的標準數據集,經常用來測試網路有效性。
然而mnist數據集也有自己的缺點,fashion-mnist正是為了克服這些缺點而生。
寫給專業的機器學習研究者
我們是認真的。取代MNIST數據集的原因由如下幾個:MNIST太簡單了。 很多深度學習演算法在測試集上的準確率已經達到99.6%!不妨看看我們基於scikit-learn上對經典機器學習演算法的評測 和這段代碼: "Most pairs of MNIST digits can be distinguished pretty well by just one pixel"(翻譯:大多數MNIST只需要一個像素就可以區分開!) MNIST被用爛了。 參考:"Ian Goodfellow wants people to move away from mnist"(翻譯:Ian Goodfellow希望人們不要再用MNIST了。) MNIST數字識別的任務不代表現代機器學習。 參考:"Fran?ois Cholle: Ideas on MNIST do not transfer to real CV" (翻譯:在MNIST上看似有效的想法沒法遷移到真正的機器視覺問題上。)
使用嘗鮮
搭建模型
看到了新數據集當然忍不住要嘗鮮了~~
為了做先吃螃蟹的人,我們需要快速搭建一個圖片分類模型模型進行訓練。
我的首選當然是真愛keras
Keras是一個高層神經網路API,Keras由純Python編寫而成並基Tensorflow、Theano以及CNTK後端。Keras 為支持快速實驗而生,能夠把你的idea迅速轉換為結果,如果你有如下需求,請選擇Keras:
簡易和快速的原型設計(keras具有高度模塊化,極簡,和可擴充特性)支持CNN和RNN,或二者的結合無縫CPU和GPU切換
使用keras搭建一個類似於vgg16網路的CNN模型非常的簡潔優雅:
def vgg_fm(input_shape): input_tensor=Input(shape=input_shape) x = Conv2D(64, (3, 3), activation=relu, padding=same, name=block1_conv1)(input_tensor) x = Conv2D(64, (3, 3), activation=relu, padding=same, name=block1_conv2)(x) x = MaxPooling2D((2, 2), strides=(2, 2), name=block1_pool)(x) # Block 2 x = Conv2D(128, (3, 3), activation=relu, padding=same, name=block2_conv1)(x) x = Conv2D(128, (3, 3), activation=relu, padding=same, name=block2_conv2)(x) x = MaxPooling2D((2, 2), strides=(2, 2), name=block2_pool)(x) # Block 3 x = Conv2D(256, (3, 3), activation=relu, padding=same, name=block3_conv1)(x) x = Conv2D(256, (3, 3), activation=relu, padding=same, name=block3_conv2)(x) x = Conv2D(256, (3, 3), activation=relu, padding=same, name=block3_conv3)(x) x = MaxPooling2D((2, 2), strides=(2, 2), name=block3_pool)(x) # Block 4 x = Conv2D(512, (3, 3), activation=relu, padding=same, name=block4_conv1)(x) x = Conv2D(512, (3, 3), activation=relu, padding=same, name=block4_conv2)(x) x = Conv2D(512, (3, 3), activation=relu, padding=same, name=block4_conv3)(x) x = MaxPooling2D((2, 2), strides=(2, 2), name=block4_pool)(x) # Classification block x = Flatten(name=flatten)(x) x = Dense(4096, activation=relu, name=fc1)(x) x = Dense(4096, activation=relu, name=fc2)(x) x = Dropout(0.5)(x) x = Dense(10, activation=softmax, name=predictions)(x) return Model(inputs=[input_tensor],outputs=[x]
keras相當於是Tensorflow/theano這樣的張量計算圖框架的高層封裝,這個函數返回一個Model對象,包含一個完整的網路計算圖,同時擁有一些用於快速訓練的方法。
Model對象需要調用compile方法來完損失成函數、優化器等的指定,之後使用fit方法訓練。
model.compile(optimizer=adam,loss=sparse_categorical_crossentropy,metrics=[acc])model.fit(x=train_x,y=train_y,validation_data=(test_x,test_y),batch_size=batch_size,epochs=100)
keras在訓練時,傳入的data和label是numpy數組的實例,不需要轉換成其他數據類型,十分方便。而且做分類任務時,label支持傳入整數類型的類別ID,不一定是one-hot向量,可以節約內存。
根據fashion-mnist的官方文檔,我們可以很方便的將數據集讀取為numpy數組。
import mnist_readerX_train, y_train = mnist_reader.load_mnist(data/fashion, kind=train)X_test, y_test = mnist_reader.load_mnist(data/fashion, kind=t10k)
完整代碼
import sysimport numpy as npfrom vgg_fm import vgg_fmfrom generators import read_data,reshapefrom manager import GPUManagerfrom keras.callbacks import ModelCheckpointfrom callback import TargetStoppingif __name__==__main__: gm=GPUManager() kwargs=dict(zip([mode,version,batch_size],sys.argv[1:])) mode,version,batch_size=list(map(lambda kd:kwargs.get(kd[0],kd[1]),zip([mode,version,batch_size],[vgg,v1,256]))) batch_size=int(batch_size) model_name=mode+_+version with gm.auto_choice(): (train_x,train_y),(test_x,test_y)=read_data(train),read_data(test) train_x,test_x,train_y,test_y=reshape(train_x,False),reshape(test_x,False),np.expand_dims(train_y,-1),np.expand_dims(test_y,-1) input_shape=train_x.shape[1:] model=vgg_fm(input_shape) model.summary() model.compile(optimizer=adam,loss=sparse_categorical_crossentropy,metrics=[acc]) model.fit(x=train_x,y=train_y,validation_data=(test_x,test_y),batch_size=batch_size,epochs=100, callbacks=[TargetStopping(filepath=model_name+.h5,monitor=val_acc,mode=max,target=0.94), ModelCheckpoint(filepath=model_name+.h5,save_best_only=True,monitor=val_acc)])
實驗結果
最終測試集準確率大概是在93.5%左右,而原始mnist使用同樣的模型打到97以上應該很輕鬆,使用說fashion-mnist著實有一定的挑戰性~
推薦閱讀:
※跟香蕉君一起舞♂蹈
※《Deep Layer Aggregation》論文筆記
※圖像的信噪比是個什麼概念?怎麼算的?
※對圖像進行顏色識別時,如何解決攝像頭偏色的問題?
※圖普科技是一家怎樣的公司?
TAG:深度学习DeepLearning | Python | 图像识别 |