原文: http://numba.pydata.org/numba-doc/latest/developer/rewrites.html
本节介绍中间表示(IR)重写,以及如何使用它们来实现优化。
正如前面在“阶段 6a:重写类型 IR ”中所讨论的那样,重写 Numba IR 允许我们执行在较低 LLVM 级别执行将更加困难的优化。与 Numba 类型和降低子系统类似,重写子系统是用户可扩展的。这种可扩展性为 Numba 提供了支持各种特定领域优化(DSO)的可能性。
其余小节详细介绍了实现重写的机制,使用重写注册表注册重写,并提供添加新重写的示例,以及数组表达式优化传递的内部结构。最后,我们回顾一下示例中公开的一些用例,并回顾开发人员应该注意的任何问题。
重写通道有一个简单的 match()
和 apply()
接口。匹配和重写之间的划分遵循如何在声明性域特定语言(DSL)中定义术语重写。在这样的 DSL 中,可以按如下方式编写重写:
<match> => <replacement>
<match>
和<replacement>
符号表示 IR 术语表达式,其中左侧表示要匹配的模式,右侧表示 IR 术语构造函数以在匹配时构建。每当重写与 IR 模式匹配时,左侧的任何自由变量都绑定在自定义环境中。应用时,重写使用模式匹配环境绑定右侧的任何自由变量。
由于 Python 不常用于声明性容量,因此 Numba 使用对象状态来处理匹配和应用程序步骤之间的信息传输。
7.6.2.1。 Rewrite
基类
class Rewrite
Rewrite
类只是为 Numba 重写定义了一个抽象基类。开发人员应将重写定义为此基类型的子类,重载 match()
和 apply()
方法。
pipeline
pipeline 属性包含当前正在编译正在考虑重写的函数的numba.compiler.Pipeline
实例。
__init__(self, pipeline, *args, **kws)
用于重写的基础构造函数只是将其参数存储到同名的属性中。除非用于调试或测试,否则重写应仅由RewriteRegistry.apply()
方法中的RewriteRegistry
构造,并且构造接口应保持稳定(尽管管道通常包含所知的所有内容)。
match(self, block, typemap, callmap)
match()
方法除 self 外还有四个参数:
- func_ir :这是被重写函数的
numba.ir.FunctionIR
实例。 - _ 块 _:这是
numba.ir.Block
的一个实例。匹配方法应迭代numba.ir.Block.body
成员中包含的指令。 - typemap :这是一个 Python
dict
实例,它从 IR 中的符号名称(表示为字符串)映射到 Numba 类型。 - callmap :这是另一个
dict
实例从呼叫(表示为numba.ir.Expr
实例)映射到它们对应的呼叫站点类型签名,表示为numba.typing.templates.Signature
实例。
match()
方法应返回 bool
结果。 True
结果应表明发现了一个或多个匹配, apply()
方法将返回新的替换numba.ir.Block
实例。 False
结果应表明未找到匹配项,随后对 apply()
的调用将返回未定义或无效的结果。
apply(self)
只有在成功调用 match()
后才能调用 apply()
方法。此方法不需要 self 以外的其他参数,并应返回替换numba.ir.Block
实例。
如上所述,调用 apply()
的行为是未定义的,除非 match()
已被调用并返回 True
。
7.6.2.2。子类 Rewrite
在进入任何 Rewrite
子类必须具有的重载方法的期望之前,让我们回过头来回顾一下这里发生的事情。通过提供可扩展的编译器,Numba 向用户定义的代码生成器开放,这些代码生成器可能不完整,或者更糟,不正确。当代码生成器出错时,它可能导致异常的程序行为或提前终止。用户定义的重写增加了一个新的复杂程度,因为它们不仅必须生成正确的代码,而且它们生成的代码应该确保编译器不会卡在匹配/应用循环中。编译器的非终止将直接导致用户函数调用的非终止。
有几种方法可以帮助确保重写终止:
- _ 键入 _:重写通常应尝试分解复合类型,并避免编写新类型。如果重写与特定类型匹配,则将表达式类型更改为较低级别类型将确保在应用重写后它们不会长期匹配。
- _ 特殊指令 _:重写可以合成自定义运算符或在目标 IR 中使用特殊函数。此技术再次生成不再位于原始匹配域内的代码,并且重写将终止。
在下面的“案例研究:数组表达式”小节中,我们将看到数组表达式重写器如何使用这两种技术。
7.6.2.3。重载 Rewrite.match()
每个重写开发人员都应尽量让 match()
的实现尽快返回 False
值。 Numba 是一个即时编译器,添加编译时间最终会增加用户的运行时间。当重写为给定块返回 False
时,注册表将不再使用该重写处理该块,并且编译器更接近于继续降低。
这种对及时性的需求必须与收集必要信息以进行重写相匹配。重写开发人员应该很乐意为其子类添加动态属性,然后使用这些新属性来指导替换基本块的构造。
7.6.2.4。重载 Rewrite.apply()
apply()
方法应返回替换numba.ir.Block
实例以替换包含重写匹配的基本块。如上所述,由 apply()
方法构建的 IR 应该保留用户代码的语义,但也试图避免为相同的重写或重写集生成另一个匹配。
如果要在重写过程中包含重写,则应将其注册到重写注册表。 numba.rewrites
模块提供抽象基类和类装饰器,用于挂钩到 Numba 重写子系统。以下说明了新重写的存根定义:
from numba import rewrites
@rewrites.register_rewrite
class MyRewrite(rewrites.Rewrite):
def match(self, block, typemap, calltypes):
raise NotImplementedError("FIXME")
def apply(self):
raise NotImplementedError("FIXME")
开发人员应注意,使用如上所示的类装饰器将在导入时注册重写。在编译开始之前,开发人员有责任确保加载扩展。
本小节更深入地介绍了数组表达式重写器。数组表达式重写器及其大部分支持功能可在numba.npyufunc.array_exprs
模块中找到。重写传递本身在RewriteArrayExprs
类中实现。除了重写器之外,array_exprs
模块还包括用于降低数组表达式的函数_lower_array_expr()
。整体优化过程如下:
RewriteArrayExprs.match()
:重写过程查找两个或多个形成数组表达式的数组操作。RewriteArrayExprs.apply()
:一旦找到数组表达式,重写器就会用一种新的 IR 表达式arrayexpr
替换各个数组操作。numba.npyufunc.array_exprs._lower_array_expr()
:在降低期间,只要找到arrayexpr
IR 表达式,代码生成器就会调用_lower_array_expr()
。
下面给出关于优化的每个步骤的更多细节。
数组表达式优化传递首先查找数组操作,包括对支持的ufunc
和用户定义的 DUFunc
的调用。 Numba IR 遵循静态单指派(SSA)语言的约定,这意味着搜索数组运算符首先要查找赋值指令。
当重写传递调用RewriteArrayExprs.match()
方法时,它首先检查它是否可以简单地拒绝基本块。如果方法确定块是匹配的候选者,则它在重写对象中设置以下状态变量:
- crnt_block :当前匹配的基本块。
- typemap :匹配函数的 typemap 。
- _ 匹配 _:引用数组表达式的变量名列表。
- array_assigns :从赋值变量名称到定义给定变量的实际赋值指令的映射。
- const_assigns :从赋值变量名到定义常量变量的常量值表达式的映射。
此时,匹配方法迭代迭代输入基本块中的赋值指令。对于每个赋值指令,匹配器会查找以下两种情况之一:
- 数组操作:如果赋值指令的右侧是表达式,并且该表达式的结果是数组类型,则匹配器检查表达式是已知数组操作还是对通用函数的调用。如果找到数组运算符,匹配器将左侧变量名和整个指令存储在 array_assigns 成员中。最后,匹配器测试以查看阵列操作的任何操作数是否也已被识别为其他阵列操作的目标。如果一个或多个操作数也是数组操作的目标,则匹配器还会将左侧变量名称附加到 _ 匹配 _ 成员。
- 常量:常量(甚至标量)可以是数组操作的操作数。不必担心数组表达式的常量分离,匹配器在 const_assigns 成员中存储常量名称和值。
匹配方法的结束只是检查非空 _ 匹配 _ 列表,如果有一个或多个匹配则返回 True
, False
当 _ 匹配 _ 为空时。
当RewriteArrayExprs.match()
找到一个或匹配的数组表达式时,重写过程将调用RewriteArrayExprs.apply()
。 apply 方法有两次通过。第一遍迭代所找到的匹配,并根据旧基本块中的指令构建映射到新基本块中的新指令。第二遍迭代旧基本块中的指令,复制未被重写改变的指令,以及替换或删除第一遍所识别的指令。
RewriteArrayExprs._handle_matches()
实现重写的代码生成部分的第一遍。对于每个匹配,此方法构建一个特殊的 IR 表达式,其中包含数组表达式的表达式树。为了计算表达式树的叶子,_handle_matches()
方法迭代所识别的根操作的操作数。如果操作数是另一个数组操作,则将其转换为表达式子树。如果操作数是常量,_handle_matches()
将复制常量值。否则,操作数被标记为由数组表达式使用。当该方法构建数组表达式节点时,它构建从旧指令到新指令( replace_map )的映射,以及可能已移动的变量集( used_vars )和变量应该完全删除( dead_vars )。这三个数据结构返回到调用RewriteArrayExprs.apply()
方法。
RewriteArrayExprs.apply()
方法的剩余部分迭代旧基本块中的指令。对于每条指令,此方法根据RewriteArrayExprs._handle_matches()
的结果替换,删除或复制该指令。以下列表描述了优化如何处理单个指令:
- 当指令是赋值时,
apply()
检查它是否在替换指令映射中。当在指令映射中找到赋值指令时,apply()
必须检查替换指令是否也在替换映射中。优化器继续此检查,直到它到达None
值或不在替换映射中的指令。删除具有None
的替换的指令。更换非None
替换的说明。不在替换映射中的赋值指令被附加到新的基本块而不进行任何更改。 - 当指令是删除指令时,重写检查它是否删除可能仍被稍后的数组表达式使用的变量,或者它是否删除了死变量。删除已使用变量的指令被添加到延迟删除指令的映射中,
apply()
使用它来移动它们超过该变量的任何用途。循环复制删除非死变量的指令,并忽略死变量的删除指令(有效地从基本块中删除它们)。 - 所有其他指令都附加到新的基本块。
最后,apply()
方法返回用于降低的新基本块。
如果我们只是重写,那么编译器的降低阶段就会失败,抱怨它不知道如何降低arrayexpr
操作。我们首先在编译器实例化RewriteArrayExprs
类时将降低函数挂钩到目标上下文中。只要遇到arrayexr
运算符,此挂钩就会导致降低传递调用_lower_array_expr()
。
这个功能有两个步骤:
- 合成实现数组表达式的 Python 函数:这个新的 Python 函数本质上就像一个 Numpy
ufunc
,在广播数组参数中的标量值上返回表达式的结果。降低功能通过从数组表达式树转换为 Python AST 来实现这一点。 - 将合成 Python 函数编译到内核中:此时,降低函数依赖于现有代码来降低 ufunc 和 DUFunc 内核,在定义如何降低对合成函数的调用之后调用
numba.targets.numpyimpl.numpy_ufunc_kernel()
。
最终结果类似于 Numba 对象模式中的循环提升。
我们已经了解了如何在 Numba 中实现重写,从界面开始,以实际优化结束。本节的重点是:
- 在编写好的插件时,匹配器应该尽快获得 go / no-go 结果。
- 重写应用程序部分的计算成本可能更高,但仍应生成不会在编译器中导致无限循环的代码。
- 我们使用对象状态将任何匹配结果传递给重写应用程序传递。