原文: http://numba.pydata.org/numba-doc/latest/extending/high-level.html
此扩展 API 通过 numba.extending
模块公开。
@overload
装饰器允许您实现在 nopython 模式功能中使用的任意函数。用@overload
修饰的函数在编译时使用函数运行时参数的 _ 类型 _ 调用。它应该返回一个 callable,表示给定类型的函数的 _ 实现 _。返回的实现由 Numba 编译,就像它是用@jit
修饰的普通函数一样。 @jit
的其他选项可以使用jit_options
参数作为字典传递。
例如,让我们假装 Numba 不支持元组上的 len()
函数。以下是使用@overload
实现它的方法:
from numba import types
from numba.extending import overload
@overload(len)
def tuple_len(seq):
if isinstance(seq, types.BaseTuple):
n = len(seq)
def len_impl(seq):
return n
return len_impl
您可能想知道,如果使用除元组以外的其他内容调用 len()
会发生什么?如果用@overload
修饰的函数不返回任何内容(即返回 None),则尝试其他定义直到成功。因此,多个库可能会使 len()
超载不同类型,而不会相互冲突。
@overload_method
装饰器类似地允许在 Numba 众所周知的类型上实现方法。以下示例在 Numpy 数组上实现 take()
方法:
@overload_method(types.Array, 'take')
def array_take(arr, indices):
if isinstance(indices, types.Array):
def take_impl(arr, indices):
n = indices.shape[0]
res = np.empty(n, arr.dtype)
for i in range(n):
res[i] = arr[indices[i]]
return res
return take_impl
@overload_attribute
装饰器允许在类型上实现数据属性(或属性)。只能读取属性;只有低级 API 支持可写属性。
以下示例在 Numpy 数组上实现 nbytes
属性:
@overload_attribute(types.Array, 'nbytes')
def array_nbytes(arr):
def get(arr):
return arr.size * arr.itemsize
return get
函数get_cython_function_address
获取 Cython 扩展模块中 C 函数的地址。该地址可用于通过 ctypes.CFUNCTYPE()
回调访问 C 函数,从而允许在 Numba jitted 函数中使用 C 函数。例如,假设您有文件foo.pyx
:
from libc.math cimport exp
cdef api double myexp(double x):
return exp(x)
您可以通过以下方式从 Numba 访问myexp
:
import ctypes
from numba.extending import get_cython_function_address
addr = get_cython_function_address("foo", "myexp")
functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)
myexp = functype(addr)
函数myexp
现在可以在 jitted 函数中使用,例如:
@njit
def double_myexp(x):
return 2*myexp(x)
需要注意的是,如果您的函数使用 Cython 的融合类型,那么函数的名称将被破坏。要找出函数的错位名称,可以检查扩展模块的__pyx_capi__
属性。