如何在张量流中停止 LSTMStateTuple 的梯度

Yu Li

我正在运行用于语言建模的基本 lstm 代码。但我不想做BPTT我想做类似的事情tf.stop_gradient(state)

with tf.variable_scope("RNN"):
  for time_step in range(N):
    if time_step > 0: tf.get_variable_scope().reuse_variables()
    (cell_output, state) = cell(inputs[:, time_step, :], state)

但是,stateLSTMStateTuple,所以我试过:

for lli in range(len(state)):
    print(state[lli].c, state[lli].h)
    state[lli].c = tf.stop_gradient(state[lli].c)
    state[lli].h = tf.stop_gradient(state[lli].h)

但我得到了一个AttributeError: can't set attribute错误:

File "/home/liyu-iri/IRRNNL/word-rnn/ptb/models/decoupling.py", line 182, in __init__
state[lli].c = tf.stop_gradient(state[lli].c)
AttributeError: can't set attribute

我也尝试使用tf.assign,但state[lli].c不是变量。

所以,我想知道我怎么能停止梯度LSTMStateTuple或者,我怎么能阻止 BPTT?我只想在单帧中做 BP。

非常感谢!

卢卡斯·凯撒

我认为这是一个纯 python 问题:LSTMStateTuple 只是一个 collections.namedtuple 并且 python 不允许你在那里分配元素(就像在其他元组中一样)。解决方案是创建一个全新的,例如像 instopped_state = LSTMStateTuple(tf.stop_gradient(old_tuple.c), tf.stop_gradient(old_tuple.h))然后使用它(或那些列表)作为状态。如果你坚持要替换现有的元组,我认为 namedtuple 有一个 _replace 方法,见这里,如old_tuple._replace(c=tf.stop_gradient(...)). 希望有帮助!

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章