在TensorFlow中运行基本的分布式MNIST求解器

约翰·克莱夫

我正在尝试学习一种模型来预测分布式TensorFlow中的MNIST类。我已经阅读了分布式TensorFlow的主要页面,但是我不明白如何运行以创建分布式TensorFlow模型。

目前,我只是基于此处的代码使用线性分类器

如何运行此模型?我从中获得代码的链接说此命令应在终端中运行:

python dist_minst_softmax.py
    --ps_hosts=localhost:2222,localhost:2223 
    --worker_hosts=localhost:2224,localhost:2225 
    --job_name=worker --task_index=1

如果在终端中运行此命令,则会收到以下消息:

2018-04-23 11:02:35.034319: I tensorflow/core/distributed_runtime/master.cc:221] CreateSession still waiting for response from worker: /job:ps/replica:0/task:0
2018-04-23 11:02:35.034375: I tensorflow/core/distributed_runtime/master.cc:221] CreateSession still waiting for response from worker: /job:worker/replica:0/task:0

此消息会无限期地重复。那么如何开始培训过程呢?

供参考,模型定义如下:

import argparse
import sys

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

FLAGS = None


def main(_):
    ps_hosts = FLAGS.ps_hosts.split(",")
    worker_hosts = FLAGS.worker_hosts.split(",")

    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

    server = tf.train.Server(cluster, 
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)

    if FLAGS.job_name == "ps":
        server.join()
    elif FLAGS.job_name == "worker":

        with tf.device(tf.train.replica_device_setter(
                           worker_device="/job:worker/task:%d" % FLAGS.task_index,
                           cluster=cluster)):

            global_step = tf.contrib.framework.get_or_create_global_step()

            with tf.name_scope("input"):
                mnist = input_data.read_data_sets("./input_data", one_hot=True)
                x = tf.placeholder(tf.float32, [None, 784], name="x-input")
                y_ = tf.placeholder(tf.float32, [None, 10], name="y-input")

            tf.set_random_seed(1)
            with tf.name_scope("weights"):
                W = tf.Variable(tf.zeros([784, 10]))
                b = tf.Variable(tf.zeros([10]))

            with tf.name_scope("model"):
                y = tf.matmul(x, W) + b

            with tf.name_scope("cross_entropy"):
                cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

            with tf.name_scope("train"):
                train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

            with tf.name_scope("acc"):
                init_op = tf.initialize_all_variables()
                correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
                accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                             global_step=global_step,
                             init_op=init_op)

    with sv.prepare_or_wait_for_session(server.target) as sess:
        for _ in range(100):
          batch_xs, batch_ys = mnist.train.next_batch(100)
          sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

        print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.register("type", "bool", lambda v: v.lower() == "true")
    # Flags for defining the tf.train.ClusterSpec
    parser.add_argument(
        "--ps_hosts",
        type=str,
        default="",
       help="Comma-separated list of hostname:port pairs"
    )
    parser.add_argument(
        "--worker_hosts",
        type=str,
        default="",
        help="Comma-separated list of hostname:port pairs"
    )
    parser.add_argument(
        "--job_name",
        type=str,
        default="",
        help="One of 'ps', 'worker'"
    )
    # Flags for defining the tf.train.Server
    parser.add_argument(
        "--task_index",
        type=int,
        default=0,
        help="Index of task within the job"
    )
    FLAGS, unparsed = parser.parse_known_args()
    print(FLAGS, unparsed)
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
MPA

您应该首先初始化您的ps_server,然后启动您的worker带有一个ps和一个的示例worker

python dist_minst_softmax.py 
       --ps_hosts=localhost:2222
       --worker_hosts=localhost:2223
       --job_name=ps --task_index=0

python dist_minst_softmax.py 
       --ps_hosts=localhost:2222 
       --worker_hosts=localhost:2223 
       --job_name=worker --task_index=0

我无法运行您提供给我的示例代码,因为我的计算机没有配置BLAS,但至少它尝试执行一些操作...

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章