# 如何比较持有numpy.ndarray（bool（a == b）引发ValueError）的数据类的相等性？

nyanpasu64

``````import numpy as np

@dataclass
class Instr:
foo: np.ndarray
bar: np.ndarray

arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))
``````

ValueError：具有多个元素的数组的真值不明确。使用a.any（）或a.all（）

`@dataclass`的实现`__eq__`是使用生成的`eval()`它的来源从stacktrace中丢失了，并且无法使用进行查看`inspect`，但实际上是在使用元组比较（调用bool（foo））。

``````import dis
dis.dis(Instr.__eq__)
``````

``````  3          12 LOAD_FAST                0 (self)
20 BUILD_TUPLE              2
30 BUILD_TUPLE              2
32 COMPARE_OP               2 (==)
34 RETURN_VALUE
``````

``````import numpy as np

def array_eq(arr1, arr2):
return (isinstance(arr1, np.ndarray) and
isinstance(arr2, np.ndarray) and
arr1.shape == arr2.shape and
(arr1 == arr2).all())

@dataclass(eq=False)
class Instr:

foo: np.ndarray
bar: np.ndarray

def __eq__(self, other):
if not isinstance(other, Instr):
return NotImplemented
return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)
``````

## 编辑

``````import numpy as np
from dataclasses import dataclass, astuple

def array_safe_eq(a, b) -> bool:
"""Check if a and b are equal, even if they are numpy arrays"""
if a is b:
return True
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
return a.shape == b.shape and (a == b).all()
try:
return a == b
except TypeError:
return NotImplemented

def dc_eq(dc1, dc2) -> bool:
"""checks if two dataclasses which hold numpy arrays are equal"""
if dc1 is dc2:
return True
if dc1.__class__ is not dc2.__class__:
return NotImplmeneted  # better than False
t1 = astuple(dc1)
t2 = astuple(dc2)
return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))

# usage
@dataclass(eq=False)
class T:

a: int
b: np.ndarray
c: np.ndarray

def __eq__(self, other):
return dc_eq(self, other)
``````

0 条评论