如何在Tensorflow 2中解码示例(从1.12版本移植)

斯蒂芬·福克

我有以下方法,应该从序列化中解码样本TFRecordDataset

def decode_example(self, serialized_example):
    """Return a dict of Tensors from a serialized tensorflow.Example."""
    data_fields, data_items_to_decoders = self.example_reading_spec()
    # Necessary to rejoin examples in the correct order with the Cloud ML Engine
    # batch prediction API.
    data_fields['batch_prediction_key'] = tf.io.FixedLenFeature([1], tf.int64, 0)
    if data_items_to_decoders is None:
        data_items_to_decoders = {
            field: tf.contrib.slim.tfexample_decoder.Tensor(field)
            for field in data_fields
        }

    decoder = tf.contrib.slim.tfexample_decoder.TFExampleDecoder(data_fields, data_items_to_decoders)

    decode_items = list(sorted(data_items_to_decoders))
    decoded = decoder.decode(serialized_example, items=decode_items)
    return dict(zip(decode_items, decoded))

但是,这在Tensorflow 2下不起作用。

tf.contrib 不再存在,我找不到可用于解码这些示例的任何内容。

TFExampleDecoder安装后我什至找不到tensorflow-data-validation

知道那里有什么问题和/或如何解码示例吗?

斯蒂芬·福克

我能够使用来使其工作tf.io.parse_single_example

我们必须像往常一样声明数据字段(example_reading_spec),然后才能对示例进行解码:

def example_reading_spec():

    data_fields = {
        'inputs': tf.io.VarLenFeature(tf.float32),
        'targets': tf.io.VarLenFeature(tf.int64),
    }

    return data_fields

def decode_example(serialized_example):
    """Return a dict of Tensors from a serialized tensorflow.Example."""
    return tf.io.parse_single_example(
        serialized_example,
        features=example_reading_spec()
    )

现在我们可以Dataset.map像这样加载数据集碎片:

record_dataset = tf.data.TFRecordDataset(filenames, buffer_size=1024)
record_dataset = record_dataset.map(decode_example)

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

如何在熊猫数据框中验证X版本的条目A是否比X-1版本的条目A更新?

如何在Maven中指定avro-mapred的hadoop2版本?

如何在jaxb2-maven-plugin 2.5.0版本中排除情节文件的生成?

如何在Ubuntu 12.04 LTS上安装最新的grub2版本

如何在Kafka 2.1.0版本中清除或删除主题

如何在conda环境中安装keras 2.0.5版本

如何在AWS Lambda函数中降级boto3版本

如何在Typo3版本7.6中覆盖语言文件?

如何在smartAdmin AngularJs 1.8版本中更改URL?

如何在exoplayer 2.11.8版本中传递负载控制

如何在iOS 6.0和7.0版本中隐藏UIApplication状态栏

如何在Windows上的python 3.4版本中安装请求模块?

如何在xampp中升级到php 7.2版本

如何在Visual Studio 2017版本15.3中使用F#?

如何在Visual Studio 2017版本15.8中删除水平白线?

如何在Visual Studio 2019中创建XUnit项目的.NET Framework 4.6版本?

如何在vee validate 3.0版本中验证十进制值

如何在tinyMCE 5.0.11版本中设置每行的最大字符数?

如何在Ubuntu 14.04中安装PgAdmin3版本1.20.0

如何在Visual Basic 2010中获取Windows 32或64位版本?

如何在Ubutu 15.10中将PHP 5.56版本降级到5.4

如何在cytoscape.js 2.6版本中添加自定义布局

如何在集线器中删除对旧统一版本的引用,以便重新安装该版本

无法从Typem 3.8.3版本的NodeJs 12中从Promise.allSettled获取值

如何在 android 中为下一版本的私有 RealmList<String> 数据列表进行 RealmMigration

如何在Angular 6中将ag-grid从18.1.0版本更新到20.1.0?

如何在UWP应用(C#或WinJS)中获取Windows 10版本(例如1809、1903、1909等)?

如何在C#中将Global.asax页面添加到asp.net 4.5版本中?

Delta Lake:如何在下一版本的delta表中不携带已删除的记录?