學習筆記TF061:分散式TensorFlow,分散式原理、最佳實踐

分散式TensorFlow由高性能gRPC庫底層技術支持。Martin Abadi、Ashish Agarwal、Paul Barham論文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》。

分散式原理。分散式集群 由多個伺服器進程、客戶端進程組成。部署方式,單機多卡、分散式(多機多卡)。多機多卡TensorFlow分散式。

單機多卡,單台伺服器多塊GPU。訓練過程:在單機單GPU訓練,數據一個批次(batch)一個批次訓練。單機多GPU,一次處理多個批次數據,每個GPU處理一個批次數據計算。變數參數保存在CPU,數據由CPU分發給多個GPU,GPU計算每個批次更新梯度。CPU收集完多個GPU更新梯度,計算平均梯度,更新參數。繼續計算更新梯度。處理速度取決最慢GPU速度。

分散式,訓練在多個工作節點(worker)。工作節點,實現計算單元。計算伺服器單卡,指伺服器。計算伺服器多卡,多個GPU劃分多個工作節點。數據量大,超過一台機器處理能力,須用分散式。

分散式TensorFlow底層通信,gRPC(google remote procedure call)。gRPC,谷歌開源高性能、跨語言RPC框架。RPC協議,遠程過程調用協議,網路從遠程計算機程度請求服務。

分散式部署方式。分散式運行,多個計算單元(工作節點),後端伺服器部署單工作節點、多工作節點。

單工作節點部署。每台伺服器運行一個工作節點,伺服器多個GPU,一個工作節點可以訪問多塊GPU卡。代碼tf.device()指定運行操作設備。優勢,單機多GPU間通信,效率高。劣勢,手動代碼指定設備。

多工作節點部署。一台伺服器運行多個工作節點。

設置CUDA_VISIBLE_DEVICES環境變數,限制各個工作節點只可見一個GPU,啟動進程添加環境變數。用tf.device()指定特定GPU。多工作節點部署優勢,代碼簡單,提高GPU使用率。劣勢,工作節點通信,需部署多個工作節點。tobegit3hub/tensorflow_examples 。

CUDA_VISIBLE_DEVICES= python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=0

CUDA_VISIBLE_DEVICES= python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=1

CUDA_VISIBLE_DEVICES=0 python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=0

CUDA_VISIBLE_DEVICES=1 python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=1

分散式架構。tensorflow.org/extend/a 。客戶端(client)、服務端(server),服務端包括主節點(master)、工作節點(worker)組成。

客戶端、主節點、工作節點關係。TensorFlow,客戶端會話聯繫主節點,實際工作由工作節點實現,每個工作節點佔一台設備(TensorFlow具體計算硬體抽象,CPU或GPU)。單機模式,客戶端、主節點、工作節點在同一台伺服器。分布模式,可不同伺服器。客戶端->主節點->工作節點/job:worker/task:0->/job:ps/task:0。

客戶端。建立TensorFlow計算圖,建立與集群交互會話層。代碼包含Session()。一個客戶端可同時與多個服務端相連,一具服務端也可與多個客戶端相連。

服務端。運行tf.train.Server實例進程,TensroFlow執行任務集群(cluster)一部分。有主節點服務(Master service)和工作節點服務(Worker service)。運行中,一個主節點進程和數個工作節點進程,主節點進程和工作接點進程通過介面通信。單機多卡和分散式結構相同,只需要更改通信介面實現切換。

主節點服務。實現tensorflow::Session介面。通過RPC服務程序連接工作節點,與工作節點服務進程工作任務通信。TensorFlow服務端,task_index為0作業(job)。

工作節點服務。實現worker_service.proto介面,本地設備計算部分圖。TensorFlow服務端,所有工作節點包含工作節點服務邏輯。每個工作節點負責管理一個或多個設備。工作節點可以是本地不同埠不同進程,或多台服務多個進程。運行TensorFlow分散式執行任務集,一個或多個作業(job)。每個作業,一個或多個相同目的任務(task)。每個任務,一個工作進程執行。作業是任務集合,集群是作業集合。

分散式機器學習框架,作業分參數作業(parameter job)和工作節點作業(worker job)。參數作業運行伺服器為參數伺服器(parameter server,PS),管理參數存儲、更新。工作節點作業,管理無狀態主要從事計算任務。模型越大,參數越多,模型參數更新超過一台機器性能,需要把參數分開到不同機器存儲更新。參數服務,多台機器組成集群,類似分散式存儲架構,涉及數據同步、一致性,參數存儲為鍵值對(key-value)。分散式鍵值內存資料庫,加參數更新操作。李沐《Parameter Server for Distributed Machine Learning》cs.cmu.edu/~muli/file/p

參數存儲更新在參數作業進行,模型計算在工作節點作業進行。TensorFlow分散式實現作業間數據傳輸,參數作業到工作節點作業前向傳播,工作節點作業到參數作業反向傳播。

任務。特定TensorFlow伺服器獨立進程,在作業中擁有對應序號。一個任務對應一個工作節點。集群->作業->任務->工作節點。

客戶端、主節點、工作節點交互過程。單機多卡交互,客戶端->會話運行->主節點->執行子圖->工作節點->GPU0?GPU1。分散式交互,客戶端->會話運行->主節點進程->執行子圖1->工作節點進程1->GPU0?GPU1。《TensorFlow:Large-Scale Machine Learning on Heterogeneous distributed Systems》Large-Scale Machine Learning on Heterogeneous Distributed Systems 。

分散式模式。

數據並行。tensorflow.org/tutorial 。CPU負責梯度平均、參數更新,不同GPU訓練模型副本(model replica)。基於訓練樣例子集訓練,模型有獨立性。

步驟:不同GPU分別定義模型網路結構。單個GPU從數據管道讀取不同數據塊,前向傳播,計算損失,計算當前變數梯度。所有GPU輸出梯度數據轉移到CPU,梯度求平均操作,模型變數更新。重複,直到模型變數收斂。

數據並行,提高SGD效率。SGD mini-batch樣本,切成多份,模型複製多份,在多個模型上同時計算。多個模型計算速度不一致,CPU更新變數有同步、非同步兩個方案。

同步更新、非同步更新。分散式隨機梯度下降法,模型參數分散式存儲在不同參數服務上,工作節點並行訓練數據,和參數伺服器通信獲取模型參數。

同步隨機梯度下降法(Sync-SGD,同步更新、同步訓練),訓練時,每個節點上工作任務讀入共享參數,執行並行梯度計算,同步需要等待所有工作節點把局部梯度處好,將所有共享參數合併、累加,再一次性更新到模型參數,下一批次,所有工作節點用模型更新後參數訓練。優勢,每個訓練批次考慮所有工作節點訓練情部,損失下降穩定。劣勢,性能瓶頸在最慢工作節點。異楹設備,工作節點性能不同,劣勢明顯。

非同步隨機梯度下降法(Async-SGD,非同步更新、非同步訓練),每個工作節點任務獨立計算局部梯度,非同步更新到模型參數,不需執行協調、等待操作。優勢,性能不存在瓶頸。劣勢,每個工作節點計算梯度值發磅回參數伺服器有參數更新衝突,影響演算法收劍速度,損失下降過程抖動較大。

同步更新、非同步更新實現區別於更新參數伺服器參數策略。數據量小,各節點計算能力較均衡,用同步模型。數據量大,各機器計算性能參差不齊,用非同步模式。

帶備份的Sync-SGD(Sync-SDG with backup)。Jianmin Chen、Xinghao Pan、Rajat Monga、Aamy Bengio、Rafal Jozefowicz論文《Revisiting Distributed Synchronous SGD》[1604.00981] Revisiting Distributed Synchronous SGD 。增加工作節點,解決部分工作節點計算慢問題。工作節點總數n+n*5%,n為集群工作節點數。非同步更新設定接受到n個工作節點參數直接更新參數伺服器模型參數,進入下一批次模型訓練。計算較慢節點訓練參數直接丟棄。

同步更新、非同步更新有圖內模式(in-graph pattern)和圖間模式(between-graph pattern),獨立於圖內(in-graph)、圖間(between-graph)概念。

圖內複製(in-grasph replication),所有操作(operation)在同一個圖中,用一個客戶端來生成圖,把所有操作分配到集群所有參數伺服器和工作節點上。國內複製和單機多卡類似,擴展到多機多卡,數據分發還是在客戶端一個節點上。優勢,計算節點只需要調用join()函數等待任務,客戶端隨時提交數據就可以訓練。劣勢,訓練數據分發在一個節點上,要分發給不同工作節點,嚴重影響並發訓練速度。

圖間複製(between-graph replication),每一個工作節點創建一個圖,訓練參數保存在參數伺服器,數據不分發,各個工作節點獨立計算,計算完成把要更新參數告訴參數伺服器,參數伺服器更新參數。優勢,不需要數據分發,各個工作節點都創建圖和讀取數據訓練。劣勢,工作節點既是圖創建者又是計算任務執行者,某個工作節點宕機影響集群工作。大數據相關深度學習推薦使用圖間模式。

模型並行。切分模型,模型不同部分執行在不同設備上,一個批次樣本可以在不同設備同時執行。TensorFlow盡量讓相鄰計算在同一台設備上完成節省網路開銷。Martin Abadi、Ashish Agarwal、Paul Barham論文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》Large-Scale Machine Learning on Heterogeneous Distributed Systems 。

模型並行、數據並行,TensorFlow中,計算可以分離,參數可以分離。可以在每個設備上分配計算節點,讓對應參數也在該設備上,計算參數放一起。

分散式API。tensorflow.org/deploy/d

創建集群,每個任務(task)啟動一個服務(工作節點服務或主節點服務)。任務可以分布不同機器,可以同一台機器啟動多個任務,用不同GPU運行。每個任務完成工作:創建一個tf.train.ClusterSpec,對集群所有任務進行描述,描述內容對所有任務相同。創建一個tf.train.Server,創建一個服務,運行相應作業計算任務。

TensorFlow分散式開發API。tf.train.ClusterSpec({"ps":ps_hosts,"worker":worke_hosts})。創建TensorFlow集群描述信息,ps、worker為作業名稱,ps_phsts、worker_hosts為作業任務所在節點地址信息。tf.train.ClusterSpec傳入參數,作業和任務間關係映射,映射關係任務通過IP地址、埠號表示。

結構 tf.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})

可用任務 /job:local/task:0?/job:local/task:1。

結構 tf.train.ClusterSpec({"worker":["worker0.example.com:2222","worker1.example.com:2222","worker2.example.com:2222"],"ps":["ps0.example.com:2222","ps1.example.com:2222"]})

可用任務 /job:worker/task:0? /job:worker/task:1? /job:worker/task:2? /job:ps/task:0? /job:ps/task:1

tf.train.Server(cluster,job_name,task_index)。創建服務(主節點服務或工作節點服務),運行作業計算任務,運行任務在task_index指定機器啟動。

#任務0

cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})

server = tr.train.Server(cluster,job_name="local",task_index=0)

#任務1

cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})

server = tr.train.Server(cluster,job_name="local",task_index=1)。

自動化管理節點、監控節點工具。集群管理工具Kubernetes。

tf.device(device_name_or_function)。設定指定設備執行張量運算,批定代碼運行CPU、GPU。

#指定在task0所在機器執行Tensor操作運算

with tf.device("/job:ps/task:0"):

weights_1 = tf.Variable(…)

biases_1 = tf.Variable(…)

分散式訓練代碼框架。創建TensorFlow伺服器集群,在該集群分散式計算數據流圖。tensorflow/tensorflow 。

import argparse

import sys

import tensorflow as tf

FLAGS = None

def main(_):

# 第1步:命令行參數解析,獲取集群信息ps_hosts、worker_hosts

# 當前節點角色信息job_name、task_index

ps_hosts = FLAGS.ps_hosts.split(",")

worker_hosts = FLAGS.worker_hosts.split(",")

# 第2步:創建當前任務節點伺服器

# Create a cluster from the parameter server and worker hosts.

cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

# Create and start a server for the local task.

server = tf.train.Server(cluster,

job_name=FLAGS.job_name,

task_index=FLAGS.task_index)

# 第3步:如果當前節點是參數伺服器,調用server.join()無休止等待;如果是工作節點,執行第4步

if FLAGS.job_name == "ps":

server.join()

# 第4步:構建要訓練模型,構建計算圖

elif FLAGS.job_name == "worker":

# Assigns ops to the local worker by default.

with tf.device(tf.train.replica_device_setter(

worker_device="/job:worker/task:%d" % FLAGS.task_index,

cluster=cluster)):

# Build model...

loss = ...

global_step = tf.contrib.framework.get_or_create_global_step()

train_op = tf.train.AdagradOptimizer(0.01).minimize(

loss, global_step=global_step)

# The StopAtStepHook handles stopping after running given steps.

# 第5步管理模型訓練過程

hooks=[tf.train.StopAtStepHook(last_step=1000000)]

# The MonitoredTrainingSession takes care of session initialization,

# restoring from a checkpoint, saving to a checkpoint, and closing when done

# or an error occurs.

with tf.train.MonitoredTrainingSession(master=server.target,

is_chief=(FLAGS.task_index == 0),

checkpoint_dir="/tmp/train_logs",

hooks=hooks) as mon_sess:

while not mon_sess.should_stop():

# Run a training step asynchronously.

# See `tf.train.SyncReplicasOptimizer` for additional details on how to

# perform *synchronous* training.

# mon_sess.run handles AbortedError in case of preempted PS.

# 訓練模型

mon_sess.run(train_op)

if __name__ == "__main__":

parser = argparse.ArgumentParser()

parser.register("type", "bool", lambda v: v.lower() == "true")

# Flags for defining the tf.train.ClusterSpec

parser.add_argument(

"--ps_hosts",

type=str,

default="",

help="Comma-separated list of hostname:port pairs"

)

parser.add_argument(

"--worker_hosts",

type=str,

default="",

help="Comma-separated list of hostname:port pairs"

)

parser.add_argument(

"--job_name",

type=str,

default="",

help="One of ps, worker"

)

# Flags for defining the tf.train.Server

parser.add_argument(

"--task_index",

type=int,

default=0,

help="Index of task within the job"

)

FLAGS, unparsed = parser.parse_known_args()

app.run-正在西部數碼(www.west.cn)進行交易(main=main, argv=[sys.argv[0]] + unparsed)

分散式最佳實踐。github.com/tensorflow/t

MNIST數據集分散式訓練。開設3個埠作分散式工作節點部署,2222埠參數伺服器,2223埠工作節點0,2224埠工作節點1。參數伺服器執行參數更新任務,工作節點0?工作節點1執行圖模型訓練計算任務。參數伺服器/job:ps/task:0 cocalhost:2222,工作節點/job:worker/task:0 cocalhost:2223,工作節點/job:worker/task:1 cocalhost:2224。

運行代碼。

python mnist_replica.py --job_name="ps" --task_index=0

python mnist_replica.py --job_name="worker" --task_index=0

python mnist_replica.py --job_name="worker" --task_index=1

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

import math

import sys

import tempfile

import time

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

# 定義常量,用於創建數據流圖

flags = tf.app.flags

flags.DEFINE_string("data_dir", "/tmp/mnist-data",

"Directory for storing mnist data")

# 只下載數據,不做其他操作

flags.DEFINE_boolean("download_only", False,

"Only perform downloading of data; Do not proceed to "

"session preparation, model definition or training")

# task_index從0開始。0代表用來初始化變數的第一個任務

flags.DEFINE_integer("task_index", None,

"Worker task index, should be >= 0. task_index=0 is "

"the master worker task the performs the variable "

"initialization ")

# 每台機器GPU個數,機器沒有GPU為0

flags.DEFINE_integer("num_gpus", 1,

"Total number of gpus for each machine."

"If you dont use GPU, please set it to 0")

# 同步訓練模型下,設置收集工作節點數量。默認工作節點總數

flags.DEFINE_integer("replicas_to_aggregate", None,

"Number of replicas to aggregate before parameter update"

"is applied (For sync_replicas mode only; default: "

"num_workers)")

flags.DEFINE_integer("hidden_units", 100,

"Number of units in the hidden layer of the NN")

# 訓練次數

flags.DEFINE_integer("train_steps", 200,

"Number of (global) training steps to perform")

flags.DEFINE_integer("batch_size", 100, "Training batch size")

flags.DEFINE_float("learning_rate", 0.01, "Learning rate")

# 使用同步訓練、非同步訓練

flags.DEFINE_boolean("sync_replicas", False,

"Use the sync_replicas (synchronized replicas) mode, "

"wherein the parameter updates from workers are aggregated "

"before applied to avoid stale gradients")

# 如果伺服器已經存在,採用gRPC協議通信;如果不存在,採用進程間通信

flags.DEFINE_boolean(

"existing_servers", False, "Whether servers already exists. If True, "

"will use the worker hosts via their GRPC URLs (one client process "

"per worker host). Otherwise, will create an in-process TensorFlow "

"server.")

# 參數伺服器主機

flags.DEFINE_string("ps_hosts","localhost:2222",

"Comma-separated list of hostname:port pairs")

# 工作節點主機

flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",

"Comma-separated list of hostname:port pairs")

# 本作業是工作節點還是參數伺服器

flags.DEFINE_string("job_name", None,"job name: worker or ps")

FLAGS = flags.FLAGS

IMAGE_PIXELS = 28

def main(unused_argv):

mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

if FLAGS.download_only:

sys.exit(0)

if FLAGS.job_name is None or FLAGS.job_name == "":

raise ValueError("Must specify an explicit `job_name`")

if FLAGS.task_index is None or FLAGS.task_index =="":

raise ValueError("Must specify an explicit `task_index`")

print("job name = %s" % FLAGS.job_name)

print("task index = %d" % FLAGS.task_index)

#Construct the cluster and start the server

# 讀取集群描述信息

ps_spec = FLAGS.ps_hosts.split(",")

worker_spec = FLAGS.worker_hosts.split(",")

# Get the number of workers.

num_workers = len(worker_spec)

# 創建TensorFlow集群描述對象

cluster = tf.train.ClusterSpec({

"ps": ps_spec,

"worker": worker_spec})

# 為本地執行任務創建TensorFlow Server對象。

if not FLAGS.existing_servers:

# Not using existing servers. Create an in-process server.

# 創建本地Sever對象,從tf.train.Server這個定義開始,每個節點開始不同

# 根據執行的命令的參數(作業名字)不同,決定這個任務是哪個任務

# 如果作業名字是ps,進程就加入這裡,作為參數更新的服務,等待其他工作節點給它提交參數更新的數據

# 如果作業名字是worker,就執行後面的計算任務

server = tf.train.Server(

cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)

# 如果是參數伺服器,直接啟動即可。這裡,進程就會阻塞在這裡

# 下面的tf.train.replica_device_setter代碼會將參數批定給ps_server保管

if FLAGS.job_name == "ps":

server.join()

# 處理工作節點

# 找出worker的主節點,即task_index為0的點

is_chief = (FLAGS.task_index == 0)

# 如果使用gpu

if FLAGS.num_gpus > 0:

# Avoid gpu allocation conflict: now allocate task_num -> #gpu

# for each worker in the corresponding machine

gpu = (FLAGS.task_index % FLAGS.num_gpus)

# 分配worker到指定gpu上運行

worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)

# 如果使用cpu

elif FLAGS.num_gpus == 0:

# Just allocate the CPU to worker server

# 把cpu分配給worker

cpu = 0

worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)

# The device setter will automatically place Variables ops on separate

# parameter servers (ps). The non-Variable ops will be placed on the workers.

# The ps use CPU and workers use corresponding GPU

# 用tf.train.replica_device_setter將涉及變數操作分配到參數伺服器上,使用CPU。將涉及非變數操作分配到工作節點上,使用上一步worker_device值。

# 在這個with語句之下定義的參數,會自動分配到參數伺服器上去定義。如果有多個參數伺服器,就輪流循環分配

with tf.device(

tf.train.replica_device_setter(

worker_device=worker_device,

ps_device="/job:ps/cpu:0",

cluster=cluster)):

# 定義全局步長,默認值為0

global_step = tf.Variable(0, name="global_step", trainable=False)

# Variables of the hidden layer

# 定義隱藏層參數變數,這裡是全連接神經網路隱藏層

hid_w = tf.Variable(

tf.truncated_normal(

[IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],

stddev=1.0 / IMAGE_PIXELS),

name="hid_w")

hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")

# Variables of the softmax layer

# 定義Softmax 回歸層參數變數

sm_w = tf.Variable(

tf.truncated_normal(

[FLAGS.hidden_units, 10],

stddev=1.0 / math.sqrt(FLAGS.hidden_units)),

name="sm_w")

sm_b = tf.Variable(tf.zeros([10]), name="sm_b")

# Ops: located on the worker specified with FLAGS.task_index

# 定義模型輸入數據變數

x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])

y_ = tf.placeholder(tf.float32, [None, 10])

# 構建隱藏層

hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)

hid = tf.nn.relu(hid_lin)

# 構建損失函數和優化器

y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))

cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))

# 非同步訓練模式:自己計算完成梯度就去更新參數,不同副本之間不會去協調進度

opt = tf.train.AdamOptimizer(FLAGS.learning_rate)

# 同步訓練模式

if FLAGS.sync_replicas:

if FLAGS.replicas_to_aggregate is None:

replicas_to_aggregate = num_workers

else:

replicas_to_aggregate = FLAGS.replicas_to_aggregate

# 使用SyncReplicasOptimizer作優化器,並且是在圖間複製情況下

# 在圖內複製情況下將所有梯度平均

opt = tf.train.SyncReplicasOptimizer(

opt,

replicas_to_aggregate=replicas_to_aggregate,

total_num_replicas=num_workers,

name="mnist_sync_replicas")

train_step = opt.minimize(cross_entropy, global_step=global_step)

if FLAGS.sync_replicas:

local_init_op = opt.local_step_init_op

if is_chief:

# 所有進行計算工作節點裡一個主工作節點(chief)

# 主節點負責初始化參數、模型保存、概要保存

local_init_op = opt.chief_init_op

ready_for_local_init_op = opt.ready_for_local_init_op

# Initial token and chief queue runners required by the sync_replicas mode

# 同步訓練模式所需初始令牌、主隊列

chief_queue_runner = opt.get_chief_queue_runner()

sync_init_op = opt.get_init_tokens_op()

init_op = tf.global_variables_initializer()

train_dir = tempfile.mkdtemp()

if FLAGS.sync_replicas:

# 創建一個監管程序,用於統計訓練模型過程中的信息

# lodger 是保存和載入模型路徑

# 啟動就會去這個logdir目錄看是否有檢查點文件,有的話就自動載入

# 沒有就用init_op指定初始化參數

# 主工作節點(chief)負責模型參數初始化工作

# 過程中,其他工作節點等待主節眯完成初始化工作,初始化完成後,一起開始訓練數據

# global_step值是所有計算節點共享的

# 在執行損失函數最小值時自動加1,通過global_step知道所有計算節點一共計算多少步

sv = tf.train.Supervisor(

is_chief=is_chief,

logdir=train_dir,

init_op=init_op,

local_init_op=local_init_op,

ready_for_local_init_op=ready_for_local_init_op,

recovery_wait_secs=1,

global_step=global_step)

else:

sv = tf.train.Supervisor(

is_chief=is_chief,

logdir=train_dir,

init_op=init_op,

recovery_wait_secs=1,

global_step=global_step)

# 創建會話,設置屬性allow_soft_placement為True

# 所有操作默認使用被指定設置,如GPU

# 如果該操作函數沒有GPU實現,自動使用CPU設備

sess_config = tf.ConfigProto(

allow_soft_placement=True,

log_device_placement=False,

device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])

# The chief worker (task_index==0) session will prepare the session,

# while the remaining workers will wait for the preparation to complete.

# 主工作節點(chief),task_index為0節點初始化會話

# 其餘工作節點等待會話被初始化後進行計算

if is_chief:

print("Worker %d: Initializing session..." % FLAGS.task_index)

else:

print("Worker %d: Waiting for session to be initialized..." %

FLAGS.task_index)

if FLAGS.existing_servers:

server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]

print("Using existing server at: %s" % server_grpc_url)

# 創建TensorFlow會話對象,用於執行TensorFlow圖計算

# prepare_or_wait_for_session需要參數初始化完成且主節點準備好後,才開始訓練

sess = sv.prepare_or_wait_for_session(server_grpc_url,

config=sess_config)

else:

sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)

print("Worker %d: Session initialization complete." % FLAGS.task_index)

if FLAGS.sync_replicas and is_chief:

# Chief worker will start the chief queue runner and call the init op.

sess.run(sync_init_op)

sv.start_queue_runners(sess, [chief_queue_runner])

# Perform training

# 執行分散式模型訓練

time_begin = time.time()

print("Training begins @ %f" % time_begin)

local_step = 0

while True:

# Training feed

# 讀入MNIST訓練數據,默認每批次100張圖片

batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)

train_feed = {x: batch_xs, y_: batch_ys}

_, step = sess.run([train_step, global_step], feed_dict=train_feed)

local_step += 1

now = time.time()

print("%f: Worker %d: training step %d done (global step: %d)" %

(now, FLAGS.task_index, local_step, step))

if step >= FLAGS.train_steps:

break

time_end = time.time()

print("Training ends @ %f" % time_end)

training_time = time_end - time_begin

print("Training elapsed time: %f s" % training_time)

# Validation feed

# 讀入MNIST驗證數據,計算驗證的交叉熵

val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}

val_xent = sess.run(cross_entropy, feed_dict=val_feed)

print("After %d training step(s), validation cross entropy = %g" %

(FLAGS.train_steps, val_xent))

if __name__ == "__main__":

tf.app.run()

參考資料:

《TensorFlow技術解析與實戰》

歡迎推薦上海機器學習工作機會,我的微信:qingxingfengzit


推薦閱讀:

理解《Deep Forest: Towards An Alternative to Deep Neural Network》
機務?演算法工程師!轉職筆記(五)
《Machine Learning:Clustering & Retrieval》課程第2章之LSH

TAG:TensorFlow | 机器学习 | 深度学习DeepLearning |