如何使用 Tensorflow 功能 API 沿批量维度进行广播?

亚历克斯·特雷维西克

在一些应用中,比如 slot attention(这里在 Pytorch 中实现的),需要沿着批次维度进行广播。但是,我看不到如何使用功能 API 来做到这一点。例如,

import tensorflow as tf
const = tf.ones((1,4))
input = tf.keras.layers.Input((4))

const = tf.broadcast_to(const, input.shape)

引发以下错误:

ValueError: Cannot convert a partially known TensorShape to a Tensor: (None, 4)

因此,我求助于 subclassing tf.keras.Model,但我想将我的代码保留在功能 API 中。有谁知道如何做到这一点?

亚历克斯·特雷维西克

最后通过使用找到了答案tf.keras.backend.shape

const = tf.ones((1,4))
input = tf.keras.layers.Input((4))

const = tf.broadcast_to(const, [tf.keras.backend.shape(input)[0], 4] )

# Shape of const is now (None, 4)

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

使用Tensorflow Keras功能API在专家模型混合中进行短路计算

使用tf.keras功能API时如何在Tensorflow / Keras中使用for循环?

Tensorflow 2:如何在 keras 功能 API 中使用密集层堆栈?

如何在功能性API中正确使用Tensorflow插件的指标?

Tensorflow Keras功能API模型可以训练Tensorflow变量吗?可以在功能性API模型中使用Tensorflow操作吗?

如何使用fetch进行通用API调用功能

TensorFlow 2:如何使用分散功能?

如何使用Tensorflow的Estimator API进行分布式培训?

无法使用TensorFlow Go API进行预测

如何沿批处理维度广播numpy索引?

如何使用新功能扩展log api

如何在TensorFlow中使用“ group_by_window”功能

如何使用功能 API 在 keras 中实现 Merge 功能

如何使用C API遍历Tensorflow图?

Tensorflow-如何将int32转换为字符串(使用Python API进行Tensorflow)

在TensorFlow中使用矩阵乘法功能

如何在功能组件中使用react-redux进行api调用

为什么我在使用 tensorflow 时收到警告/错误(使用功能性 API 和未实现的错误)

如何使用Spark SQL广播功能

功能性API中的Tensorflow Keras乙状结肠激活

在TensorFlow中的功能API和模型训练期间定义输入

使用Clojure进行多个API请求的功能方法

通过feature_columns使用数据集API将免费文本功能获取到Tensorflow固定估计器中

使用高级keras API在Tensorflow中进行量化感知训练

使用tensorflow API对检测到的对象进行计数

使用tensorflow对象检测API进行性别识别

Tensorflow冻结图Protobuf无法使用C API进行预测

使用TensorFlow-Keras API进行数据增强

使用对象检测 tensorflow API 对录制的视频进行预测