用cython比struct.pack更快

周杰伦

我正在努力做得比struct.pack

以打包整数的特定情况为例,通过对这个问题的回答,我可以在其中打包一个整数列表pack_ints.pyx

# cython: language_level=3, boundscheck=False
import cython

@cython.boundscheck(False)
@cython.wraparound(False)
def pack_ints(int_col):

    int_buf = bytearray(4*len(int_col))
    cdef int[::1] buf_view = memoryview(int_buf).cast('i')

    idx: int = 0
    for idx in range(len(int_col)):
        buf_view[idx] = int_col[idx]


    return int_buf

在ipython中使用以下测试代码:

from struct import pack 
import pyximport; pyximport.install(language_level=3) 
import pack_ints 

amount = 10**7 
ints = list(range(amount)) 

res1 = pack(f'{amount}i', *ints) 
res2 = pack_ints.pack_ints(ints) 
assert(res1 == res2) 

%timeit pack(f'{amount}i', *ints)  
%timeit pack_ints.pack_ints(ints)      

我得到:

304 ms ± 2.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
212 ms ± 6.54 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

我尝试将其键入int_bufarray('b'),但没有看到任何改进。

还有其他方法可以改善此操作,或者以其他方式使用cython来加快此操作的速度吗?

ad

该答案试图给出一个估计,并行化版本可以产生多大的提速。但是,由于此任务是受内存带宽限制的(Python整数对象至少占用32个字节,并且可以分散在内存中,因此会有很多缓存未命中),因此我们不应期望太多。

第一个问题是如何处理错误(元素不是整数或值太大)。我将遵循策略/简化:当对象

  • 不是整数,
  • 是负整数,
  • 或整数> = 2 ^ 30

它将强制转换为特殊数字(-1),表示出现问题。只允许使用非负整数<2^30使我的生活变得更轻松,因为我必须重新实现PyLong_AsLongAndOverflow以解决引发错误,否则检测溢出通常很麻烦(但是,请参见答案末尾的版本以获取更复杂的方法)。

Python整数对象的内存布局可以在这里找到

struct _longobject {
    PyObject_VAR_HEAD
    digit ob_digit[1];
};

成员ob_size/ macroPy_SIZE告诉我们在整数表示中使用了多少个30位数字(ob_size对于负整数为负数)。

因此,我的简单规则转化为以下C代码(我使用C而不是Cython,因为它是使用Python C-API的更简单/更自然的方式):

#include <Python.h>

// returns -1 if vv is not an integer,
//            negative, or > 2**30-1
int to_int(PyObject *vv){ 
   if (PyLong_Check(vv)) {
       PyLongObject * v = (PyLongObject *)vv;
       Py_ssize_t i = Py_SIZE(v);
       if(i==0){
           return 0;
       }
       if(i==1){//small enought for a digit
           return v->ob_digit[0];
       }
       //negative (i<0) or too big (i>1)
       return -1;
   }
   return -1;
}

现在给出一个列表,我们可以将其转换为int与以下使用omp的C函数并行-buffer:

void convert_list(PyListObject *lst, int *output){
    Py_ssize_t n = Py_SIZE(lst);
    PyObject **data = lst->ob_item;
    #pragma omp parallel for
    for(Py_ssize_t i=0; i<n; ++i){
        output[i] = to_int(data[i]);
    }
}

没什么可说的PyListObject--API用于并行访问列表的元素。之所以可以这样做,是因为to_int-函数中没有引用计数/竞赛条件

现在,将所有内容与Cython捆绑在一起:

%%cython -c=-fopenmp --link-args=-fopenmp
import cython

cdef extern from *:
    """
    #include <Python.h>

    int to_int(PyObject *vv){ 
       ... code
    }

    void convert_list(PyListObject *lst, int *output){
        ... code
    }
    """
    void convert_list(list lst, int *output)

@cython.boundscheck(False)
@cython.wraparound(False)
def pack_ints_ead(list int_col):
    cdef char[::1] int_buf = bytearray(4*len(int_col))
    convert_list(int_col, <int*>(&int_buf[0]))
    return int_buf.base

一个重要的细节是:convert_list 切勿虚假(因为不是)!Omp线程和Python线程(受GIL影响)完全不同。

在使用带有缓冲区协议的对象时,可以(但不是必须)释放GIL以进行omp操作-因为这些对象是通过缓冲区协议锁定的,不能从其他Python线程更改。Alist没有这种锁定机制,因此,如果释放了GIL,则可以在另一个线程中更改列表,并且我们所有的指针都可能无效。

现在来看时间(列表稍大):

amount = 5*10**7 
ints = list(range(amount)) 


%timeit pack(f'{amount}i', *ints)  
# 1.51 s ± 38.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit pack_ints_DavidW(ints) 
# 284 ms ± 3.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit pack_ints_ead(ints) 
# 177 ms ± 11.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

btw关闭并行化pack_ints_ead会导致209 ms的运行时间。

因此,鉴于ca的适度改进。33%,我会选择功能更强大的DavidW解决方案。


这是一种用略有不同的方式发送错误值的方式的实现:

  • 不是整数对象会导致-2147483648(即0x80000000)-32位整数可以存储的最小负值。
  • 整数>=2147483647(即>=0x7fffffff)将映射/存储为2147483647-32位整数可以存储的最大正数。
  • 整数<=-2147483647(即<=0x80000001)将映射为/存储为-2147483647
  • 所有其他整数都映射到其正确值。

主要优点是,它可以在较大范围的整数值中正常工作。该算法产生的运行时间与第一个简单版本的运行时间几乎相同(可能慢2-3%):

int to_int(PyObject *vv){ 
   if (PyLong_Check(vv)) {
       PyLongObject * v = (PyLongObject *)vv;
       Py_ssize_t i = Py_SIZE(v);
       int sign = i<0 ? -1 : 1;
       i = abs(i);
       if(i==0){
           return 0;
       }
       if(i==1){//small enought for a digit
           return sign*v->ob_digit[0];
       }
       if(i==2 && (v->ob_digit[1]>>1)==0){
           int add = (v->ob_digit[1]&1) << 30;
           return sign*(v->ob_digit[0]+add);
       }
       return sign * 0x7fffffff;
   }
   return 0x80000000;
}

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章