Tensorflow:使用对角线/超对角线创建输入的对角矩阵

用户名

我有以下代码:

import tensorflow as tf

N = 10
X = tf.ones([N,], dtype=tf.float64)
D = tf.linalg.diag(X, k=1, num_rows=N+1, num_cols=N+1)

print(D)

根据TF2文档,我希望返回第一个11x11张量,在第一个超对角线上插入X(即使没有可选num_rowsnum_cols参数)。但是,结果是

tf.Tensor(
[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]], shape=(10, 10), dtype=float64)

有什么明显的我想念的东西吗?

sohv89

我可以告诉您为什么这行不通,但是我不知道解决方法是什么。可能引发github问题。

如果您在中查看array_ops.py它会进行兼容性检查,tf.compat.forward_compatible以查看兼容性窗口是否已过期。返回False(对于TF 2.0.0和2.1.0rc0)。由于这个原因,它执行

return gen_array_ops.matrix_diag(diagonal=diagonal, name=name)

你可以看到,没有一个knum_rowsnum_cols是在调用中使用。因此,如果该tf.compat.forward_compatible检查失败,则该方法当前完全不考虑这些参数

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章