当我们想使用分布式TensorFlow时,我们将使用以下命令创建参数服务器
tf.train.Server.join()
但是,除了杀死进程之外,我找不到关闭服务器的任何方法。join()的TensorFlow文档是
Blocks until the server has shut down.
This method currently blocks forever.
这让我很困扰,因为我想创建许多用于计算的服务器,并在一切完成后将其关闭。
是否有可能的解决方案。
谢谢
您可以通过使用session.run(dequeue_op)
来使参数服务器进程按需终止,而不是server.join()
在您希望该进程终止时使用,让另一个进程将某些东西排队到该队列中。
因此,对于k
参数服务器分片,您可以创建k
具有唯一shared_name
属性的队列,然后尝试dequeue
从该队列开始。当您要关闭服务器时,可以遍历所有队列并将enqueue
令牌循环到每个队列上。这将导致session.run
取消阻止,并且Python进程将运行到最后并退出,从而使服务器停机。
下面是一个包含两个分片的独立示例:https : //gist.github.com/yaroslavvb/82a5b5302449530ca5ff59df520c369e
(有关多工作者/多分片的示例,请参见https://gist.github.com/yaroslavvb/ea1b1bae0a75c4aae593df7eca72d9ca)
import subprocess
import tensorflow as tf
import time
import sys
flags = tf.flags
flags.DEFINE_string("port1", "12222", "port of worker1")
flags.DEFINE_string("port2", "12223", "port of worker2")
flags.DEFINE_string("task", "", "internal use")
FLAGS = flags.FLAGS
# setup local cluster from flags
host = "127.0.0.1:"
cluster = {"worker": [host+FLAGS.port1, host+FLAGS.port2]}
clusterspec = tf.train.ClusterSpec(cluster).as_cluster_def()
if __name__=='__main__':
if not FLAGS.task: # start servers and run client
# launch distributed service
def runcmd(cmd): subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT)
runcmd("python %s --task=0"%(sys.argv[0]))
runcmd("python %s --task=1"%(sys.argv[0]))
time.sleep(1)
# bring down distributed service
sess = tf.Session("grpc://"+host+FLAGS.port1)
queue0 = tf.FIFOQueue(1, tf.int32, shared_name="queue0")
queue1 = tf.FIFOQueue(1, tf.int32, shared_name="queue1")
with tf.device("/job:worker/task:0"):
add_op0 = tf.add(tf.ones(()), tf.ones(()))
with tf.device("/job:worker/task:1"):
add_op1 = tf.add(tf.ones(()), tf.ones(()))
print("Running computation on server 0")
print(sess.run(add_op0))
print("Running computation on server 1")
print(sess.run(add_op1))
print("Bringing down server 0")
sess.run(queue0.enqueue(1))
print("Bringing down server 1")
sess.run(queue1.enqueue(1))
else: # Launch TensorFlow server
server = tf.train.Server(clusterspec, config=None,
job_name="worker",
task_index=int(FLAGS.task))
print("Starting server "+FLAGS.task)
sess = tf.Session(server.target)
queue = tf.FIFOQueue(1, tf.int32, shared_name="queue"+FLAGS.task)
sess.run(queue.dequeue())
print("Terminating server"+FLAGS.task)
本文收集自互联网,转载请注明来源。
如有侵权,请联系 [email protected] 删除。
我来说两句