2017年8月25日 星期五

深度學習(6)--使用Tensorflow實現類AlexNet model 訓練Cifar10數據集



    繼上一篇我們使用Alexnet 模型來訓練MNIST 數據集,這次我們改用Cifar10數據集來做訓練及預測。Cifar10的數據集可以從以下網址下載:
https://www.cs.toronto.edu/~kriz/cifar.html

它總共包含60000張32x32 RGB彩色的圖案,前五個檔案各有10000張圖檔為訓練數據集,最後10000張為測試用數據集。該數據集分成10個類別,以數字編號0-9分別是

airplane        
automobile   
bird              
cat                
deer              
dog               
frog              
horse            
ship              
truck            


      首先,照慣例我會先練習如何讀進Cifar10數據集從原始的二位元檔案至Numpy的矩陣格式,以便後續處理,這樣也可以用matplotlib套件秀出原圖,這樣做的意義是,就像之前所做的手寫數字辨識練習,可以加入自行創建或收集的圖片再進行預測或訓練,就不會只限於原始的數據集。也就是說如果懂得如何將圖片轉成Numpy的矩陣格式,那麼我就可以使用這個圖片做訓練或預測。例如常見的OpenCV套件不僅可以做到將圖片轉成Numpy矩陣格式,並且可以做很多的效果預處理,如裁減、灰化、調整對比或大小等等。

     在Cifar10的數據集裡每一Row 包含的第一個byte為類別標籤,接下來的1024,1024,1024共3072(32x32x3)byte分別代表該圖R,G,B的數據。

     底下一範例程式,便是可以將原始的Cifar10二進位檔讀出,轉成Numpy矩陣資料格式後,
再用Matplotlib套件,便可還原出原始圖檔:
https://github.com/Ashing00/Cifar10/blob/master/read_cifar.py

效果如下圖所示,上面的數字為其對應的類別標籤。





     有關這次使用的Alexnet 模型跟上一篇MNIST數據即使用的模型差異不大,主要差別在於
這次直接使用32x32x3的大小。
有關Alexnet 模型訓練MNIST數據集,可先參考底下連結:
http://arbu00.blogspot.tw/2017/07/5-tensorflowalexnet.html

所以須將圖檔大小,reshape成32x32x3。

x = tf.reshape(x_, shape=[-1, 32, 32, 3])

reshaped_xs = np.reshape(batch_xs, (
BATCH_SIZE,
32,
32,
3))

在此我也不先做剪裁動作,有些範例會剪裁成28x28x3大小,在這裡為了方便起見,就直接使用原本的尺寸大小。並且也無預作調整對比或是翻轉的預處裡,有興趣的人可以自行練習看看,訓練結果是否會更好。


整個範例晚整的程式碼放在底下連結:
https://github.com/Ashing00/Cifar10

cifar_train.py 為主要訓練的代碼,並儲存模型變數
cifar_inference.py 為建立的Alexnet model
cifar_eval.py 是用來讀取儲存模型的變數,並進行測試數據集的評估

下圖為訓練的過程及結果



下圖是評估測試數據的結果:


從結果來看,訓練數據可達90%精確率,但是預測測試數據集只有70%精確率,有相當的overfitting的現象,表示該模型仍有調整空間又或許在數據的預處裡方面可以多做一些處理。



加入阿布拉機的3D列印與機器人的FB專頁
https://www.facebook.com/arbu00/


<參考資料>
[1]書名:Tensorflow 實戰 作者:黃文堅 唐源
[2]書名:TensorFlow 技術解析與實戰 作者:李嘉璇
[3]Alexnet(ImageNet Classification with Deep Convolutional Neural Networks)


<其他相關文章>
人工神經網路(1)--使用Python實作perceptron(感知器)
人工神經網路(2)--使用Python實作後向傳遞神經網路演算法(Backprogation artificial neature network)
深度學習(1)-如何在windows安裝Theano +Keras +Tensorflow並使用GPU加速訓練神經網路
深度學習(2)--使用Tensorflow實作卷積神經網路(Convolutional neural network,CNN)
深度學習(3)--循環神經網絡(RNN, Recurrent Neural Networks)
深度學習(4)--使用Tensorflow實現類Lenet5手寫數字辨識
深度學習(5)--使用Tensorflow實現類AlexNet手寫數字辨識
機器學習(1)--使用OPENCV KNN實作手寫辨識
機器學習(2)--使用OPENCV SVM實作手寫辨識
演算法(1)--蒙地卡羅法求圓周率及橢圓面積(Monte carlo)
機器學習(3)--適應線性神經元與梯度下降法(Adaline neuron and Gradient descent)
機器學習(4)--資料標準常態化與隨機梯度下降法( standardization & Stochastic Gradient descent)
機器學習(5)--邏輯斯迴歸,過度適合與正規化( Logistic regression,overfitting and regularization)
機器學習(6)--主成分分析(Principal component analysis,PCA)
機器學習(7)--利用核主成分分析(Kernel PCA)處理非線性對應
機器學習(8)--實作多層感知器(Multilayer Perceptron,MLP)手寫數字辨識






沒有留言:

張貼留言