tensorflow訓練自己製作的tfrecords文件時出現ValueError?

問題的詳細描述及代碼在此鏈接里:

簡單描述一下,就是我直接用dataset讀取csv文件拿去訓練的話不出現問題;但是我將csv文件製作成tfrecords文件的話將會出現ValueError: features should be a dictionary of `Tensor`s. Given type: &錯誤。

谷歌了好幾天,也查了好幾天依然沒有解決。

求助各位大神幫忙看一下!

人類身份驗證 - SegmentFault?

segmentfault.com


初看了下,tfrecord文件本身應該沒有問題,錯誤應該是在你的input_fn上。

即下面這段代碼

def my_input_fn(is_shuffle=False, repeat_count=1):
#略
return features, labels

正如錯誤提示:

ValueError: features should be a dictionary of `Tensor`s. Given type: &

簡單的改成

def my_input_fn(is_shuffle=False, repeat_count=1):
#略
return {features:features}, labels

如果還有問題。就按下面幾步進行排查:

  1. 檢查tfrecord文件是否生成正確。將features, labels = iterator.get_next()

features, labels = iterator.get_next()
sess = tf.InteractiveSession()
print(features.eval())
# 可以看到具體的features值就沒錯

2. 這裡你既然使用了feature_columns,就要確保feature_columns中的key和你input_fn中返回的features里的key值相符


使用了

def my_input_fn(is_shuffle=False, repeat_count=1):
#略
return {features:features}, labels

後,錯誤變成了ValueError: Feature PetalLength is not in features dictionary。

這是第二個錯誤。

原因就是上面說的第二點,input_fn中返回的features里的key值和feature_columns不符。

你所使用的feature_columns有4個鍵值

feature_names = [
SepalLength,
SepalWidth,
PetalLength,
PetalWidth
]
feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names]

而你製作的tfrecord的時候把4個features融合在了一起(這種做法是可行的),但是需要在return的部分做對應的調整。

example = tf.train.Example(
features=tf.train.Features(
feature={
label: _int64_feature(label),
features: _float_feature(features)
}
)
)

修改方法1:對parser的返回值進行修改

def parser(record):
keys_to_features = {
label: tf.FixedLenFeature((), dtype=tf.int64),
features: tf.FixedLenFeature(shape=(4,), dtype=tf.float32),
}
parsed = tf.parse_single_example(record, keys_to_features)
my_features = {SepalLength: parsed[features][0],
SepalWidth: parsed[features][1],
PetalLength: parsed[features][2],
PetalWidth: parsed[features][3]
}
return my_features, parsed[label]
def my_input_fn(is_shuffle=False, repeat_count=1):
#略
#改回去
return features, labels

最後給你兩個非常詳細的教程。

YJango:YJango:TensorFlow中層API Datasets+TFRecord的數據導入?

zhuanlan.zhihu.com圖標YJango:YJango:TensorFlow高層API Custom Estimator建立CNN+RNN的演示?

zhuanlan.zhihu.com圖標


feature應該是字典類型的,先做個

feature = dict(feature)


推薦閱讀:

python3.6 安裝後沒有pip?
網頁爬蟲,但總是出現中文亂碼,求大神幫忙解決或看看問題在哪?
Python訪問網頁報錯,ValueError: unknown url type,求問什麼原因?
在Python中下面這句話怎麼理解?
你看好 Python 3 嗎?

TAG:Python3x | TensorFlow |