tensorflow訓練自己製作的tfrecords文件時出現ValueError?
問題的詳細描述及代碼在此鏈接里:
簡單描述一下,就是我直接用dataset讀取csv文件拿去訓練的話不出現問題;但是我將csv文件製作成tfrecords文件的話將會出現ValueError: features should be a dictionary of `Tensor`s. Given type: &
錯誤。 谷歌了好幾天,也查了好幾天依然沒有解決。
求助各位大神幫忙看一下!
人類身份驗證 - SegmentFault?segmentfault.com
初看了下,tfrecord文件本身應該沒有問題,錯誤應該是在你的input_fn上。
即下面這段代碼
def my_input_fn(is_shuffle=False, repeat_count=1):
#略
return features, labels
正如錯誤提示:
ValueError: features should be a dictionary of `Tensor`s. Given type: &
簡單的改成
def my_input_fn(is_shuffle=False, repeat_count=1):
#略
return {features:features}, labels
如果還有問題。就按下面幾步進行排查:
- 檢查tfrecord文件是否生成正確。將features, labels = iterator.get_next()
features, labels = iterator.get_next()
sess = tf.InteractiveSession()
print(features.eval())
# 可以看到具體的features值就沒錯
2. 這裡你既然使用了feature_columns,就要確保feature_columns中的key和你input_fn中返回的features里的key值相符
使用了
def my_input_fn(is_shuffle=False, repeat_count=1):
#略
return {features:features}, labels
後,錯誤變成了ValueError: Feature PetalLength is not in features dictionary。
這是第二個錯誤。
原因就是上面說的第二點,input_fn中返回的features里的key值和feature_columns不符。
你所使用的feature_columns有4個鍵值
feature_names = [
SepalLength,
SepalWidth,
PetalLength,
PetalWidth
]
feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names]
而你製作的tfrecord的時候把4個features融合在了一起(這種做法是可行的),但是需要在return的部分做對應的調整。
example = tf.train.Example(
features=tf.train.Features(
feature={
label: _int64_feature(label),
features: _float_feature(features)
}
)
)
修改方法1:對parser的返回值進行修改
def parser(record):
keys_to_features = {
label: tf.FixedLenFeature((), dtype=tf.int64),
features: tf.FixedLenFeature(shape=(4,), dtype=tf.float32),
}
parsed = tf.parse_single_example(record, keys_to_features)
my_features = {SepalLength: parsed[features][0],
SepalWidth: parsed[features][1],
PetalLength: parsed[features][2],
PetalWidth: parsed[features][3]
}
return my_features, parsed[label]
def my_input_fn(is_shuffle=False, repeat_count=1):
#略
#改回去
return features, labels
最後給你兩個非常詳細的教程。
YJango:YJango:TensorFlow中層API Datasets+TFRecord的數據導入?zhuanlan.zhihu.comYJango:YJango:TensorFlow高層API Custom Estimator建立CNN+RNN的演示?zhuanlan.zhihu.comfeature應該是字典類型的,先做個
feature = dict(feature)
推薦閱讀:
※python3.6 安裝後沒有pip?
※網頁爬蟲,但總是出現中文亂碼,求大神幫忙解決或看看問題在哪?
※Python訪問網頁報錯,ValueError: unknown url type,求問什麼原因?
※在Python中下面這句話怎麼理解?
※你看好 Python 3 嗎?
TAG:Python3x | TensorFlow |