检查两个numpy数组是否相同

交流电

假设我有一堆数组,包括xy,我想检查它们是否相等。通常,我可以使用np.all(x == y)(除非我现在忽略了一些笨拙的案例)。

但是,这会评估的整个数组(x == y),通常不需要。我的阵列是真的大了,我有很多的人,和两个数组相等的概率很小,因此在所有的可能性,我真的只需要评估的一个非常小的一部分(x == y)之前all函数可以返回False,所以这对我来说不是最佳解决方案。

我试过结合使用内置all函数itertools.izipall(val1==val2 for val1,val2 in itertools.izip(x, y))

不过,这似乎只是在两个数组的情况下慢得多相等的,即总体而言,它使用过的STIL不值得np.all我猜想是因为内置all的通用性。并且np.all不适用于发电机。

有没有一种方法可以更快地完成我想做的事情?

我知道这个问题类似于先前提出的问题(例如,比较两个numpy数组是否相等,逐个元素),但是它们特别不涉及提前终止的情况。

风神

在以numpy本机实现之前,您可以编写自己的函数并使用numba对其进行jit编译

import numpy as np
import numba as nb


@nb.jit(nopython=True)
def arrays_equal(a, b):
    if a.shape != b.shape:
        return False
    for ai, bi in zip(a.flat, b.flat):
        if ai != bi:
            return False
    return True


a = np.random.rand(10, 20, 30)
b = np.random.rand(10, 20, 30)


%timeit np.all(a==b)  # 100000 loops, best of 3: 9.82 µs per loop
%timeit arrays_equal(a, a)  # 100000 loops, best of 3: 9.89 µs per loop
%timeit arrays_equal(a, b)  # 100000 loops, best of 3: 691 ns per loop

最坏的情况下的性能(数组相等)等效np.all于这种情况,并且在尽早停止编译功能的情况下,其性能可能会大大超出np.all性能

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章