基於Tensorflow高階API構建大規模分散式深度學習模型系列:基於Dataset API處理Input pipeline

基於Tensorflow高階API構建大規模分散式深度學習模型系列:基於Dataset API處理Input pipeline

來自專欄演算法工程師的自我修養16 人贊了文章

在TensorFlow 1.3版本之前,讀取數據一般有兩種方法:

  • 使用placeholder + feed_dict讀內存中的數據
  • 使用文件名隊列(string_input_producer)與內存隊列(reader)讀硬碟中的數據

Dataset API同時支持從內存和硬碟的數據讀取,相比之前的兩種方法在語法上更加簡潔易懂。Dataset API可以更方便地與其他高階API配合,快速搭建網路模型。此外,如果想要用到TensorFlow新出的Eager模式,就必須要使用Dataset API來讀取數據。

Dataset可以看作是相同類型「元素」的有序列表。在實際使用時,單個「元素」可以是向量,也可以是字元串、圖片,甚至是tuple或者dict。

一、從內存中讀取數據

用tf.data.Dataset.from_tensor_slices創建了一個最簡單的Dataset:

import tensorflow as tfimport numpy as npdataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))

如何將這個dataset中的元素取出呢?方法是從Dataset中實例化一個Iterator,然後對Iterator進行迭代。

iterator = dataset.make_one_shot_iterator()one_element = iterator.get_next()with tf.Session() as sess: for i in range(5): print(sess.run(one_element))

由於Tensorflow採用了符號式編程(symbolic style programs)模式,而非常見的命令式編程(imperative style programs)模式,因此必須創建一個Session對象才能運行程序。上述代碼中,one_element只是一個Tensor,並不是一個實際的值。調用sess.run(one_element)後,才能真正地取出一個值。如果一個dataset中元素被讀取完了,再嘗試sess.run(one_element)的話,就會拋出tf.errors.OutOfRangeError異常,這個行為與使用隊列方式讀取數據的行為是一致的。

其實,tf.data.Dataset.from_tensor_slices的功能不止如此,它的真正作用是切分傳入Tensor的第一個維度,生成相應的dataset。例如:

dataset = tf.data.Dataset.from_tensor_slices(np.random.uniform(size=(5, 2)))

傳入的數值是一個矩陣,它的形狀為(5, 2),tf.data.Dataset.from_tensor_slices就會切分它形狀上的第一個維度,最後生成的dataset中一個含有5個元素,每個元素的形狀是(2, ),即每個元素是矩陣的一行。

下面我們來看看如何從Dict中構建dataset:

dataset = tf.data.Dataset.from_tensor_slices( { "a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]), "b": np.random.uniform(size=(5, 2)) })

這時函數會分別切分"a"中的數值以及"b"中的數值,最終dataset中的一個元素就是類似於{"a": 1.0, "b": [0.9, 0.1]}的形式。

二、從文件中讀取數據

在實際應用中,模型的訓練和評估數據總是以文件的形式存在文件系統中,目前Dataset API提供了三種從文件讀取數據並創建Dataset的方式,分別用來讀取不同存儲格式的文件。

DataSet類結構

  • tf.data.TextLineDataset():這個函數的輸入是一個文件的列表,輸出是一個dataset。dataset中的每一個元素就對應了文件中的一行。可以使用這個函數來讀入CSV文件。
  • tf.data.FixedLengthRecordDataset():這個函數的輸入是一個文件的列表和一個record_bytes,之後dataset的每一個元素就是文件中固定位元組數record_bytes的內容。通常用來讀取以二進位形式保存的文件,如CIFAR10數據集就是這種形式。
  • tf.data.TFRecordDataset():顧名思義,這個函數是用來讀TFRecord文件的,dataset中的每一個元素就是一個TFExample。

需要說明的是,這三種讀取文件數據創建dataset的方法,不僅能讀取本地文件系統中的文件,還能讀取分散式文件系統(如HDFS)中的文件,這為模型的分散式訓練創造了良好的條件。

三、Dataset的常用Transformation操作

一個Dataset通過數據變換操作可以生成一個新的Dataset。下面介紹數據格式變換、過濾、數據打亂、生產batch和epoch等常用Transformation操作。

(1)map

map接收一個函數,Dataset中的每個元素都會被當作這個函數的輸入,並將函數返回值作為新的Dataset,如我們可以對dataset中每個元素的值取平方:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))dataset = dataset.map(lambda x: x * x) # 1.0, 4.0, 9.0, 16.0, 25.0

(2)filter

filter操作可以過濾掉dataset不滿足條件的元素,它接受一個布爾函數作為參數,dataset中的每個元素都作為該布爾函數的參數,布爾函數返回True的元素保留下來,布爾函數返回False的元素則被過濾掉。

dataset = dataset.filter(filter_func)

(3)shuffle

shuffle功能為打亂dataset中的元素,它有一個參數buffer_size,表示打亂時使用的buffer的大小:

dataset = dataset.shuffle(buffer_size=10000)

(4)repeat

repeat的功能就是將整個序列重複多次,主要用來處理機器學習中的epoch,假設原先的數據是一個epoch,使用repeat(5)就可以將之變成5個epoch:

dataset = dataset.repeat(5)

如果直接調用repeat()的話,生成的序列就會無限重複下去,沒有結束,因此也不會拋出tf.errors.OutOfRangeError異常。

(5)batch

batch就是將多個元素組合成batch,如下面的程序將dataset中的每個元素組成了大小為32的batch:

dataset = dataset.batch(32)

需要注意的是,必須要保證dataset中每個元素擁有相同的shape才能調用batch方法,否則會拋出異常。在調用map方法轉換元素格式的時候尤其要注意這一點。

四、Dataset元素變換案例

1. 解析CSV文件

假設我們有一個Tab分隔4個欄位的文件,則可用如下的代碼解析並生成dataset。

_CSV_COLUMNS = [field1, field2, field3, field4]_CSV_COLUMN_DEFAULTS=[[], [], [0.0], [0.0]]def input_fn(data_file, shuffle, batch_size): def parse_csv(value): columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS, field_delim= ) features = dict(zip(_CSV_COLUMNS, columns)) labels = features.pop(ctr_flag) return features, tf.equal(labels, 1.0) # Extract lines from input files using the Dataset API. dataset = tf.data.TextLineDataset(data_file) if shuffle: dataset = dataset.shuffle(buffer_size=100000) dataset = dataset.map(parse_csv, num_parallel_calls=100) # We call repeat after shuffling, rather than before, to prevent separate # epochs from blending together. dataset = dataset.repeat() dataset = dataset.batch(batch_size) return dataset

上述代碼主要利用tf.decode_csv函數來把CSV文件記錄轉換為Tensors列表,每一列對應一個Tensor。

2. 解析特殊格式的文本文件

有時候我們的訓練數據可能有特殊的格式,比如CVS文件其中某些欄位是JSON格式的字元串,我們要把JSON字元串的內容也解析出來,這個時候tf.decode_csv函數就不夠用了。

是時候請萬能函數tf.py_func上場了,tf.py_func函數能夠把一個任意的python函數封裝成tensorflow的op,提供了極大的靈活性,其定義如下:

tf.py_func( func, inp, Tout, stateful=True, name=None)

tf.py_func的核心是一個func函數(由用戶自己定義),該函數被封裝成graph中的一個節點(op)。第二個參數inp是一個由Tensor組成的list,在執行時,inp的各個Tensor的值被取出來傳給func作為參數。func的返回值會被tf.py_func轉換為Tensors,這些Tensors的類型由Tout指定。當func只有一個返回值時,Tout是一個單獨的tensorflow數據類型;當func函數有多個返回值時,Tout是一個tensorflow數據類型組成的元組或列表。參數stateful表示func函數是否有狀態(產生副作用)。

在使用過程中,有幾個需要注意的地方:

  • func函數的返回值類型一定要和Tout指定的tensor類型一致。
  • tf.py_func中的func是脫離Graph的,在func中不能定義可訓練的參數參與網路訓練(反傳)。
  • tf.py_func操作只能在CPU上運行;如果使用分散式TensorFlow,tf.py_func操作必須放在與客戶端相同進程的CPU設備上。
  • tf.py_func操作返回的tensors是沒有定義形狀(shape)的,必須調用set_shape方法為各個返回值設置shape,才能參與後續的計算。

先來看一個簡單的示例,func函數接受單個參數併產生單個返回值的情況。

def filter_func(line): fields = line.decode().split(" ") if len(fields) < 8: return False for field in fields: if not field: return False return Truedataset = dataset.filter(lambda x: tf.py_func(filter_func, [x], tf.bool, False))

再來看一個稍微複雜一點的例子,該例子解析一個帶有json格式欄位的CSV文件,json欄位被平鋪開來和其他欄位並列作為返回值。

import jsonimport numpy as npimport tensorflow as tfdef parse_line(line): _COLUMNS = ["sellerId", "brandId", "cateId"] _INT_COLUMNS = ["click", "productId", "matchType", "position", "hour"] _FLOAT_COLUMNS = ["matchScore", "popScore", "brandPrefer", "catePrefer"] _STRING_COLUMNS = ["phoneResolution", "phoneBrand", "phoneOs"] _SEQ_COLUMNS = ["behaviorC1ids", "behaviorBids", "behaviorCids", "behaviorPids"] def get_content(record): import datetime fields = record.decode().split(" ") if len(fields) < 8: raise ValueError("invalid record %s" % record) for field in fields: if not field: raise ValueError("invalid record %s" % record) fea = json.loads(fields[1]) if fea["time"]: dt = datetime.datetime.fromtimestamp(fea["time"]) fea["hour"] = dt.hour else: fea["hour"] = 0 seq_len = 10 for x in _SEQ_COLUMNS: sequence = fea.setdefault(x, []) n = len(sequence) if n < seq_len: sequence.extend([-1] * (seq_len - n)) elif n > seq_len: fea[x] = sequence[:seq_len] seq_len = 20 elems = [np.int64(fields[2]), np.int64(fields[3]), np.int64(fields[4]), np.int64(fields[6]), fields[7]] elems += [np.int64(fea.get(x, 0)) for x in _INT_COLUMNS] elems += [np.float32(fea.get(x, 0.0)) for x in _FLOAT_COLUMNS] elems += [fea.get(x, "") for x in _STRING_COLUMNS] elems += [np.int64(fea[x]) for x in _SEQ_COLUMNS] return elems out_type = [tf.int64] * 4 + [tf.string] + [tf.int64] * len(_INT_COLUMNS) + [tf.float32] * len(_FLOAT_COLUMNS) + [ tf.string] * len(_STRING_COLUMNS) + [tf.int64] * len(_SEQ_COLUMNS) result = tf.py_func(get_content, [line], out_type) n = len(result) - len(_SEQ_COLUMNS) for i in range(n): result[i].set_shape([]) result[n].set_shape([10]) for i in range(n + 1, len(result)): result[i].set_shape([20]) columns = _COLUMNS + _INT_COLUMNS + _FLOAT_COLUMNS + _STRING_COLUMNS + _SEQ_COLUMNS features = dict(zip(columns, result)) labels = features.pop(click) return features, labelsdef my_input_fn(filenames, batch_size, shuffle_buffer_size): dataset = tf.data.TextLineDataset(filenames) dataset = dataset.filter(lambda x: tf.py_func(filter_func, [x], tf.bool, False)) dataset = dataset.map(parse_line, num_parallel_calls=100) # Shuffle, repeat, and batch the examples. if shuffle_buffer_size > 0: dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.repeat().batch(batch_size) return dataset

3. 解析TFRECORD文件

Tfrecord是tensorflow官方推薦的訓練數據存儲格式,它更容易與網路應用架構相匹配。

Tfrecord本質上是二進位的Protobuf數據,因而其讀取、傳輸的速度更快。Tfrecord文件的每一條記錄都是一個tf.train.Example的實例。tf.train.Example的proto格式的定義如下:

message Example { Features features = 1;};message Features { map<string, Feature> feature = 1;};message Feature { oneof kind { BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; }};

使用tfrecord文件格式的另一個好處是數據結構統一,屏蔽了底層的數據結構。在類似於圖像分類的任務中,原始數據是各個圖片以單獨的小文件的形式存在,label又以文件夾的形式存在,處理這樣的數據比較麻煩,比如隨機打亂,分batch等操作;而所有原始數據轉換為一個或幾個單獨的tfrecord文件後處理起來就會比較方便。

來看看tensorflow讀取tfrecord文件並轉化為訓練features和labels的代碼:

def parse_exmp(serial_exmp): features = { "click": tf.FixedLenFeature([], tf.int64), "behaviorBids": tf.FixedLenFeature([20], tf.int64), "behaviorCids": tf.FixedLenFeature([20], tf.int64), "behaviorC1ids": tf.FixedLenFeature([10], tf.int64), "behaviorSids": tf.FixedLenFeature([20], tf.int64), "behaviorPids": tf.FixedLenFeature([20], tf.int64), "productId": tf.FixedLenFeature([], tf.int64), "sellerId": tf.FixedLenFeature([], tf.int64), "brandId": tf.FixedLenFeature([], tf.int64), "cate1Id": tf.FixedLenFeature([], tf.int64), "cateId": tf.FixedLenFeature([], tf.int64), "tab": tf.FixedLenFeature([], tf.string), "matchType": tf.FixedLenFeature([], tf.int64) } feats = tf.parse_single_example(serial_exmp, features=features) labels = feats.pop(click) return feats, labelsdef train_input_fn(filenames, batch_size, shuffle_buffer_size): dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.map(parse_exmp, num_parallel_calls=100) # Shuffle, repeat, and batch the examples. if shuffle_buffer_size > 0: dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.repeat().batch(batch_size) return dataset

這裡我們再說說如何把原始數據轉換為tfrecord文件格式,請參考下面的代碼片段:

# 建立tfrecorder writerwriter = tf.python_io.TFRecordWriter(csv_train.tfrecords)for i in xrange(train_values.shape[0]): image_raw = train_values[i].tostring() # build example protobuf example = tf.train.Example( features=tf.train.Features(feature={ image_raw: tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])), label: tf.train.Feature(int64_list=tf.train.Int64List(value=[train_labels[i]])) })) writer.write(record=example.SerializeToString())writer.close()

然而,大規模的訓練數據用這種方式轉換格式會比較低效,更好的實踐是用hadoop或者spark這種分散式計算平台,並行實現數據轉換任務。這裡給出一個用Hadoop MapReduce編程模式轉換為tfrecord文件格式的開源實現:Hadoop MapReduce InputFormat/OutputFormat for TFRecords。由於該實現指定了protobuf的版本,因而可能會跟自己真正使用的hadoop平台自己的protobuf版本不一致,hadoop在默認情況下總是優先使用HADOOP_HOME/lib下的jar包,從而導致運行時錯誤,遇到這種情況時,只需要設置mapreduce.task.classpath.user.precedence=true參數,優先使用自己指定版本的jar包即可。

參考資料

  1. zhuanlan.zhihu.com/p/30
  2. skcript.com/svr/why-eve

推薦閱讀:

淺入淺出TensorFlow 9 - 代碼框架解析
卷積神經網路圖文介紹
TensorFlow 1.9.0正式發布:新手指南更友好
TensorFlow學習筆記之五——源碼分析之最近演算法
淺入淺出TensorFlow 8 - 行人分割

TAG:深度學習DeepLearning | TensorFlow | 人工智慧演算法 |