Skip to content

Latest commit

 

History

History
95 lines (64 loc) · 3.78 KB

71.md

File metadata and controls

95 lines (64 loc) · 3.78 KB

6.1。高级扩展 API

原文: http://numba.pydata.org/numba-doc/latest/extending/high-level.html

此扩展 API 通过 numba.extending 模块公开。

6.1.1。实现功能

@overload装饰器允许您实现在 nopython 模式功能中使用的任意函数。用@overload修饰的函数在编译时使用函数运行时参数的 _ 类型 _ 调用。它应该返回一个 callable,表示给定类型的函数的 _ 实现 _。返回的实现由 Numba 编译,就像它是用@jit修饰的普通函数一样。 @jit的其他选项可以使用jit_options参数作为字典传递。

例如,让我们假装 Numba 不支持元组上的 len() 函数。以下是使用@overload实现它的方法:

from numba import types
from numba.extending import overload

@overload(len)
def tuple_len(seq):
   if isinstance(seq, types.BaseTuple):
       n = len(seq)
       def len_impl(seq):
           return n
       return len_impl

您可能想知道,如果使用除元组以外的其他内容调用 len() 会发生什么?如果用@overload修饰的函数不返回任何内容(即返回 None),则尝试其他定义直到成功。因此,多个库可能会使 len() 超载不同类型,而不会相互冲突。

6.1.2。实施方法

@overload_method装饰器类似地允许在 Numba 众所周知的类型上实现方法。以下示例在 Numpy 数组上实现 take() 方法:

@overload_method(types.Array, 'take')
def array_take(arr, indices):
   if isinstance(indices, types.Array):
       def take_impl(arr, indices):
           n = indices.shape[0]
           res = np.empty(n, arr.dtype)
           for i in range(n):
               res[i] = arr[indices[i]]
           return res
       return take_impl

6.1.3。实现属性

@overload_attribute装饰器允许在类型上实现数据属性(或属性)。只能读取属性;只有低级 API 支持可写属性。

以下示例在 Numpy 数组上实现 nbytes 属性:

@overload_attribute(types.Array, 'nbytes')
def array_nbytes(arr):
   def get(arr):
       return arr.size * arr.itemsize
   return get

6.1.4。导入 Cython 函数

函数get_cython_function_address获取 Cython 扩展模块中 C 函数的地址。该地址可用于通过 ctypes.CFUNCTYPE() 回调访问 C 函数,从而允许在 Numba jitted 函数中使用 C 函数。例如,假设您有文件foo.pyx

from libc.math cimport exp

cdef api double myexp(double x):
    return exp(x)

您可以通过以下方式从 Numba 访问myexp

import ctypes
from numba.extending import get_cython_function_address

addr = get_cython_function_address("foo", "myexp")
functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)
myexp = functype(addr)

函数myexp现在可以在 jitted 函数中使用,例如:

@njit
def double_myexp(x):
    return 2*myexp(x)

需要注意的是,如果您的函数使用 Cython 的融合类型,那么函数的名称将被破坏。要找出函数的错位名称,可以检查扩展模块的__pyx_capi__属性。