tensorflow tf.contrib.learn.SVM 如何重新加载训练好的模型并使用预测对新数据进行分类

鹰A

使用 tensorflow tf.contrib.learn.SVM 训练 svm 模型并保存模型;代码

feature_columns = [tf.contrib.layers.real_valued_column(feat) for feat in self.feature_columns]
model_dir = os.path.join(define.root, 'src', 'static_data', 'svm_model_dir')
model = svm.SVM(example_id_column='example_id',
                feature_columns=feature_columns,
                 model_dir=model_dir,
                            config=tf.contrib.learn.RunConfig(save_checkpoints_secs=10))
model.fit(input_fn=lambda: self.input_fun(self.df_train), steps=10000)
results = model.evaluate(input_fn=lambda: self.input_fun(self.df_test), steps=5, metrics=validation_metrics)
for key in sorted(results):
    print('% s: % s' % (key, results[key]))

hwo 重新加载经过训练的模型并使用 predict 对新数据进行分类?

杨利杰

训练时

您先调用svm.SVM(..., model_dir),然后再调用fit()andevaluate()方法。

测试时

您调用svm.SVM(..., model_dir)然后可以调用predict()方法。您的模型将在 中找到经过训练的模型model_dir并将加载经过训练的模型参数。

参考

TF 的第 3340 期

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

Tensorflow中的tf.contrib模块的目的是什么?

TensorFlow-tf.layers与tf.contrib.layers

使用Tensorflow构建SVM

无法在tensorflow会话中保存tf.contrib.learn宽模型和深模型并在TensorFlow Serving上提供

tf.contrib.learn.LinearRegressor为具有一个功能的数据构建意外不良模型

AttributeError:模块“ tensorflow.contrib.learn”没有属性“ TensorFlowDNNClassifier”

TensorFLow:tf.contrib.rnn模块对象不可调用

Tensorflow服务-使用tf.contrib.learn.Experiment训练的模型的“无可服务版本”消息

tensorflow.contrib.learn.ExportStrategy的示例

如何使用tf.contrib.learn.Experiment中的train_and_evaluate函数正确应用辍学

tf.contrib.learn load_csv_with_header在TensorFlow 1.1中不起作用

将tf.contrib.learn输入馈入DNNClassifier

使用tf.contrib.learn解决基本的物流分类器

Tensorflow:tf.nn.dropout和tf.contrib.rnn.DropoutWrapper有什么区别?

使用tf.contrib.learn.Experiment需要tf.train.replica_device_setter吗?

Tensorflow:tf.contrib弃用

使用scikit-learn训练数据时SVM多类分类停止

如何在tensorflow2.0中导入'tf.contrib.seq2seq.dynamic_decoder'?

Tensorflow 2中的tf.contrib.layers.fully_connected()吗?

如何将`tf.contrib.lookup.index_table_from_file`转换为Tensorflow v2

如何更新tensorflow以支持tf.contrib?

为什么Tensorflow tf.learn分类结果差异很大?

TensorFlow:如何在tf.contrib.metrics.streaming_mean_iou中获得total_cm

tensorflow0.12 的 contrib.learn.estimator()

使用 tf.contrib.learn.LinearClassifier 后如何保存和加载张量流模型?

线性回归,Tensorflow,非线性方程,tf.contrib.learn

使用“tf.contrib.factorization.KMeansClustering”

如何使用 tensorflow 函数 tf.contrib.legacy_seq2seq.sequence_loss_by_example 的“权重”参数?

如何为 tf.contrib.learn.DNNRegressor 选择参数