从tensorflow中的`tfds`数据中获取验证数据的任何方法?

汉密尔顿

我对从tensorflow_datasetstensorflow中创建验证数据集感到好奇,因为我不清楚如何拆分来自的训练数据tfds我知道使用train_test_splitfrom创建验证数据很容易sklearn,但是我不确定应该如何对from进行验证tfda有谁知道这样做的可能方法吗?有什么想法吗?

试图

我可以按照以下方式进行验证:

from tensorflow.keras.datasets import mnist
from sklearn.model_selection import train_test_split

(X_tr, y_tr), (X_test, y_test) = mnist.load_data()
X_train, X_val, y_train, y_val = train_test_split(X_tr, y_tr, test_size=0.1, stratify=y_tr)

但是我们应该如何从中获取验证数据:

import tensorflow_datasets as tfds
mnst= tfds.load('mnist')
train_data = mnst['train']
test_data = mnst['test']

由此我们如何制作验证数据?有什么想法吗?谢谢!

弗雷特拉

加载数据时,您可以指定拆分,如下所示:

(train_data, validation_data) = tfds.load(
    'mnist',
    split=['train[:80%]', 'train[80%:]'],
    as_supervised=True,
)

拆分可以指定为'train''test'从文档

所有DatasetBuilder都公开了定义为拆分的各种数据子集(例如:训练,测试)

也可以通过一种简单的方法来检查它们:

(training_set, validation_set, test_set) = tfds.load(
    'mnist',
    split=['train[:80%]', 'train[80%:]', 'test'],
    as_supervised=True,
)

将它们转换为numpy数组并检查其形状,将仅显示一个用于演示,其他遵循相同的逻辑,我们使用以下方法进行tfds迭代as_numpy

test_set = tfds.as_numpy(test_set)

x_test = [] # will be containing numpy arrays, I defined them as a list to check.
y_test = []


for features_labels in test_set: # features_labels is a tuple 
                                 # containing features and labels here.
    x_test.append(features_labels[0])
    y_test.append(features_labels[1])
    
x_test = np.array(x_test)    
y_test = np.array(y_test)

现在您可以检查形状:

x_test.shape
>>> (10000, 28, 28, 1)

y_test.shape
>>> (10000,)


x_val.shape
>>> (12000, 28, 28, 1)


x_train.shape
>>> (48000, 28, 28, 1)

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

Django 1.8迁移:从不再具有模型的数据库表中获取数据的任何方法是什么?

如果帐户不在身份验证数据库中,是否有任何方法可以禁用帐户创建?

在MS-Access中,有任何方法可以绕过/覆盖/轻松地临时禁用INSERT SQL语句的数据验证

在SQL批量复制期间检查数据库中数据是否存在的任何方法

任何方法来条件化 jsonencoded 数据中的变量?

ELMAH是否提供任何方法来读取数据库中的错误?

在httpservlet请求中获取路径参数的任何方法

从Flutter中的DocumentSnapshot获取DocumentReference的任何方法?

任何简单的方法来获取Imagenet数据集以在Tensorflow中训练自定义模型?

tfds.load()之后如何在TensorFlow 2.0中应用数据增强

Java:从字节数组中获取ZipFile的任何方法(或带有直接getEntry方法的任何方法)?

将字符串数据从 JSON API 转换为 js 中的 int 的任何方法或解决方案?

是否有任何方法可以使用.net驱动程序从cassandra数据库中忽略实体的列

是否有任何方法或选项来选择大写/小写字段,例如pyspark的数据框中的ABC / abc?

熊猫中是否有任何方法可以将数据帧从天转换为默认的d / m / y格式?

是否有任何方法可以根据多次出现的标志条件提取熊猫数据框中的块

核心数据:以任何方式获取多个实体?

javafx 或 fxml 中是否有任何方法来验证 TextField 字符的长度?

使用zip4j库在Java中获取压缩方法的任何方法吗?

是否有任何方法可以获取在FlagsAttributed枚举中不是类型的所有Flag?

是否有任何方法可以在Java代码中获取Zipkin的TraceId

在Tensorflow中获取数据集的长度

xslt以任何方式从模板中获取调用模板的名称

在Django模型中,是否有任何方法可以通过非pk列从数据库的多个表中选择属性,而只需一次点击DB?

mongodb 中是否有任何方法可以创建仅对数据库的特定视图具有“读取”访问权限的“用户”?

方法不会覆盖其超类中的任何方法

在我的类中的任何方法之前调用特殊方法

验证错误:数据与“ oneOf”中的任何架构都不匹配

Laravel 中是否有任何方法可以自动向经过身份验证的用户添加模型?