學習筆記TF023:下載、緩存、屬性字典、惰性屬性、覆蓋數據流圖、資源

確保目錄結構存在。每次創建文件,確保父目錄已經存在。確保指定路徑全部或部分目錄已經存在。創建沿指定路徑上不存在目錄。

下載函數,如果文件名未指定,從URL解析。下載文件,返回本地文件系統文件名。如果文件存在,不下載。如果文件未指定,從URL解析,返回filepath 。實際下載前,檢查下載位置是否有目標名稱文件。是,跳過下載。下載文件,返迴路徑。重複下載,把文件從文件系統刪除。

import os

import shutil

import errno

from lxml import etree

from urllib.request import urlopen

def ensure_directory(directory):

directory = os.path.expanduser(directory)

try:

os.makedirs(directory)

except OSError as e:

if e.errno != errno.EEXIST:

raise e

def download(url, directory, filename=None):

if not filename:

_, filename = os.path.split(url)

directory = os.path.expanduser(directory)

ensure_directory(directory)

filepath = os.path.join(directory, filename)

if os.path.isfile(filepath):

return filepath

print(Download, filepath)

with urlopen(url) as response, open(filepath, wb) as file_:

shutil.copyfileobj(response, file_)

return filepath

磁碟緩存修飾器,較大規模數據集處理中間結果保存磁碟公共位置,緩存載入函數修飾器。Python pickle功能實現函數返回值序列化、反序列化。只適合能納入主存數據集。@disk_cache修飾器,函數實參傳給被修飾函數。函數參數確定參數組合是否有緩存。散列映射為文件名數字。如果是method,跳過第一參數,緩存filepath,directory/basename-hash.pickle。方法method=False參數通知修飾器是否忽略第一個參數。

import functools

import os

import pickle

def disk_cache(basename, directory, method=False):

directory = os.path.expanduser(directory)

ensure_directory(directory)

def wrapper(func):

@functools.wraps(func)

def wrapped(*args, **kwargs):

key = (tuple(args), tuple(kwargs.items()))

if method and key:

key = key[1:]

filename = {}-{}.pickle.format(basename, hash(key))

filepath = os.path.join(directory, filename)

if os.path.isfile(filepath):

with open(filepath, rb) as handle:

return pickle.load(handle)

result = func(*args, **kwargs)

with open(filepath, wb) as handle:

pickle.dump(result, handle)

return result

return wrapped

return wrapper

@disk_cache(dataset, /home/user/dataset/)

def get_dataset(one_hot=True):

dataset = Dataset(example.com/dataset.bz2)

dataset = Tokenize(dataset)

if one_hot:

dataset = OneHotEncoding(dataset)

return dataset

屬性字典。繼承自內置dict類,可用屬性語法訪問悠已有元素。傳入標準字典(鍵值對)。內置函數locals,返回作用域所有局部變數名值映射。

class AttrDict(dict):

def __getattr__(self, key):

if key not in self:

raise AttributeError

return self[key]

def __setattr__(self, key, value):

if key not in self:

raise AttributeError

self[key] = value

惰性屬性修飾器。外部使用。訪問model.optimze,數據流圖創建新計算路徑。調用model.prediction,創建新權值和偏置。定義只計算一次屬性。結果保存到帶有某些前綴的函數調用。惰性屬性,TensorFlow模型結構化、分類。

import functools

def lazy_property(function):

attribute = _lazy_ + function.__name__

@property

@functools.wraps(function)

def wrapper(self):

if not hasattr(self, attribute):

setattr(self, attribute, function(self))

return getattr(self, attribute)

return wrapper

class Model:

def __init__(self, data, target):

self.data = data

self.target = target

self.prediction

self.optimize

self.error

@lazy_property

def prediction(self):

data_size = int(self.data.get_shape()[1])

target_size = int(self.target.get_shape()[1])

weight = tf.Variable(tf.truncated_normal([data_size, target_size]))

bias = tf.Variable(tf.constant(0.1, shape=[target_size]))

incoming = tf.matmul(self.data, weight) + bias

return tf.nn.softmax(incoming)

@lazy_property

def optimize(self):

cross_entropy = -tf.reduce_sum(self.target, tf.log(self.prediction))

optimizer = tf.train.RMSPropOptimizer(0.03)

return optimizer.minimize(cross_entropy)

@lazy_property

def error(self):

mistakes = tf.not_equal(

tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))

return tf.reduce_mean(tf.cast(mistakes, tf.float32))

覆蓋數據流圖修飾器。未明確指定使用期他數據流圖,TensorFlow使用默認。Jupyter Notebook,解釋器狀態在不同一單元執行期間保持。初始默認數據流圖始終存在。執行再次定義數據流圖運算單元,添加到已存在數據流圖。根據菜單選項重新啟動kernel,再次運行所有單元。

創建定製數據流圖,設置默認。所有運算添加到該數據流圖,再次運行單元,創建新數據流圖。舊數據流圖自動清理。

修飾器中創建數據流圖,修飾主函數。主函數定義完整數據流圖,定義佔位符,調用函數創建模型。

import functools

import tensorflow as tf

def overwrite_graph(function):

@functools.wraps(function)

def wrapper(*args, **kwargs):

with tf.Graph().as_default():

return function(*args, **kwargs)

return wrapper

@overwrite_graph

def main():

data = tf.placeholder(...)

target = tf.placeholder(...)

model = Model()

main()

API文檔,編寫代碼時參考:

tensorflow.org/versions

Github庫,跟蹤TensorFlow最新功能特性,閱讀拉拽請求(pull request)、問題(issues)、發行記錄(release note):

github.com/tensorflow/t

分散式 TensorFlow:

tensorflow.org/versions

構建新TensorFlow功能:

tensorflow.org/master/h

郵件列表:

groups.google.com/a/ten

StackOverflow:

stackoverflow.com/quest

代碼:

github.com/backstopmedi

參考資料:

《面向機器智能的TensorFlow實踐》

歡迎付費諮詢(150元每小時),我的微信:qingxingfengzi


推薦閱讀:

factorization machine和logistic regression的區別?

TAG:TensorFlow | 机器学习 | 深度学习DeepLearning |