当我们在 tf.keras.preprocessing.image_dataset_from_directory 对象上使用 .next() 或 .take() 时,我们是否会丢失数据?

阿敏巴

我创建了一个这样的数据生成器:

# Create test_dataset
test_dataset = \
  tf.keras.preprocessing.image_dataset_from_directory(directory=test_dir,
                                                      labels='inferred', 
                                                      label_mode='int', 
                                                      class_names=None,
                                                      seed=42, 
                                                      )
# Explore the first batch
for images, labels in test_dataset.take(1):
  print(labels)

它返回:

tf.Tensor([5 3 8 3 8 5 7 6 3 8 4 2 4 5 5 4 0 1 0 5 5 2 6 0 7 9 9 0 4 9 6 4], shape=(32,), dtype=int32)

如果我重新运行最后一部分如下:

for images, labels in test_dataset.take(1):
  print(labels)

它返回与第一次不同的东西:

tf.Tensor([0 6 2 5 5 7 5 2 7 4 0 5 0 4 6 5 8 7 7 3 5 1 1 9 5 2 6 6 6 6 2 0], shape=(32,), dtype=int32)

如果我重新创建test_dataset和探索它如下:

# Create test_dataset
test_dataset = \
  tf.keras.preprocessing.image_dataset_from_directory(directory=test_dir,
                                                      labels='inferred', 
                                                      label_mode='int', 
                                                      class_names=None,
                                                      seed=42, 
                                                      )
# Explore the first batch
for images, labels in test_dataset.take(1):
  print(labels)

它返回与第一次相同的结果:

tf.Tensor([5 3 8 3 8 5 7 6 3 8 4 2 4 5 5 4 0 1 0 5 5 2 6 0 7 9 9 0 4 9 6 4], shape=(32,), dtype=int32)

好吧,我得出的结论是,当我使用该take方法时,批处理会弹出并丢失,并且无法再用于建模和验证等。

我的问题是:

  • 我对吗?如果我跑,第一批会丢失吗test_dataset.take(1)
  • 如果上述问题的答案是肯定的,那么在尝试探索tf.keras.preprocessing.image_dataset_from_directory对象中的批次时,有什么方法可以不松懈吗?
弗雷特拉

这不是关于丢失批次。函数tf.keras.preprocessing.image_dataset_from_directory有一个参数shuffle,默认值为True也就是说,数据集在每次迭代时都被打乱。

如果我们深入研究源代码

  if shuffle:
    # Shuffle locally at each iteration
    dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)
  dataset = dataset.batch(batch_size)

正如您所看到的,它创建了一个tf.data具有shuffle方法对象Shuffle Methodreshuffle_each_iteration = True默认有一个参数使用 2nd take 方法,您将再次迭代数据集,导致它再次被打乱。

如果shuffle = False为数据集设置,则数据将按字母数字顺序排序,并且每次迭代时其顺序都不会改变。

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章