我正在寻找与tf.unsorted_segment_sum类似的函数,但我不想对这些段求和,我想将每个段都作为张量。
因此,例如,我有以下代码:(实际上,我的张量形状为(10000,63),并且段数为2500)
to_be_sliced = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.3, 0.2, 0.2, 0.6, 0.3],
[0.9, 0.8, 0.7, 0.6, 0.5],
[2.0, 2.0, 2.0, 2.0, 2.0]])
indices = tf.constant([0, 2, 0, 1])
num_segments = 3
tf.unsorted_segment_sum(to_be_sliced, indices, num_segments)
输出将在这里
array([sum(row1+row3), row4, row2]
我正在寻找的是具有不同形状的3张量(也许是张量的列表),第一张包含原始的第一行和第三行(形状为(2,5)),第二张包含第四行(形状为(1 ,5)),第三个包含第二行,如下所示:
[array([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.9, 0.8, 0.7, 0.6, 0.5]]),
array([[2.0, 2.0, 2.0, 2.0, 2.0]]),
array([[0.3, 0.2, 0.2, 0.6, 0.3]])]
提前致谢!
您可以这样做:
import tensorflow as tf
to_be_sliced = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.3, 0.2, 0.2, 0.6, 0.3],
[0.9, 0.8, 0.7, 0.6, 0.5],
[2.0, 2.0, 2.0, 2.0, 2.0]])
indices = tf.constant([0, 2, 0, 1])
num_segments = 3
result = [tf.boolean_mask(to_be_sliced, tf.equal(indices, i)) for i in range(num_segments)]
with tf.Session() as sess:
print(*sess.run(result), sep='\n')
输出:
[[0.1 0.2 0.3 0.4 0.5]
[0.9 0.8 0.7 0.6 0.5]]
[[2. 2. 2. 2. 2.]]
[[0.3 0.2 0.2 0.6 0.3]]
本文收集自互联网,转载请注明来源。
如有侵权,请联系 [email protected] 删除。
我来说两句