我正在Keras中使用ImageDataGenerator(),我想获取整个测试数据的标签。
当前,我正在使用以下代码来完成此任务:
test_batches = ImageDataGenerator().flow_from_directory(...)
test_labels = []
for i in range(0,3):
test_labels.extend(np.array(test_batches[i][1]))
但是,此代码仅起作用,因为我知道我总共有150张图像,并且我的批处理大小定义为50。
此外使用:
imgs, labels = next(test_batches)
如本主题类似文章中所建议,仅返回一批标签,而不返回整个数据集。因此,我想知道是否有比我上面使用的方法更有效的方法。
好吧-当您知道时,batch_size
您可以从flow_from_directory
对象获取许多图像:
test_batches = ImageDataGenerator().flow_from_directory(.., batch_size=n)
number_of_examples = len(test_batches.filenames)
number_of_generator_calls = math.ceil(number_of_examples / (1.0 * n))
# 1.0 above is to skip integer division
test_labels = []
for i in range(0,int(number_of_generator_calls)):
test_labels.extend(np.array(test_batches[i][1]))
本文收集自互联网,转载请注明来源。
如有侵权,请联系 [email protected] 删除。
我来说两句