使用tf.contrib.learn解决基本的物流分类器

罗恩·贡捷(Rowan Gontier)

我正在Tensorflow中学习有关tf.contrib.learn的信息,并且正在使用自制练习。练习是将x1和x2作为输入,将三个区域分类如下,标签为三角形/圆形/十字形:在此处输入图片说明

我的代码能够拟合数据并对其进行评估。但是,我似乎无法获得预期的效果。代码如下。有任何想法吗?

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import tempfile

from six.moves import urllib

import pandas as pd
import tensorflow as tf
import numpy as np

FLAGS = None

myImportedDatax1_np = np.array([[.1],[.1],[.2],[.2],[.4],[.4],[.5],[.5],[.1],[.1],[.2],[.2]],dtype=float)
myImportedDatax2_np = np.array([[.1],[.2],[.1],[.2],[.1],[.2],[.1],[.2],[.4],[.5],[.4],[.5]],dtype=float)
combined_Imported_Data_x = np.append(myImportedDatax1_np, myImportedDatax2_np, axis=1)
myImportedDatay_np = np.array([[0],[0],[0],[0],[1],[1],[1],[1],[2],[2],[2],[2]],dtype=int)

def build_estimator(model_dir, model_type):
  x1 = tf.contrib.layers.real_valued_column("x1")
  x2 = tf.contrib.layers.real_valued_column("x2")

  wide_columns = [x1, x2]
  m = tf.contrib.learn.LinearClassifier(model_dir=model_dir, feature_columns=wide_columns)
  return m

def input_fn(input_batch, output_batch):
  inputs = {"x1": tf.constant(input_batch[:,0]), "x2": tf.constant(input_batch[:,1])}
  label = tf.constant(output_batch)
  print(inputs)
  print(label)
  print(input_batch)
  # Returns the feature columns and the label.
  return inputs, label

def train_and_eval(model_dir, model_type, train_steps, train_data, test_data):
  model_dir = tempfile.mkdtemp() if not model_dir else model_dir
  print("model directory = %s" % model_dir)
  m = build_estimator(model_dir, model_type)
  m.fit(input_fn=lambda: input_fn(combined_Imported_Data_x, myImportedDatay_np), steps=train_steps)
  results = m.evaluate(input_fn=lambda: input_fn(np.array([[.4, .1],[.4, .2]], dtype=float), np.array([[0], [0]], dtype=int)), steps=1)
  for key in sorted(results):
    print("%s: %s" % (key, results[key]))
  predictions = list(m.predict(input_fn=({"x1": tf.constant([[.1]]),"x2": tf.constant([[.1]])})))
 # print(predictions)

def main(_):
  train_and_eval(FLAGS.model_dir, FLAGS.model_type, FLAGS.train_steps,
                 FLAGS.train_data, FLAGS.test_data)

if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  parser.add_argument(
      "--model_dir",
      type=str,
      default="",
      help="Base directory for output models."
  )
  parser.add_argument(
      "--model_type",
      type=str,
      default="wide_n_deep",
      help="Valid model types: {'wide', 'deep', 'wide_n_deep'}."
  )
  parser.add_argument(
      "--train_steps",
      type=int,
      default=200,
      help="Number of training steps."
  )
  parser.add_argument(
      "--train_data",
      type=str,
      default="",
      help="Path to the training data."
  )
  parser.add_argument(
      "--test_data",
      type=str,
      default="",
      help="Path to the test data."
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
amo-ej1

要解决此具体问题,您可以添加以下输入函数,该输入函数与现有输入函数相似,不同之处在于它返回None作为元组中的第二个元素

def input_fn_predict():
  inputs = {"x1": tf.constant([0.1]), "x2": tf.constant([0.2])}
  print(inputs)
  return inputs, None

在下一阶段,您可以使用以下命令调用它:

predictions = list(m.predict(input_fn=lambda: input_fn_predict()))

如果您注释掉了打印件,则应该可以使用。

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

使用tf.contrib.learn.Experiment需要tf.train.replica_device_setter吗?

tensorflow tf.contrib.learn.SVM 如何重新加载训练好的模型并使用预测对新数据进行分类

Tensorflow服务-使用tf.contrib.learn.Experiment训练的模型的“无可服务版本”消息

如何使用tf.contrib.learn.Experiment中的train_and_evaluate函数正确应用辍学

使用 tf.contrib.learn.LinearClassifier 后如何保存和加载张量流模型?

使用“tf.contrib.factorization.KMeansClustering”

将tf.contrib.learn输入馈入DNNClassifier

如何为 tf.contrib.learn.DNNRegressor 选择参数

如何在MNIST上使用tf.contrib.model_pruning?

如何使用tf.contrib.keras.optimizers.Adamax?

远离tf.contrib.learn:具有专用评估程序的分布式培训

tf.contrib.learn.io.numpy_input_fn 参数是什么意思?

tf.contrib.learn load_csv_with_header在TensorFlow 1.1中不起作用

线性回归,Tensorflow,非线性方程,tf.contrib.learn

tf.contrib.learn.RunConfig(save_checkpoints_secs = 1))引发意外的关键字TypeError

使用tf.contrib.data.parallel_interleave并行化tf.from_generator

数据集API,迭代器和tf.contrib.data.rejection_resample

如何更新tensorflow以支持tf.contrib?

Tensorflow:tf.contrib弃用

无法在tensorflow会话中保存tf.contrib.learn宽模型和深模型并在TensorFlow Serving上提供

tf.contrib.learn.LinearRegressor为具有一个功能的数据构建意外不良模型

TensorFlow-tf.layers与tf.contrib.layers

tf.reshape与tf.contrib.layers.flatten

tf.nn.relu与tf.contrib.layers.relu?

为什么要使用“ slim = tf.contrib.slim”而不是通常导入苗条?

当我使用tf.contrib.rnn.LayerNormBasicLSTMCell时,类型错误“张量”对象不可迭代

如何使用tf.contrib.opt.ScipyOptimizerInterface获取损失函数历史

如何使用 tensorflow 函数 tf.contrib.legacy_seq2seq.sequence_loss_by_example 的“权重”参数?

如何使用grunt http服务器访问区域设置json文件(grunt-contrib-connect)