學習筆記TF061:分散式TensorFlow,分散式原理、最佳實踐
分散式TensorFlow由高性能gRPC庫底層技術支持。Martin Abadi、Ashish Agarwal、Paul Barham論文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》。
分散式原理。分散式集群 由多個伺服器進程、客戶端進程組成。部署方式,單機多卡、分散式(多機多卡)。多機多卡TensorFlow分散式。
單機多卡,單台伺服器多塊GPU。訓練過程:在單機單GPU訓練,數據一個批次(batch)一個批次訓練。單機多GPU,一次處理多個批次數據,每個GPU處理一個批次數據計算。變數參數保存在CPU,數據由CPU分發給多個GPU,GPU計算每個批次更新梯度。CPU收集完多個GPU更新梯度,計算平均梯度,更新參數。繼續計算更新梯度。處理速度取決最慢GPU速度。
分散式,訓練在多個工作節點(worker)。工作節點,實現計算單元。計算伺服器單卡,指伺服器。計算伺服器多卡,多個GPU劃分多個工作節點。數據量大,超過一台機器處理能力,須用分散式。
分散式TensorFlow底層通信,gRPC(google remote procedure call)。gRPC,谷歌開源高性能、跨語言RPC框架。RPC協議,遠程過程調用協議,網路從遠程計算機程度請求服務。
分散式部署方式。分散式運行,多個計算單元(工作節點),後端伺服器部署單工作節點、多工作節點。
單工作節點部署。每台伺服器運行一個工作節點,伺服器多個GPU,一個工作節點可以訪問多塊GPU卡。代碼tf.device()指定運行操作設備。優勢,單機多GPU間通信,效率高。劣勢,手動代碼指定設備。
多工作節點部署。一台伺服器運行多個工作節點。
設置CUDA_VISIBLE_DEVICES環境變數,限制各個工作節點只可見一個GPU,啟動進程添加環境變數。用tf.device()指定特定GPU。多工作節點部署優勢,代碼簡單,提高GPU使用率。劣勢,工作節點通信,需部署多個工作節點。tobegit3hub/tensorflow_examples 。
CUDA_VISIBLE_DEVICES= python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=0
CUDA_VISIBLE_DEVICES= python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=1
CUDA_VISIBLE_DEVICES=0 python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=0
CUDA_VISIBLE_DEVICES=1 python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=1
分散式架構。https://www.tensorflow.org/extend/architecture 。客戶端(client)、服務端(server),服務端包括主節點(master)、工作節點(worker)組成。
客戶端、主節點、工作節點關係。TensorFlow,客戶端會話聯繫主節點,實際工作由工作節點實現,每個工作節點佔一台設備(TensorFlow具體計算硬體抽象,CPU或GPU)。單機模式,客戶端、主節點、工作節點在同一台伺服器。分布模式,可不同伺服器。客戶端->主節點->工作節點/job:worker/task:0->/job:ps/task:0。
客戶端。建立TensorFlow計算圖,建立與集群交互會話層。代碼包含Session()。一個客戶端可同時與多個服務端相連,一具服務端也可與多個客戶端相連。
服務端。運行tf.train.Server實例進程,TensroFlow執行任務集群(cluster)一部分。有主節點服務(Master service)和工作節點服務(Worker service)。運行中,一個主節點進程和數個工作節點進程,主節點進程和工作接點進程通過介面通信。單機多卡和分散式結構相同,只需要更改通信介面實現切換。
主節點服務。實現tensorflow::Session介面。通過RPC服務程序連接工作節點,與工作節點服務進程工作任務通信。TensorFlow服務端,task_index為0作業(job)。
工作節點服務。實現worker_service.proto介面,本地設備計算部分圖。TensorFlow服務端,所有工作節點包含工作節點服務邏輯。每個工作節點負責管理一個或多個設備。工作節點可以是本地不同埠不同進程,或多台服務多個進程。運行TensorFlow分散式執行任務集,一個或多個作業(job)。每個作業,一個或多個相同目的任務(task)。每個任務,一個工作進程執行。作業是任務集合,集群是作業集合。
分散式機器學習框架,作業分參數作業(parameter job)和工作節點作業(worker job)。參數作業運行伺服器為參數伺服器(parameter server,PS),管理參數存儲、更新。工作節點作業,管理無狀態主要從事計算任務。模型越大,參數越多,模型參數更新超過一台機器性能,需要把參數分開到不同機器存儲更新。參數服務,多台機器組成集群,類似分散式存儲架構,涉及數據同步、一致性,參數存儲為鍵值對(key-value)。分散式鍵值內存資料庫,加參數更新操作。李沐《Parameter Server for Distributed Machine Learning》http://www.cs.cmu.edu/~muli/file/ps.pdf 。
參數存儲更新在參數作業進行,模型計算在工作節點作業進行。TensorFlow分散式實現作業間數據傳輸,參數作業到工作節點作業前向傳播,工作節點作業到參數作業反向傳播。
任務。特定TensorFlow伺服器獨立進程,在作業中擁有對應序號。一個任務對應一個工作節點。集群->作業->任務->工作節點。
客戶端、主節點、工作節點交互過程。單機多卡交互,客戶端->會話運行->主節點->執行子圖->工作節點->GPU0?GPU1。分散式交互,客戶端->會話運行->主節點進程->執行子圖1->工作節點進程1->GPU0?GPU1。《TensorFlow:Large-Scale Machine Learning on Heterogeneous distributed Systems》Large-Scale Machine Learning on Heterogeneous Distributed Systems 。
分散式模式。
數據並行。https://www.tensorflow.org/tutorials/deep_cnn 。CPU負責梯度平均、參數更新,不同GPU訓練模型副本(model replica)。基於訓練樣例子集訓練,模型有獨立性。
步驟:不同GPU分別定義模型網路結構。單個GPU從數據管道讀取不同數據塊,前向傳播,計算損失,計算當前變數梯度。所有GPU輸出梯度數據轉移到CPU,梯度求平均操作,模型變數更新。重複,直到模型變數收斂。
數據並行,提高SGD效率。SGD mini-batch樣本,切成多份,模型複製多份,在多個模型上同時計算。多個模型計算速度不一致,CPU更新變數有同步、非同步兩個方案。
同步更新、非同步更新。分散式隨機梯度下降法,模型參數分散式存儲在不同參數服務上,工作節點並行訓練數據,和參數伺服器通信獲取模型參數。
同步隨機梯度下降法(Sync-SGD,同步更新、同步訓練),訓練時,每個節點上工作任務讀入共享參數,執行並行梯度計算,同步需要等待所有工作節點把局部梯度處好,將所有共享參數合併、累加,再一次性更新到模型參數,下一批次,所有工作節點用模型更新後參數訓練。優勢,每個訓練批次考慮所有工作節點訓練情部,損失下降穩定。劣勢,性能瓶頸在最慢工作節點。異楹設備,工作節點性能不同,劣勢明顯。
非同步隨機梯度下降法(Async-SGD,非同步更新、非同步訓練),每個工作節點任務獨立計算局部梯度,非同步更新到模型參數,不需執行協調、等待操作。優勢,性能不存在瓶頸。劣勢,每個工作節點計算梯度值發磅回參數伺服器有參數更新衝突,影響演算法收劍速度,損失下降過程抖動較大。
同步更新、非同步更新實現區別於更新參數伺服器參數策略。數據量小,各節點計算能力較均衡,用同步模型。數據量大,各機器計算性能參差不齊,用非同步模式。
帶備份的Sync-SGD(Sync-SDG with backup)。Jianmin Chen、Xinghao Pan、Rajat Monga、Aamy Bengio、Rafal Jozefowicz論文《Revisiting Distributed Synchronous SGD》[1604.00981] Revisiting Distributed Synchronous SGD 。增加工作節點,解決部分工作節點計算慢問題。工作節點總數n+n*5%,n為集群工作節點數。非同步更新設定接受到n個工作節點參數直接更新參數伺服器模型參數,進入下一批次模型訓練。計算較慢節點訓練參數直接丟棄。
同步更新、非同步更新有圖內模式(in-graph pattern)和圖間模式(between-graph pattern),獨立於圖內(in-graph)、圖間(between-graph)概念。
圖內複製(in-grasph replication),所有操作(operation)在同一個圖中,用一個客戶端來生成圖,把所有操作分配到集群所有參數伺服器和工作節點上。國內複製和單機多卡類似,擴展到多機多卡,數據分發還是在客戶端一個節點上。優勢,計算節點只需要調用join()函數等待任務,客戶端隨時提交數據就可以訓練。劣勢,訓練數據分發在一個節點上,要分發給不同工作節點,嚴重影響並發訓練速度。
圖間複製(between-graph replication),每一個工作節點創建一個圖,訓練參數保存在參數伺服器,數據不分發,各個工作節點獨立計算,計算完成把要更新參數告訴參數伺服器,參數伺服器更新參數。優勢,不需要數據分發,各個工作節點都創建圖和讀取數據訓練。劣勢,工作節點既是圖創建者又是計算任務執行者,某個工作節點宕機影響集群工作。大數據相關深度學習推薦使用圖間模式。
模型並行。切分模型,模型不同部分執行在不同設備上,一個批次樣本可以在不同設備同時執行。TensorFlow盡量讓相鄰計算在同一台設備上完成節省網路開銷。Martin Abadi、Ashish Agarwal、Paul Barham論文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》Large-Scale Machine Learning on Heterogeneous Distributed Systems 。
模型並行、數據並行,TensorFlow中,計算可以分離,參數可以分離。可以在每個設備上分配計算節點,讓對應參數也在該設備上,計算參數放一起。
分散式API。https://www.tensorflow.org/deploy/distributed 。
創建集群,每個任務(task)啟動一個服務(工作節點服務或主節點服務)。任務可以分布不同機器,可以同一台機器啟動多個任務,用不同GPU運行。每個任務完成工作:創建一個tf.train.ClusterSpec,對集群所有任務進行描述,描述內容對所有任務相同。創建一個tf.train.Server,創建一個服務,運行相應作業計算任務。
TensorFlow分散式開發API。tf.train.ClusterSpec({"ps":ps_hosts,"worker":worke_hosts})。創建TensorFlow集群描述信息,ps、worker為作業名稱,ps_phsts、worker_hosts為作業任務所在節點地址信息。tf.train.ClusterSpec傳入參數,作業和任務間關係映射,映射關係任務通過IP地址、埠號表示。
結構 tf.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
可用任務 /job:local/task:0?/job:local/task:1。
結構 tf.train.ClusterSpec({"worker":["worker0.example.com:2222","worker1.example.com:2222","worker2.example.com:2222"],"ps":["ps0.example.com:2222","ps1.example.com:2222"]})
可用任務 /job:worker/task:0? /job:worker/task:1? /job:worker/task:2? /job:ps/task:0? /job:ps/task:1
tf.train.Server(cluster,job_name,task_index)。創建服務(主節點服務或工作節點服務),運行作業計算任務,運行任務在task_index指定機器啟動。
#任務0
cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
server = tr.train.Server(cluster,job_name="local",task_index=0)
#任務1
cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
server = tr.train.Server(cluster,job_name="local",task_index=1)。
自動化管理節點、監控節點工具。集群管理工具Kubernetes。
tf.device(device_name_or_function)。設定指定設備執行張量運算,批定代碼運行CPU、GPU。
#指定在task0所在機器執行Tensor操作運算
with tf.device("/job:ps/task:0"):
weights_1 = tf.Variable(…)
biases_1 = tf.Variable(…)
分散式訓練代碼框架。創建TensorFlow伺服器集群,在該集群分散式計算數據流圖。tensorflow/tensorflow 。
import argparse
import sys
import tensorflow as tf
FLAGS = None
def main(_):
# 第1步:命令行參數解析,獲取集群信息ps_hosts、worker_hosts
# 當前節點角色信息job_name、task_index
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
# 第2步:創建當前任務節點伺服器
# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# Create and start a server for the local task.
server = tf.train.Server(cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
# 第3步:如果當前節點是參數伺服器,調用server.join()無休止等待;如果是工作節點,執行第4步
if FLAGS.job_name == "ps":
server.join()
# 第4步:構建要訓練模型,構建計算圖
elif FLAGS.job_name == "worker":
# Assigns ops to the local worker by default.
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
# Build model...
loss = ...
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = tf.train.AdagradOptimizer(0.01).minimize(
loss, global_step=global_step)
# The StopAtStepHook handles stopping after running given steps.
# 第5步管理模型訓練過程
hooks=[tf.train.StopAtStepHook(last_step=1000000)]
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir="/tmp/train_logs",
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# Run a training step asynchronously.
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.
# mon_sess.run handles AbortedError in case of preempted PS.
# 訓練模型
mon_sess.run(train_op)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
# Flags for defining the tf.train.ClusterSpec
parser.add_argument(
"--ps_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--worker_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--job_name",
type=str,
default="",
help="One of ps, worker"
)
# Flags for defining the tf.train.Server
parser.add_argument(
"--task_index",
type=int,
default=0,
help="Index of task within the job"
)
FLAGS, unparsed = parser.parse_known_args()
app.run-正在西部數碼(www.west.cn)進行交易(main=main, argv=[sys.argv[0]] + unparsed)
分散式最佳實踐。https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py 。
MNIST數據集分散式訓練。開設3個埠作分散式工作節點部署,2222埠參數伺服器,2223埠工作節點0,2224埠工作節點1。參數伺服器執行參數更新任務,工作節點0?工作節點1執行圖模型訓練計算任務。參數伺服器/job:ps/task:0 cocalhost:2222,工作節點/job:worker/task:0 cocalhost:2223,工作節點/job:worker/task:1 cocalhost:2224。
運行代碼。
python mnist_replica.py --job_name="ps" --task_index=0
python mnist_replica.py --job_name="worker" --task_index=0
python mnist_replica.py --job_name="worker" --task_index=1
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import sys
import tempfile
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 定義常量,用於創建數據流圖
flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/mnist-data",
"Directory for storing mnist data")
# 只下載數據,不做其他操作
flags.DEFINE_boolean("download_only", False,
"Only perform downloading of data; Do not proceed to "
"session preparation, model definition or training")
# task_index從0開始。0代表用來初始化變數的第一個任務
flags.DEFINE_integer("task_index", None,
"Worker task index, should be >= 0. task_index=0 is "
"the master worker task the performs the variable "
"initialization ")
# 每台機器GPU個數,機器沒有GPU為0
flags.DEFINE_integer("num_gpus", 1,
"Total number of gpus for each machine."
"If you dont use GPU, please set it to 0")
# 同步訓練模型下,設置收集工作節點數量。默認工作節點總數
flags.DEFINE_integer("replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update"
"is applied (For sync_replicas mode only; default: "
"num_workers)")
flags.DEFINE_integer("hidden_units", 100,
"Number of units in the hidden layer of the NN")
# 訓練次數
flags.DEFINE_integer("train_steps", 200,
"Number of (global) training steps to perform")
flags.DEFINE_integer("batch_size", 100, "Training batch size")
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
# 使用同步訓練、非同步訓練
flags.DEFINE_boolean("sync_replicas", False,
"Use the sync_replicas (synchronized replicas) mode, "
"wherein the parameter updates from workers are aggregated "
"before applied to avoid stale gradients")
# 如果伺服器已經存在,採用gRPC協議通信;如果不存在,採用進程間通信
flags.DEFINE_boolean(
"existing_servers", False, "Whether servers already exists. If True, "
"will use the worker hosts via their GRPC URLs (one client process "
"per worker host). Otherwise, will create an in-process TensorFlow "
"server.")
# 參數伺服器主機
flags.DEFINE_string("ps_hosts","localhost:2222",
"Comma-separated list of hostname:port pairs")
# 工作節點主機
flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
"Comma-separated list of hostname:port pairs")
# 本作業是工作節點還是參數伺服器
flags.DEFINE_string("job_name", None,"job name: worker or ps")
FLAGS = flags.FLAGS
IMAGE_PIXELS = 28
def main(unused_argv):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
if FLAGS.download_only:
sys.exit(0)
if FLAGS.job_name is None or FLAGS.job_name == "":
raise ValueError("Must specify an explicit `job_name`")
if FLAGS.task_index is None or FLAGS.task_index =="":
raise ValueError("Must specify an explicit `task_index`")
print("job name = %s" % FLAGS.job_name)
print("task index = %d" % FLAGS.task_index)
#Construct the cluster and start the server
# 讀取集群描述信息
ps_spec = FLAGS.ps_hosts.split(",")
worker_spec = FLAGS.worker_hosts.split(",")
# Get the number of workers.
num_workers = len(worker_spec)
# 創建TensorFlow集群描述對象
cluster = tf.train.ClusterSpec({
"ps": ps_spec,
"worker": worker_spec})
# 為本地執行任務創建TensorFlow Server對象。
if not FLAGS.existing_servers:
# Not using existing servers. Create an in-process server.
# 創建本地Sever對象,從tf.train.Server這個定義開始,每個節點開始不同
# 根據執行的命令的參數(作業名字)不同,決定這個任務是哪個任務
# 如果作業名字是ps,進程就加入這裡,作為參數更新的服務,等待其他工作節點給它提交參數更新的數據
# 如果作業名字是worker,就執行後面的計算任務
server = tf.train.Server(
cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
# 如果是參數伺服器,直接啟動即可。這裡,進程就會阻塞在這裡
# 下面的tf.train.replica_device_setter代碼會將參數批定給ps_server保管
if FLAGS.job_name == "ps":
server.join()
# 處理工作節點
# 找出worker的主節點,即task_index為0的點
is_chief = (FLAGS.task_index == 0)
# 如果使用gpu
if FLAGS.num_gpus > 0:
# Avoid gpu allocation conflict: now allocate task_num -> #gpu
# for each worker in the corresponding machine
gpu = (FLAGS.task_index % FLAGS.num_gpus)
# 分配worker到指定gpu上運行
worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
# 如果使用cpu
elif FLAGS.num_gpus == 0:
# Just allocate the CPU to worker server
# 把cpu分配給worker
cpu = 0
worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
# The device setter will automatically place Variables ops on separate
# parameter servers (ps). The non-Variable ops will be placed on the workers.
# The ps use CPU and workers use corresponding GPU
# 用tf.train.replica_device_setter將涉及變數操作分配到參數伺服器上,使用CPU。將涉及非變數操作分配到工作節點上,使用上一步worker_device值。
# 在這個with語句之下定義的參數,會自動分配到參數伺服器上去定義。如果有多個參數伺服器,就輪流循環分配
with tf.device(
tf.train.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/cpu:0",
cluster=cluster)):
# 定義全局步長,默認值為0
global_step = tf.Variable(0, name="global_step", trainable=False)
# Variables of the hidden layer
# 定義隱藏層參數變數,這裡是全連接神經網路隱藏層
hid_w = tf.Variable(
tf.truncated_normal(
[IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
stddev=1.0 / IMAGE_PIXELS),
name="hid_w")
hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
# Variables of the softmax layer
# 定義Softmax 回歸層參數變數
sm_w = tf.Variable(
tf.truncated_normal(
[FLAGS.hidden_units, 10],
stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
# Ops: located on the worker specified with FLAGS.task_index
# 定義模型輸入數據變數
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, [None, 10])
# 構建隱藏層
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)
# 構建損失函數和優化器
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
# 非同步訓練模式:自己計算完成梯度就去更新參數,不同副本之間不會去協調進度
opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
# 同步訓練模式
if FLAGS.sync_replicas:
if FLAGS.replicas_to_aggregate is None:
replicas_to_aggregate = num_workers
else:
replicas_to_aggregate = FLAGS.replicas_to_aggregate
# 使用SyncReplicasOptimizer作優化器,並且是在圖間複製情況下
# 在圖內複製情況下將所有梯度平均
opt = tf.train.SyncReplicasOptimizer(
opt,
replicas_to_aggregate=replicas_to_aggregate,
total_num_replicas=num_workers,
name="mnist_sync_replicas")
train_step = opt.minimize(cross_entropy, global_step=global_step)
if FLAGS.sync_replicas:
local_init_op = opt.local_step_init_op
if is_chief:
# 所有進行計算工作節點裡一個主工作節點(chief)
# 主節點負責初始化參數、模型保存、概要保存
local_init_op = opt.chief_init_op
ready_for_local_init_op = opt.ready_for_local_init_op
# Initial token and chief queue runners required by the sync_replicas mode
# 同步訓練模式所需初始令牌、主隊列
chief_queue_runner = opt.get_chief_queue_runner()
sync_init_op = opt.get_init_tokens_op()
init_op = tf.global_variables_initializer()
train_dir = tempfile.mkdtemp()
if FLAGS.sync_replicas:
# 創建一個監管程序,用於統計訓練模型過程中的信息
# lodger 是保存和載入模型路徑
# 啟動就會去這個logdir目錄看是否有檢查點文件,有的話就自動載入
# 沒有就用init_op指定初始化參數
# 主工作節點(chief)負責模型參數初始化工作
# 過程中,其他工作節點等待主節眯完成初始化工作,初始化完成後,一起開始訓練數據
# global_step值是所有計算節點共享的
# 在執行損失函數最小值時自動加1,通過global_step知道所有計算節點一共計算多少步
sv = tf.train.Supervisor(
is_chief=is_chief,
logdir=train_dir,
init_op=init_op,
local_init_op=local_init_op,
ready_for_local_init_op=ready_for_local_init_op,
recovery_wait_secs=1,
global_step=global_step)
else:
sv = tf.train.Supervisor(
is_chief=is_chief,
logdir=train_dir,
init_op=init_op,
recovery_wait_secs=1,
global_step=global_step)
# 創建會話,設置屬性allow_soft_placement為True
# 所有操作默認使用被指定設置,如GPU
# 如果該操作函數沒有GPU實現,自動使用CPU設備
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
# The chief worker (task_index==0) session will prepare the session,
# while the remaining workers will wait for the preparation to complete.
# 主工作節點(chief),task_index為0節點初始化會話
# 其餘工作節點等待會話被初始化後進行計算
if is_chief:
print("Worker %d: Initializing session..." % FLAGS.task_index)
else:
print("Worker %d: Waiting for session to be initialized..." %
FLAGS.task_index)
if FLAGS.existing_servers:
server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
print("Using existing server at: %s" % server_grpc_url)
# 創建TensorFlow會話對象,用於執行TensorFlow圖計算
# prepare_or_wait_for_session需要參數初始化完成且主節點準備好後,才開始訓練
sess = sv.prepare_or_wait_for_session(server_grpc_url,
config=sess_config)
else:
sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
print("Worker %d: Session initialization complete." % FLAGS.task_index)
if FLAGS.sync_replicas and is_chief:
# Chief worker will start the chief queue runner and call the init op.
sess.run(sync_init_op)
sv.start_queue_runners(sess, [chief_queue_runner])
# Perform training
# 執行分散式模型訓練
time_begin = time.time()
print("Training begins @ %f" % time_begin)
local_step = 0
while True:
# Training feed
# 讀入MNIST訓練數據,默認每批次100張圖片
batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
train_feed = {x: batch_xs, y_: batch_ys}
_, step = sess.run([train_step, global_step], feed_dict=train_feed)
local_step += 1
now = time.time()
print("%f: Worker %d: training step %d done (global step: %d)" %
(now, FLAGS.task_index, local_step, step))
if step >= FLAGS.train_steps:
break
time_end = time.time()
print("Training ends @ %f" % time_end)
training_time = time_end - time_begin
print("Training elapsed time: %f s" % training_time)
# Validation feed
# 讀入MNIST驗證數據,計算驗證的交叉熵
val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
val_xent = sess.run(cross_entropy, feed_dict=val_feed)
print("After %d training step(s), validation cross entropy = %g" %
(FLAGS.train_steps, val_xent))
if __name__ == "__main__":
tf.app.run()
參考資料:
《TensorFlow技術解析與實戰》
歡迎推薦上海機器學習工作機會,我的微信:qingxingfengzit
推薦閱讀:
※理解《Deep Forest: Towards An Alternative to Deep Neural Network》
※機務?演算法工程師!轉職筆記(五)
※《Machine Learning:Clustering & Retrieval》課程第2章之LSH
TAG:TensorFlow | 机器学习 | 深度学习DeepLearning |