在一些应用中,比如 slot attention(这里是在 Pytorch 中实现的),需要沿着批次维度进行广播。但是,我看不到如何使用功能 API 来做到这一点。例如,
import tensorflow as tf
const = tf.ones((1,4))
input = tf.keras.layers.Input((4))
const = tf.broadcast_to(const, input.shape)
引发以下错误:
ValueError: Cannot convert a partially known TensorShape to a Tensor: (None, 4)
因此,我求助于 subclassing tf.keras.Model
,但我想将我的代码保留在功能 API 中。有谁知道如何做到这一点?
最后通过使用找到了答案tf.keras.backend.shape
:
const = tf.ones((1,4))
input = tf.keras.layers.Input((4))
const = tf.broadcast_to(const, [tf.keras.backend.shape(input)[0], 4] )
# Shape of const is now (None, 4)
本文收集自互联网,转载请注明来源。
如有侵权,请联系 [email protected] 删除。
我来说两句