tf.contrib.data.TFRecordDataset无法从* .tfrecord读取

弗雷德里克·埃利施伯格(Frederik Elischberger)

在创建和加载.tfrecord文件的上下文中,我遇到以下问题:

生成dataset.tfrecord文件

文件夹/ Batch_manager / assets包含一些* .tif图像,这些图像用于生成dataset.tfrecord文件:

def _save_as_tfrecord(self, path, name):
    self.__filename = os.path.join(path, name + '.tfrecord')
    writer = tf.python_io.TFRecordWriter(self.__filename)
    print('Writing', self.__filename)
    for index, img in enumerate(self.load(get_iterator=True, n_images=1)):
        img = img[0]
        image_raw = img.tostring()
        rows = img.shape[0]
        cols = img.shape[1]
        try:
            depth = img.shape[2]
        except IndexError:
            depth = 1
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': self._int64_feature(rows), 
            'width': self._int64_feature(cols), 
            'depth': self._int64_feature(depth), 
            'label': self._int64_feature(int(self.target[index])), 
            'image_raw': self._bytes_feature(image_raw)
                }))
        writer.write(example.SerializeToString())
    writer.close()

从dataset.tfrecord文件中读取

接下来,我尝试使用路径指向数据集.tfrecord文件的位置读取该文件:

def dataset_input_fn(self, path):
    dataset = tf.contrib.data.TFRecordDataset(path)

    def parser(record):
        keys_to_features = {
            "height": tf.FixedLenFeature((), tf.int64, default_value=""),
            "width": tf.FixedLenFeature((), tf.int64, default_value=""),
            "depth": tf.FixedLenFeature((), tf.int64, default_value=""),
            "label": tf.FixedLenFeature((), tf.int64, default_value=""),
            "image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
        }
        print(record)
        features = tf.parse_single_example(record, features=keys_to_features)
        print(features)
        label = features['label']
        height = features['height']
        width = features['width']
        depth = features['depth']
        image = tf.decode_raw(features['image_raw'], tf.float32) 
        image = tf.reshape(image, [height, width, -1])
        label = tf.cast(features["label"], tf.int32)

        return {"image_raw": image, "height": height, "width": width, "depth":depth, "label":label}

    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(32)
    iterator = dataset.make_one_shot_iterator()

    # `features` is a dictionary in which each value is a batch of values for
    # that feature; `labels` is a batch of labels.
    features = iterator.get_next()

    return Features

错误信息:

TypeError:预期为int64,取而代之的是类型为“ str”的“”。

这段代码有什么问题?我成功地验证了dataset.tfrecord实际上包含正确的图像和元数据!

弗雷德里克·埃利施伯格(Frederik Elischberger)

发生错误是因为我复制并粘贴了此示例,该示例将所有键值对的值设置为空字符串,由引起default_value=""从所有内容中删除该内容可以tf.FixedLenFeature解决此问题。

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

Tensorflow中的tf.contrib模块的目的是什么?

TensorFlow-tf.layers与tf.contrib.layers

使用tf.contrib.data.parallel_interleave并行化tf.from_generator

tf.reshape与tf.contrib.layers.flatten

无法在tensorflow会话中保存tf.contrib.learn宽模型和深模型并在TensorFlow Serving上提供

tf.nn.relu与tf.contrib.layers.relu?

TensorFlow:tf.contrib.data API中的“无法通过值捕获有状态节点”

tf.contrib.data.Dataset似乎不支持SparseTensor

数据集API,迭代器和tf.contrib.data.rejection_resample

像队列一样将数据输入tf.contrib.data.Dataset

tf.contrib.data.Dataset以随机播放重复,是否注意到时代结束,时代混合?

TFRecord vs TF.image?

无法迭代tf.data.Dataset

用tf.estimator初始化tf.contrib.data.Iterator

无法通过优化学习tf.contrib.distributions.MultivariateNormalDiag的参数

使用tf.data API,TFRecordDataset和序列化时遇到问题

预取(tf.data)和prefetch_to_device(tf.contrib)之间的区别

如何使用tf.data API读取(解码)tfrecords

tensorflow-具有多个TFRecord文件+ tf.contrib.data.sliding_window_batch()的输入管道

Tensorflow:tf.contrib弃用

tf.contrib.metrics.f1_score无法导入

从.tfrecord到tf.data.Dataset到tf.keras.model.fit

使用tf.data API加载tfrecord数据并训练模型,结果没有改变

如何更新tensorflow以支持tf.contrib?

无法使用 tf-slim 框架从 mnist tfrecord 获取图像

`tf.train.shuffle_batch` 在 TensorFlow 中读取 `TFRecord` 文件时崩溃

Tensorflow:tf.contrib.data 中 dataset.map() 的类型不兼容

使用“tf.contrib.factorization.KMeansClustering”

tf.data API 读取 TFRecord 文件