Python · 神經網路(六)· 拓展

(這裡是最終成品的 GitHub 地址)

(這裡是本章用到的 GitHub 地址)

這一章主要是將上一章提到的一些拓展實現出來。由於這些和神經網路的演算法本身沒有太大關係而更多只是編程能力的考驗,所以我不會說得太詳細而只會說一個大概思路

  • 分 batch 訓練

    • 把訓練數據集分成 batch 來訓練的原因有許多,我在這裡只說其中兩個比較直觀的原因:

      • 同一個 epoch 中的更新次數降低,收斂變慢
      • 內存扛不住

      極端的栗子就是把所有數據集扔進網路裡面訓練(這也是我們目前實現的方法)。這樣的話一個 epoch 裡面就只能更新一次、而且數據集大了之後內存就會爆炸(……)

    • 實現的方法很簡單,這裡就只放代碼了。分兩步走:
      • 訓練前判斷是否有必要分 batch (如果數據集比 batch_size 還小就沒必要),同時算出一個 epoch 裡面要訓練多少個 batch

        train_len = len(x_train)batch_size = min(batch_size, train_len)do_random_batch = train_len >= batch_sizetrain_repeat = int(train_len / batch_size) + 1

      • 訓練時依 batch 訓練

        for _i in range(train_repeat): if do_random_batch: batch = np.random.choice(train_len, batch_size) x_batch, y_batch = x_train[batch], y_train[batch] else: x_batch, y_batch = x_train, y_train self._train_step.run( feed_dict={self._tfx: x_batch, self._tfy: y_batch})

  • 預測時分批預測

    • 還是因為內存問題。實現的方法有兩種:一種是比較常見的按個數分組,一種是我採取的按數據大小分組。換句話說:

      • 常見的做法是每批預測 k 個數據
      • 我的做法是每批預測 m 個數據,這 m 個數據一共包含一百萬左右個數字

      常見做法有一個顯而易見的缺點:如果單個數據很龐大的話、這樣做可能還是會爆內存

    • 實現稍微有些繁複、主要是要考慮 tensorflow 的 session 的生存周期(不知道能不能這樣說……)、但其實思路是平凡的。下面姑且貼一下核心代碼,但建議跳過(喂)

      def _get_prediction(self, x, batch_size=1e6, out_of_sess=False): single_batch = int(batch_size / np.prod(x.shape[1:])) if not single_batch: single_batch = 1 if single_batch >= len(x): if not out_of_sess: return self._y_pred.eval(feed_dict={self._tfx: x}) with self._sess.as_default(): x = x.astype(np.float32) return self.get_rs(x).eval(feed_dict={self._tfx: x}) if not out_of_sess: rs = [self._y_pred.eval( feed_dict={self._tfx: x[:single_batch]})] else: rs = [self.get_rs(x[:single_batch])] count = single_batch while count < len(x): count += single_batch if count >= len(x): if not out_of_sess: rs.append(self._y_pred.eval(feed_dict={ self._tfx: x[count - single_batch:]})) else: rs.append(self.get_rs(x[count - single_batch:])) else: if not out_of_sess: rs.append(self._y_pred.eval(feed_dict={ self._tfx: x[count - single_batch:count]})) else: rs.append(self.get_rs(x[count - single_batch:count])) if out_of_sess: with self._sess.as_default(): rs = [_rs.eval() for _rs in rs] return np.vstack(rs)

  • 交叉驗證

    • 交叉驗證大概有三種:K 折交叉驗證、留一驗證、和 Holdout 驗證。我們要採用的就是最後這一種、因為它最簡單(喂
    • 實現起來也很直觀:只需要在一開始把數據隨機打亂然後按比例分割即可

      if train_rate is not None: train_rate = float(train_rate) train_len = int(len(x) * train_rate) shuffle_suffix = np.random.permutation(int(len(x))) x, y = x[shuffle_suffix], y[shuffle_suffix] x_train, y_train = x[:train_len], y[:train_len] x_test, y_test = x[train_len:], y[train_len:]else: x_train = x_test = x y_train = y_test = y

      用到了一點 Numpy 的技巧、不過(應該)不是唯一的做法

  • 實時記錄結果

    • 這一塊的東西挺多的、可能只能大概看看思路:
      • 定義一個屬性 self._logs 以存儲我們的記錄
      • 這個屬性是一個字典,key 為 train 的 value 對應著訓練數據集的記錄,key 為 test 的 value 則對應測試數據集的
      • 記錄的東西通常是對模型的某種評估,常用的評估有三種:損失(loss)、準確率(acc)和 F1-score。最後這個評估是針對二類問題的,具體數學含義可以看這裡
      • 定義三個函數,一個拿來記錄這些評估,一個拿來 print 出最新的評估,一個拿來可視化評估

    • 實現的話不難但繁,需要定義好幾個東西。這裡就不貼代碼了否則總有灌水的感覺(……),感興趣的觀眾老爺們可以直接看源代碼 ( σ"ω")σ

以上大概把上一章留下來的坑填了一遍。至此一個還算能用的神經網路模型就已經做好了,它誠然還有相當大可優化的空間、不過作為基礎來說業已足夠

從下一章開始就是比較重頭戲的 CNN 了;本來從 NN 到 CNN 的拓展其實並不太平凡,可是 tensorflow 這個外掛愣是把這件事情變得相當容易……想起當初用純 Numpy 好不容易才把 CNN 整合進這個框架結果看到 tensorflow 之後直接傻眼……唉不說了挺心疼自己的(嘆

最後貼一個比較普適的、不限於 NN 的二維數據可視化函數作為結束吧:

def visualize_2d(self, x, y, plot_scale=2, plot_precision=0.01): plot_num = int(1 / plot_precision) xf = np.linspace( np.min(x) * plot_scale, np.max(x) * plot_scale, plot_num) yf = np.linspace( np.min(x) * plot_scale, np.max(x) * plot_scale, plot_num) input_x, input_y = np.meshgrid(xf, yf) input_xs = np.c_[input_x.ravel(), input_y.ravel()] output_ys_2d = np.argmax(self.predict(input_xs), axis=1).reshape( len(xf), len(yf)) plt.contourf(input_x, input_y, output_ys_2d, cmap=plt.cm.Spectral) plt.scatter(x[:, 0], x[:, 1], c=np.argmax(y, axis=1), s=40, cmap=plt.cm.Spectral) plt.axis("off") plt.show()

把 predict 函數換成任一個有預測功能的函數都行。在我們這個 NN 里,它的效果大概如下:

希望觀眾老爺們能夠喜歡~

(猛戳我進入下一章!( σ"ω")σ )


推薦閱讀:

One-Page AlphaGo -- 10分鐘看懂AlphaGo的核心演算法
跟我學做菜吧!sklearn快速上手!用樸素貝葉斯/SVM分析新聞主題

TAG:Python | 机器学习 | 神经网络 |