在TensorFlow中编写类似Conv2D的操作

奥利弗·厄本(Oliver Urbann)

在我的CNN中,我需要一个执行Conv2D之类的操作的层,该层要减去而不是相乘。我已经有一个有效的代码,其中inputs[0]是完整图像和inputs[1]具有形状的Tensor,例如(None, 5, 3, 512)我已经在Keras中实现了一个自定义层,这是其中的一部分call()

    ...
    lines = []
    for x in range(0, x_max, x_step):
        line_parts = []
        for y in range(0, y_max, y_step):
            line_parts.append(inputs[0][:,x:x+x_step, y:y+y_step] - inputs[1])
        line = K.concatenate(line_parts, 2)
        lines.append(line)
    img = K.concatenate(lines, 1)
    ...

但是,随着尺寸变小x_stepy_step变得太大。在TensorFlow的低级部分中如何在C ++或CUDA中不实现这种循环的情况下,应该如何实现这种循环?

我试图切片input[0]然后使用,tf.map_fn但找不到能立即切出所有所需的较小张量而没有循环的操作。此外,我尝试使用,tf.while_loop但是在创建tf.Variable具有形状的空洞遇到了问题,[None, ...]而且我也没有看到用于tf.concat从空洞构造最终Tensor的解决方案

提前谢谢!

Jdehesa

我认为您可以按照以下步骤进行操作:

import tensorflow as tf

def subtract_patches(imgs, patches):
    # Get dimensions
    img_shape = tf.shape(imgs)
    img_h = img_shape[1]
    img_w = img_shape[2]
    img_c = img_shape[3]
    patch_shape = tf.shape(patches)
    patch_h = patch_shape[1]
    patch_w = patch_shape[2]
    # Reshape image into patches
    imgs = tf.reshape(imgs, [-1, img_h // patch_h, patch_h, img_w // patch_w, patch_w, img_c])
    # Do subtraction
    out = imgs - tf.expand_dims(tf.expand_dims(patches, 1), 3)
    # Reshape result back
    out = tf.reshape(out, img_shape)
    return out

# Test
with tf.Graph().as_default(), tf.Session() as sess:
    imgs = tf.reshape(tf.range(2 * 6 * 8 * 2, dtype=tf.float32), (2, 6, 8, 2))
    patches = 0.1 * tf.reshape(tf.range(2 * 3 * 4 * 2, dtype=tf.float32), (2, 3, 4, 2))
    out = subtract_patches(imgs, patches)
    print(sess.run(out))

输出:

[[[[  0.    0.9]
   [  1.8   2.7]
   [  3.6   4.5]
   [  5.4   6.3]
   [  8.    8.9]
   [  9.8  10.7]
   [ 11.6  12.5]
   [ 13.4  14.3]]

  [[ 15.2  16.1]
   [ 17.   17.9]
   [ 18.8  19.7]
   [ 20.6  21.5]
   [ 23.2  24.1]
   [ 25.   25.9]
   [ 26.8  27.7]
   [ 28.6  29.5]]

  [[ 30.4  31.3]
   [ 32.2  33.1]
   [ 34.   34.9]
   [ 35.8  36.7]
   [ 38.4  39.3]
   [ 40.2  41.1]
   [ 42.   42.9]
   [ 43.8  44.7]]

  [[ 48.   48.9]
   [ 49.8  50.7]
   [ 51.6  52.5]
   [ 53.4  54.3]
   [ 56.   56.9]
   [ 57.8  58.7]
   [ 59.6  60.5]
   [ 61.4  62.3]]

  [[ 63.2  64.1]
   [ 65.   65.9]
   [ 66.8  67.7]
   [ 68.6  69.5]
   [ 71.2  72.1]
   [ 73.   73.9]
   [ 74.8  75.7]
   [ 76.6  77.5]]

  [[ 78.4  79.3]
   [ 80.2  81.1]
   [ 82.   82.9]
   [ 83.8  84.7]
   [ 86.4  87.3]
   [ 88.2  89.1]
   [ 90.   90.9]
   [ 91.8  92.7]]]


 [[[ 93.6  94.5]
   [ 95.4  96.3]
   [ 97.2  98.1]
   [ 99.   99.9]
   [101.6 102.5]
   [103.4 104.3]
   [105.2 106.1]
   [107.  107.9]]

  [[108.8 109.7]
   [110.6 111.5]
   [112.4 113.3]
   [114.2 115.1]
   [116.8 117.7]
   [118.6 119.5]
   [120.4 121.3]
   [122.2 123.1]]

  [[124.  124.9]
   [125.8 126.7]
   [127.6 128.5]
   [129.4 130.3]
   [132.  132.9]
   [133.8 134.7]
   [135.6 136.5]
   [137.4 138.3]]

  [[141.6 142.5]
   [143.4 144.3]
   [145.2 146.1]
   [147.  147.9]
   [149.6 150.5]
   [151.4 152.3]
   [153.2 154.1]
   [155.  155.9]]

  [[156.8 157.7]
   [158.6 159.5]
   [160.4 161.3]
   [162.2 163.1]
   [164.8 165.7]
   [166.6 167.5]
   [168.4 169.3]
   [170.2 171.1]]

  [[172.  172.9]
   [173.8 174.7]
   [175.6 176.5]
   [177.4 178.3]
   [180.  180.9]
   [181.8 182.7]
   [183.6 184.5]
   [185.4 186.3]]]]

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章