我知道为了将元素添加到集合中,它必须是可哈希的,而numpy数组似乎不是。这引起了一些问题,因为我有以下代码:
fill_set = set()
for i in list_of_np_1D:
vecs = i + np_2D
for j in range(N):
tup = tuple(vecs[j,:])
fill_set.add(tup)
# list_of_np_1D is a list of 1D numpy arrays
# np_2D is a 2D numpy array
# np_2D could also be converted to a list of 1D arrays if it helped.
我需要使它运行得更快,并且将近50%的运行时间用于将2D numpy数组的切片转换为元组,以便可以将它们添加到集合中。
所以我一直在尝试找出以下
谢谢你的帮助!
首先创建一些数据:
import numpy as np
np.random.seed(1)
list_of_np_1D = np.random.randint(0, 5, size=(500, 6))
np_2D = np.random.randint(0, 5, size=(20, 6))
运行您的代码:
%%time
fill_set = set()
for i in list_of_np_1D:
vecs = i + np_2D
for v in vecs:
tup = tuple(v)
fill_set.add(tup)
res1 = np.array(list(fill_set))
输出:
CPU times: user 161 ms, sys: 2 ms, total: 163 ms
Wall time: 167 ms
这是一个加速版本,它使用广播.view()
方法将dtype转换为字符串,然后调用set()
将字符串转换回array:
%%time
r = list_of_np_1D[:, None, :] + np_2D[None, :, :]
stype = "S%d" % (r.itemsize * np_2D.shape[1])
fill_set2 = set(r.ravel().view(stype).tolist())
res2 = np.zeros(len(fill_set2), dtype=stype)
res2[:] = list(fill_set2)
res2 = res2.view(r.dtype).reshape(-1, np_2D.shape[1])
输出:
CPU times: user 13 ms, sys: 1 ms, total: 14 ms
Wall time: 14.6 ms
要检查结果:
np.all(res1[np.lexsort(res1.T), :] == res2[np.lexsort(res2.T), :])
您也可以使用lexsort()
删除重复的数据:
%%time
r = list_of_np_1D[:, None, :] + np_2D[None, :, :]
r = r.reshape(-1, r.shape[-1])
r = r[np.lexsort(r.T)]
idx = np.where(np.all(np.diff(r, axis=0) == 0, axis=1))[0] + 1
res3 = np.delete(r, idx, axis=0)
输出:
CPU times: user 13 ms, sys: 3 ms, total: 16 ms
Wall time: 16.1 ms
要检查结果:
np.all(res1[np.lexsort(res1.T), :] == res3)
本文收集自互联网,转载请注明来源。
如有侵权,请联系 [email protected] 删除。
我来说两句