如何比较持有numpy.ndarray(bool(a == b)引发ValueError)的数据类的相等性?

nyanpasu64

如果我创建一个包含Numpy ndarray的Python数据类,我将无法再使用自动生成的数据__eq__

import numpy as np

@dataclass
class Instr:
    foo: np.ndarray
    bar: np.ndarray

arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))

ValueError:具有多个元素的数组的真值不明确。使用a.any()或a.all()

这是因为ndarray.__eq__ 有时返回ndarray真值,通过比较a[0]b[0],等等等等到2。这是相当复杂的,不直观的时间越长,而实际上只提高在阵列不同的形状,或者有错误不同的值或其他东西。

如何安全比较@dataclass持有Numpy数组的ES?


@dataclass的实现__eq__是使用生成的eval()它的来源从stacktrace中丢失了,并且无法使用进行查看inspect,但实际上是在使用元组比较(调用bool(foo))。

import dis
dis.dis(Instr.__eq__)

摘抄:

  3          12 LOAD_FAST                0 (self)
             14 LOAD_ATTR                1 (foo)
             16 LOAD_FAST                0 (self)
             18 LOAD_ATTR                2 (bar)
             20 BUILD_TUPLE              2
             22 LOAD_FAST                1 (other)
             24 LOAD_ATTR                1 (foo)
             26 LOAD_FAST                1 (other)
             28 LOAD_ATTR                2 (bar)
             30 BUILD_TUPLE              2
             32 COMPARE_OP               2 (==)
             34 RETURN_VALUE
切尔

解决方案是放入您自己的__eq__方法并进行设置,eq=False以使数据类不会生成自己的方法(尽管检查文档不是最后一步是必要的,但无论如何我还是觉得很明确)。

import numpy as np

def array_eq(arr1, arr2):
    return (isinstance(arr1, np.ndarray) and
            isinstance(arr2, np.ndarray) and
            arr1.shape == arr2.shape and
            (arr1 == arr2).all())

@dataclass(eq=False)
class Instr:

    foo: np.ndarray
    bar: np.ndarray

    def __eq__(self, other):
        if not isinstance(other, Instr):
            return NotImplemented
        return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)

编辑

通用数据类的通用快速解决方案,其中某些值是numpy数组,而另一些则不是

import numpy as np
from dataclasses import dataclass, astuple

def array_safe_eq(a, b) -> bool:
    """Check if a and b are equal, even if they are numpy arrays"""
    if a is b:
        return True
    if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
        return a.shape == b.shape and (a == b).all()
    try:
        return a == b
    except TypeError:
        return NotImplemented

def dc_eq(dc1, dc2) -> bool:
   """checks if two dataclasses which hold numpy arrays are equal"""
   if dc1 is dc2:
        return True
   if dc1.__class__ is not dc2.__class__:
       return NotImplmeneted  # better than False
   t1 = astuple(dc1)
   t2 = astuple(dc2)
   return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))

# usage
@dataclass(eq=False)
class T:

   a: int
   b: np.ndarray
   c: np.ndarray

   def __eq__(self, other):
        return dc_eq(self, other)

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

TypeError:无法比较类型'ndarray(dtype = bool)'和'str'

C-Numpy:如何根据现有数据创建字符串的固定宽度ndarray

如何为持有std :: vector的类实现operator []

如何设计用于持有Dominion卡的数据库?

如何获得对boost :: any持有的数据的const引用?

将数据框列与numpy ndarray比较并更新数据框中的值

MXNET - 无效的数据类型 '<type 'numpy.ndarray'>',应该是 NDArray、numpy.ndarray,

Numpy ndArray:访问每个类的输入特征

常量持有$ this类名

numpy ndarray到熊猫数据框

标准化numpy ndarray数据

从numpy ndarray提取特定数据

如何重写基类中的相等性?

如何获取有关随机智能合约地址的数据。比如创建日期,链,持有者(持有多少钱包)

如何扩展numpy.ndarray

您如何访问由共享指针持有的类方法?

如何从 Compose 状态持有者类中观察 ViewModel LiveData 的变化?

如何检查元素是否实际持有一些数据

如何在核心数据中创建持有人/交易实体?

比较伪类实例时的clojure相等性

关于重塑和视图的 Numpy ndarray 数据所有权问题

C#:在检查相等性时如何评估对象的multipe比较器类?

用numpy ndarray索引numpy ndarray

如何比较功能类型的相等性?

从数据帧为keras数据生成numpy-ndarray

如何同时比较多个numpy数组是否相等?

使用 Boost Python Numpy ndarray 作为类成员变量

在PHP中持有函数的类变量

如何将numpy ndarray列表放入pandas数据框的列中?