过滤Numpy的数组数组

超级询问

使用numpy的ndarray将数据预处理到神经网络。它基本上包含用于传感器数据的几个固定长度的数组。因此,例如:

>>> type(arr)
<class 'numpy.ndarray'>

>>> arr.shape
(400,1,5,4)

>>> arr
 [
  [[ 9.4 -3.7 -5.2  3.8]
   [ 2.8  1.4 -1.7  3.4]
   [ 0.0  0.0  0.0  0.0]
   [ 0.0  0.0  0.0  0.0]
   [ 0.0  0.0  0.0  0.0]]
  ..
  [[ 0.0 -1.0  2.1  0.0]
   [ 3.0  2.8 -3.0  8.2]
   [ 7.5  1.7 -3.8  2.6]
   [ 0.0  0.0  0.0  0.0]
   [ 0.0  0.0  0.0  0.0]]
 ]

每个嵌套数组都是shape的(1, 5,4)目的是通过此过程arr并仅选择至少前三行为非零的那些数组(尽管单个条目可以为零,但不能整行)。

因此,在上面给出的示例中,应该删除第一个嵌套数组,因为只有2个第一行非零,而我们需要3个及以上。

安德烈亚斯·K。

您可以使用以下技巧:

mask = arr[:,:,:3].any(axis=3).all(axis=2)
arr_filtered = arr[mask]

快速说明:要保留一个嵌套数组,它至少应具有3个第一行(因此我们只需要查看arr[:,:,:3]),以便所有它们(因此.all(axis=2)位于末尾)至少具有一个非零条目(因此.any(axis=3))。

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章