torch.export机制

Export IR 是一种用于编译器的中间表示(IR),与 MLIR 和 TorchScript 有相似之处。它专门设计用来表达 PyTorch 程序的语义。Export IR 主要以一系列操作的streamlined list形式表示计算,对动态性如控制流的支持有限。

Export IR 建立在 torch.fx.Graph 之上。换句话说,所有的 Export IR 图都是有效的 FX 图,如果使用标准的 FX 语义解释,Export IR 可以被准确地解释。这意味着,导出的图可以通过标准的 FX 代码生成转换为有效的 Python 程序。

要创建一个 Export IR 图,可以使用一个前端,该前端通过trace-specializing mechanism准确捕获 PyTorch 程序。然后可以通过后端优化和执行生成的 Export IR。目前,这可以通过 torch.export.export() 来完成。

本文档将涵盖的关键概念包括:

  • ExportedProgram:包含 Export IR 程序的数据结构
  • Graph:由一系列节点组成。
  • Nodes:代表操作、控制流和存储在此节点上的元数据。
  • Values 由节点产生和消费。
  • Types 与值和节点相关联。
  • 值的大小和内存布局也被定义。

本文档将主要关注 Export IR 与 FX 在严格性方面的区别,而忽略它们之间的相似之处。

ExportedProgram

顶层的 Export IR 构造是 torch.export.ExportedProgram 类。它将 PyTorch 模型的计算图(通常是 torch.nn.Module)与模型的参数或权重捆绑在一起。

torch.export.ExportedProgram 类的一些显著属性(notable attributes)包括:

  • graph_module (torch.fx.GraphModule):包含 PyTorch 模型 flattened computational graph 的数据结构。可以通过 ExportedProgram.graph 直接访问图。
  • graph_signature (torch.export.ExportGraphSignature):The graph signature, which specifies the parameters and buffer names used and mutated within the graph. Instead of storing parameters and buffers as attributes of the graph, they are lifted as inputs to the graph. The graph_signature is utilized to keep track of additional information on these parameters and buffers.
    图签名,指定图中使用和改变的参数和缓冲区名称。参数和缓冲区不作为图的属性存储,而是提升为图的输入。graph_signature 用于追踪这些参数和缓冲区的额外信息。
  • state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含参数和buffers的数据结构。
  • range_constraints (Dict[sympy.Symbol, RangeConstraint]):对于导出具有数据依赖行为的程序,每个节点上的元数据将包含符号形状(如 s0, i0)。此属性将符号形状映射到其下限/上限范围。

Graph

Export IR Graph 是以 DAG(有向无环图)形式表示的 PyTorch 程序。图中的每个节点代表特定的计算或操作,图的边由节点之间的引用组成。

我们可以将 Graph 视为具有以下架构:

class Graph:
  nodes: List[Node]

实际上,Export IR 的图以 torch.fx.Graph Python 类实现。

Export IR 图包含以下节点(节点将在下一节中更详细地描述):

  • 0 个或多个操作类型为 placeholder 的节点
  • 0 个或多个操作类型为 call_function 的节点
  • 正好 1 个操作类型为 output 的节点

推论: 最小的有效图将是一个节点。即 nodes 永不为空。

定义: 图的 placeholder 节点集合代表 Graph 或 GraphModule 的输入。图的输出节点代表 Graph 或 GraphModule 的输出

示例:

from torch import nn
class MyModule(nn.Module):
    def forward(self, x, y):
        return x + y

mod = torch._export.export(MyModule())
print(mod.graph)

graph():
    %l_x_ : [num_users=1] = placeholder[target=l_x_]
    %l_y_ : [num_users=1] = placeholder[target=l_y_]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%l_x_, %l_y_), kwargs = {})
    return (add,)

以上是图的文本表示,每一行都是一个节点。

节点 Node

节点代表特定的计算或操作,使用 Python 中的 torch.fx.Node 类来表示。节点之间的边以 Node 类的 args 属性中对其他节点的直接引用的形式表示。使用相同的 FX 机制,我们可以表示计算图通常需要的以下操作,如操作符调用、占位符(即输入)、条件和循环。

节点具有以下架构:

class Node:
    name: str # 节点名称
    op_name: str # 操作类型
    # 下面字段的解释取决于 op_name
    target: [str|Callable]
    args: List[object]
    kwargs: Dict[str, object]
    meta: Dict[str, object]

FX 文本格式

如上面的示例所示,每行的格式如下:

%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

这种格式以紧凑的形式捕获了 Node 类中的所有内容,除了 meta

具体来说:

  • <name> 是节点的名称,如 node.name 中所示。
  • <op_name>node.op 字段,必须是以下之一:<call_function>, <placeholder>, <get_attr>, 或 <output>
  • <target>node.target 中的目标。这个字段的含义取决于 op_name
  • args1, … args 4…node.args 元组中列出的内容。如果列表中的值是一个 torch.fx.Node,则会特别用 % 表示。

例如,对加法操作符的调用将显示为:

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中 %x, %y 是两个名为 x 和 y 的其他节点。值得注意的是,字符串 torch.op.aten.add.Tensor 代表实际存储在目标字段中的可调用对象,而不仅仅是它的字符串名称。

这种文本格式的最后一行是:

return [add]

这是一个 op_name = output 的节点,表示我们返回这一个元素。

call_function

call_function 节点代表对操作符的调用。

定义

  • Functional(函数式): 我们说一个可调用对象是“函数式的”,如果它满足以下所有要求:
    • 非变异性:操作符不会改变其输入的值(对于张量,这包括元数据和数据)。
    • 无副作用:操作符不会改变外部可见的状态,比如改变模块参数的值。
  • Operator(操作符): 是一个具有预定义架构的函数式可调用对象。这样的操作符的例子包括 ATen 的函数式操作符。

在 FX 中的表示

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

与原生 FX call_function 的区别

  1. 在 FX 图中,call_function 可以引用任何可调用对象,在 Export IR 中,我们将其限制为只选取 ATen 操作符、自定义操作符和控制流操作符的子集。
  2. 在 Export IR 中,常量参数将嵌入图中。
  3. 在 FX 图中,get_attr 节点可以代表读取图模块中存储的任何属性。然而,在 Export IR 中,这被限制为只读取子模块,因为所有参数/buffers将作为输入传递给图模块。

元数据

Node.meta 是附加在每个 FX 节点上的字典。然而,FX 规范并没有指定哪些元数据可以或将会存在。Export IR 提供了更强的约束,特别是所有 call_function 节点将保证仅具有以下元数据字段:

  • node.meta["stack_trace"] 是一个包含引用原始 Python 源代码的 Python 堆栈跟踪的字符串。一个堆栈跟踪的示例看起来像这样:
File "my_module.py", line 19, in forward
    return x + dummy_helper(y)
File "helper_utility.py", line 89, in dummy_helper
    return y + 1
  • node.meta["val"] 描述运行操作后的输出。它可以是 , , List[Union[FakeTensor, SymInt]]None 类型。
  • node.meta["nn_module_stack"] 描述节点来自哪个 torch.nn.Module,如果它是从 torch.nn.Module 调用的话。例如,如果一个包含 addmm 操作的节点是从 torch.nn.Linear 模块中的 torch.nn.Sequential 模块调用的,nn_module_stack 将如下所示:
{'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 
 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
  • node.meta["source_fn_stack"] 包含节点调用前的 torch 函数或叶子 torch.nn.Module 类。例如,一个从 torch.nn.Linear 模块调用包含 addmm 操作的节点,其 source_fn 中将包含 torch.nn.Linear,而一个从 torch.nn.functional.Linear 模块调用包含 addmm 操作的节点,其 source_fn 中将包含 torch.nn.functional.Linear

placeholder

占位符代表图的输入。它的语义与 FX 中完全相同。占位符节点必须是图节点列表中的前 N 个节点。N 可以为零。

在 FX 中的表示

%name = placeholder[target = name](args = ())

目标字段是一个字符串,是输入的名称。

args,如果非空,应该是大小为 1,表示此输入的默认值。

元数据

占位符节点也有 meta['val'],就像 call_function 节点一样。在这种情况下,val 字段表示图预期接收到的此输入参数的输入形状/数据类型。

output

输出调用代表函数中的返回语句;因此它终止当前图。只有一个输出节点,并且它将始终是图的最后一个节点。

在 FX 中的表示

output[](args = (%something, …))

这与 torch.fx 中的语义完全相同。args 表示要返回的节点。

元数据

输出节点具有与 call_function 节点相同的元数据。

get_attr

get_attr 节点代表从封装的 torch.fx.GraphModule 中读取子模块。与 torch.fx.symbolic_trace() 中的普通 FX 图不同,get_attr 节点用于从顶级 torch.fx.GraphModule 中读取属性,如参数和缓冲区,参数和缓冲区作为输入传递给图模块,并存储在顶级 torch.export.ExportedProgram 中。

在 FX 中的表示

%name = get_attr[target = name](args = ())

示例

考虑以下模型:

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

图:

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

这行代码 %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 读取子模块 true_graph_0,其中包含 sin 操作符。

参考

SymInt

SymInt 是一个既可以是字面整数也可以是表示整数的符号(在 Python 中由 sympy.Symbol 类表示)的对象。当 SymInt 是一个符号时,它描述了在编译时对图未知的整数类型的变量,即其值只在运行时才知道。

FakeTensor

FakeTensor 是一个包含张量元数据的对象。可以视为具有以下元数据的对象。

class FakeTensor:
    size: List[SymInt]
    dtype: torch.dtype
    device: torch.device
    dim_order: List[int] # 这个还不存在

FakeTensor 的 size 字段是一个整数或 SymInt 的列表。如果存在 SymInt,这意味着这个张量有动态形状。如果存在整数,则假定该张量将具有确切的静态形状。TensorMeta 的秩永远不是动态的。dtype 字段代表该节点输出的数据类型。Edge IR 中没有隐式类型提升。FakeTensor 中没有跨度。

换句话说:

  • 如果 node.target 中的操作符返回一个张量,则 node.meta['val'] 是描述该张量的 FakeTensor。
  • 如果 node.target 中的操作符返回一个张量的 n 元组,则 node.meta['val'] 是描述每个张量的 n 元组 FakeTensors。
  • 如果 node.target 中的操作符返回在编译时已知的 int/float/标量,则 node.meta['val'] 为 None。
  • 如果 node.target 中的操作符返回在编译时未知的 int/float/标量,则 node.meta['val'] 是 SymInt 类型。

例如:

  • aten::add 返回一个张量;因此其规范将是一个 FakeTensor,具有该操作符返回的张量的数据类型和大小。
  • aten::sym_size 返回一个整数;因此其 val 将是一个 SymInt,因为其值只在运行时才可用。
  • max_pool2d_with_indexes 返回一个 (Tensor, Tensor) 元组;因此规范也将是两个 FakeTensor 对象的 2 元组,第一个 TensorMeta 描述返回值的第一个元素等。

Python 代码:

def add_one(x):
    return torch.ops.aten(x, 1)

图:

graph():
    %ph_0 : [#users=1] = placeholder[target=ph_0]
    %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
    return [add_tensor]

FakeTensor:

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

torch.export 生成一个干净的中间表示 (IR),具有以下不变性。更多关于 IR 的规范可以在这里找到。

  • Soundness: 它保证是原始程序的一个准确表示,并保持原始程序的相同调用约定。
  • Normalized: 图中没有 Python 语义。原始程序中的子模块被内联形成一个完全扁平化的计算图。
  • Defined Operator Set: 产生的图只包含一小部分定义的 Core ATen IR 操作集和注册的自定义操作符。
  • Graph properties: 图是纯函数式的,意味着它不包含有副作用的操作,如变异或别名。它不会改变任何中间值、参数或缓冲区。
  • Metadata: 图包含在跟踪期间捕获的元数据,例如用户代码的堆栈跟踪。

在底层,torch.export 利用以下最新技术:

  • TorchDynamo (torch._dynamo) 是一个内部 API,使用一个名为 Frame Evaluation API 的 CPython 功能来安全地跟踪 PyTorch 图。这提供了一个大幅改进的图形捕获体验,以便完全跟踪 PyTorch 代码需要更少的重写。
  • AOT Autograd 提供了一个功能化的 PyTorch 图,并确保图被分解/降级到小的定义的 Core ATen 操作集。
  • Torch FX (torch.fx) 是图的底层表示,允许灵活的基于 Python 的转换。

torch.compile() 也使用与 torch.export 相同的 PT2 栈,但有一些不同之处:

  • JIT vs. AOT: torch.compile() 是一个 JIT 编译器,它不是用来生成部署外的编译工件的。
  • 部分 vs. 完整图捕获: 当 torch.compile() 遇到模型中无法追踪的部分时,它将“图断裂”并回退到在急切的 Python 运行时中运行程序。相比之下,torch.export 旨在获取 PyTorch 模型的完整图表示,因此当遇到无法追踪的内容时会出错。由于 torch.export 产生的完整图与任何 Python 特性或运行时无关,因此这个图可以保存、加载并在不同的环境和语言中运行。
  • 可用性权衡: 由于 torch.compile() 能够在遇到无法追踪的内容时回退到 Python 运行时,因此它更加灵活。相比之下,torch.export 将要求用户提供更多信息或重写代码以使其可追踪。

torch.fx.symbolic_trace() 相比,torch.export 使用 TorchDynamo 进行追踪,该技术在 Python 字节码级别上操作,使其能够追踪 Python 操作符重载不支持的任意 Python 构造。此外,torch.export 精细地跟踪张量元数据,以便在张量形状等条件上不会导致追踪失败。总的来说,torch.export 预计能在更多用户程序上工作,并产生更低级别的图(在 torch.ops.aten 操作符级别)。请注意,用户仍然可以在 torch.export 之前使用 torch.fx.symbolic_trace() 作为预处理步骤。

torch.jit.script() 相比,torch.export 不捕获 Python 控制流或数据结构,但它支持比 TorchScript 更多的 Python 语言特性(因为覆盖 Python 字节码更容易)。产生的图更简单,只有直线控制流(除了显式控制流操作符)。

torch.jit.trace() 相比,torch.export 是准确的:它能够追踪在尺寸上进行整数计算的代码,并记录所有必要的副条件,以证明特定追踪对其他输入有效。

参考