我正在使用tensorflow.examples.tutorials.mnist训练具有5个隐藏层的nn。
这是我训练神经网络的方式:
with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
for iteration in range(len(mnist.test.labels)//batch_size):
X_batch, y_batch = mnist.train.next_batch(batch_size)
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
acc_test = accuracy.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels})
print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)
我想训练神经网络以仅识别从0到4的数字。我将logits层更改为具有5个输出。
我如何过滤TensorFlow提供的mnist数据集,以便仅获取0到4之间的数字?
有很多方法可以做到这一点。其中之一就是当你解救自己时X_batch, y_batch = mnist.train.next_batch(batch_size)
。在这一步,您y_batch
将获得有关数字值的信息(数字值或数字的一个整数)。
您遍历批处理中的示例,并检查该数字是否为您所关心的数字。如果是,则将其添加到中cleaned_up_batch
。效率不是很高,但是会起作用。
回答评论:
它效率不高,因为您可能需要多次过滤相同的数据。我认为这不会成为问题,因为MNIST非常小。通常的方法是只过滤一次,创建一个新的数据集并编写自己的函数以从中获取下一批(实际上非常容易,因为您只是从数据集中随机选择了k个元素)
本文收集自互联网,转载请注明来源。
如有侵权,请联系 [email protected] 删除。
我来说两句