如何将Float数组/列表转换为TFRecord?

这是用于将数据转换为TFRecord的代码

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

 def _bytes_feature(value):
   return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _floats_feature(value):
   return tf.train.Feature(float_list=tf.train.FloatList(value=value))

with tf.python_io.TFRecordWriter("train.tfrecords") as writer:
    for row in train_data:
        prices, label, pip = row[0],row[1],row[2]
        prices = np.asarray(prices).astype(np.float32)
        example = tf.train.Example(features=tf.train.Features(feature={
                                           'prices': _floats_feature(prices),
                                           'label': _int64_feature(label[0]),
                                           'pip': _floats_feature(pip)
    }))
        writer.write(example.SerializeToString())

特征价格是一组形状(1,288)。它转换成功!但是,当使用解析函数和数据集API解码数据时。

def parse_func(serialized_data):
    keys_to_features = {'prices': tf.FixedLenFeature([], tf.float32),
                    'label': tf.FixedLenFeature([], tf.int64)}

    parsed_features = tf.parse_single_example(serialized_data, keys_to_features)
    return parsed_features['prices'],tf.one_hot(parsed_features['label'],2)

它给了我错误

C:\ tf_jenkins \ workspace \ rel-win \ M \ windows-gpu \ PY \ 36 \ tensorflow \ core \ framework \ op_kernel.cc:1202] OP_REQUIRES在example_parsing_ops.cc:240中失败:参数无效:密钥:价格。无法解析序列化的示例。2018-03-31 15:37:11.443073:WC:\ tf_jenkins \ workspace \ rel-win \ M \ windows-gpu \ PY \ 36 \ tensorflow \ core \ framework \ op_kernel.cc:1202] OP_REQUIRES在example_parsing_ops.cc中失败:240:无效的参数:键:价格。无法解析序列化的示例。2018-03-31 15:37:11.443313:WC:\ tf_jenkins \ workspace \ rel-win \ M \ windows-gpu \提高类型(e)(node_def,op,消息)PY \ 36 \ tensortensorflow.python.framework。 errors_impl.InvalidArgumentError:键:价格。无法解析序列化的示例。[[节点:ParseSingleExample / ParseSingleExample = ParseSingleExample [Tdense = [DT_INT64,DT_FLOAT],density_keys = [“ label”,“ prices”],density_shapes = [[],[]],num_sparse = 0,sparse_keys = [],sparse_types = []](arg0,ParseSingleExample / Const,ParseSingleExample / Const_1)]] [[节点:IteratorGetNext_1 = IteratorGetNextoutput_shapes = [[?], [?,2]],output_types = [DT_FLOAT,DT_FLOAT],_ device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”]]流\核心\框架\ op_kernel.cc: [1202] OP_REQUIRES在example_parsing_ops.cc:240处失败:无效参数:键:价格。无法解析序列化的示例。重点:价格。无法解析序列化的示例。重点:价格。无法解析序列化的示例。

我发现了问题。tf.io.FixedLenFeature使用tf.io.FixedLenSequenceFeature
而不是用于解析数组)(对于TensorFlow 1,使用tf.代替tf.io.

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章