如何优化这些嵌套循环?

亚历克斯

我试图计算一种颜色(或最接近17种颜色之一)出现在图像像素中的次数(给定为300x300x3 np数组,浮点值[0,1])。我已经写了这个,但是看起来效率很低:

for w in range(300):
    for h in range(300):
        colordistance = float('inf')
        colorindex = 0
        for c in range(17):
            r1 = color[c, 0]
            g1 = color[c, 1]
            b1 = color[c, 2]
            r2 = img[w, h, 0]
            g2 = img[w, h, 1]
            b2 = img[w, h, 2]
            distance = math.sqrt(
                ((r2-r1)*0.3)**2 + ((g2-g1)*0.59)**2 + ((b2-b1)*0.11)**2)
            if distance < colordistance:
                colordistance = distance
                colorindex = c
        colorcounters[colorindex] = colorcounters[colorindex] + 1

有什么方法可以提高效率吗?我已经在对外部循环使用多处理了。

c2huc2hu

您提到您正在使用numpy,因此应避免尽可能地迭代。我的矢量化实现速度提高了约40倍。我对您的代码进行了一些更改,以便它们可以使用相同的数组,从而可以验证正确性。这可能会影响速度。

import numpy as np
import time
import math

num_images = 1
hw = 300 # height and width
nc = 17  # number of colors
img = np.random.random((num_images, hw, hw, 1, 3))
colors = np.random.random((1, 1, 1, nc, 3))

## NUMPY IMPLEMENTATION
t = time.time()

dist_sq = np.sum(((img - colors) * [0.3, 0.59, 0.11]) ** 2, axis=4)  # calculate (distance * coefficients) ** 2
largest_color = np.argmin(dist_sq, axis=3)  # find the minimum
color_counters = np.unique(largest_color, return_counts=True) # count
print(color_counters)
# should return an object like [[1, 2, 3, ... 17], [count1, count2, count3, ...]]

print("took {} s".format(time.time() - t))

## REFERENCE IMPLEMENTATION
t = time.time()
colorcounters = [0 for i in range(nc)]
for i in range(num_images):
    for h in range(hw):
        for w in range(hw):
            colordistance = float('inf')
            colorindex = 0
            for c in range(nc):
                r1 = colors[0, 0, 0, c, 0]
                g1 = colors[0, 0, 0, c, 1]
                b1 = colors[0, 0, 0, c, 2]
                r2 = img[i, w, h, 0, 0]
                g2 = img[i, w, h, 0, 1]
                b2 = img[i, w, h, 0, 2]

                # dist_sq
                distance = math.sqrt(((r2-r1)*0.3)**2 + ((g2-g1)*0.59)**2 + ((b2-b1)*0.11)**2)  # not using sqrt gives a 14% improvement

                # largest_color
                if distance < colordistance:
                    colordistance = distance
                    colorindex = c
            # color_counters
            colorcounters[colorindex] = colorcounters[colorindex] + 1
print(colorcounters)
print("took {} s".format(time.time() - t))

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章