torch.export IR Specification IR 说明

Export IR 是一种用于编译器的中间表示(IR),与 MLIR 和 TorchScript 有相似之处。它专门设计用来表达 PyTorch 程序的语义。Export IR 主要以一系列操作的streamlined list形式表示计算,对动态性如控制流的支持有限。

“streamlined list” 可以理解为“简化后的操作列表”或“优化后的操作列表”,强调的是通过精简操作和结构,使得表示更加直接和易于理解

Export IR 建立在 torch.fx.Graph 之上。换句话说,所有的 Export IR 图都是有效的 FX 图,如果使用标准的 FX 语义解释,Export IR 可以被准确地解释。这意味着,导出的图可以通过标准的 FX 代码生成转换为有效的 Python 程序。

要创建一个 Export IR 图,可以使用一个前端,该前端通过trace-specializing mechanism准确捕获 PyTorch 程序。然后可以通过后端优化和执行生成的 Export IR。目前,这可以通过 torch.export.export() 来完成。

本文档将涵盖的关键概念包括:

  • ExportedProgram:包含 Export IR 程序的数据结构
  • Graph:由一系列节点组成。
  • Nodes:代表操作、控制流和存储在此节点上的元数据。
  • Values 由节点产生和消费。
  • Types 与值和节点相关联。
  • 值的大小和内存布局也被定义。

本文档将主要关注 Export IR 与 FX 在严格性方面的区别,而忽略它们之间的相似之处。

ExportedProgram

顶层的 Export IR 构造是 torch.export.ExportedProgram 类。它将 PyTorch 模型的计算图(通常是 torch.nn.Module)与模型的参数或权重捆绑在一起。

torch.export.ExportedProgram 类的一些显著属性(notable attributes)包括:

  • graph_module (torch.fx.GraphModule):包含 PyTorch 模型 flattened computational graph 的数据结构。可以通过 ExportedProgram.graph 直接访问图。
  • graph_signature (torch.export.ExportGraphSignature):The graph signature, which specifies the parameters and buffer names used and mutated within the graph. Instead of storing parameters and buffers as attributes of the graph, they are lifted as inputs to the graph. The graph_signature is utilized to keep track of additional information on these parameters and buffers.
    图签名,指定图中使用和改变的参数和缓冲区名称。参数和缓冲区不作为图的属性存储,而是提升为图的输入。graph_signature 用于追踪这些参数和缓冲区的额外信息。
  • state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]):包含参数和buffers的数据结构。
  • range_constraints (Dict[sympy.Symbol, RangeConstraint]):对于导出具有数据依赖行为的程序,每个节点上的元数据将包含符号形状(如 s0, i0)。此属性将符号形状映射到其下限/上限范围。

在 PyTorch 中,“flattened computational graph” (平展计算图)通常指的是一个简化且去掉了层次结构的计算图表示。在标准的 torch.nn.Module 模型中,计算图通常会根据层次结构进行组织,例如模块内包含子模块,子模块内可能还包含其他模块。这种层次化的结构有助于组织和管理模型的不同部分,但在分析或优化模型的计算流程时,这种层次结构可能会带来复杂性。
“flattened” 或 “flattening” 过程涉及到将这种层次结构展开成一个单一的、平面的图,其中所有的操作和计算节点都在同一个层级上。

Graph

Export IR Graph 是以 DAG(有向无环图)形式表示的 PyTorch 程序。图中的每个节点代表特定的计算或操作,图的边由节点之间的引用组成。

我们可以将 Graph 视为具有以下架构:

class Graph:
  nodes: List[Node]

实际上,Export IR 的图以 torch.fx.Graph Python 类实现。

Export IR 图包含以下节点(节点将在下一节中更详细地描述):

  • 0 个或多个操作类型为 placeholder 的节点
  • 0 个或多个操作类型为 call_function 的节点
  • 正好 1 个操作类型为 output 的节点

推论: 最小的有效图将是一个节点。即 nodes 永不为空。

定义: 图的 placeholder 节点集合代表 Graph 或 GraphModule 的输入。图的输出节点代表 Graph 或 GraphModule 的输出

示例:

from torch import nn
class MyModule(nn.Module):
    def forward(self, x, y):
        return x + y

mod = torch._export.export(MyModule())
print(mod.graph)

graph():
    %l_x_ : [num_users=1] = placeholder[target=l_x_]
    %l_y_ : [num_users=1] = placeholder[target=l_y_]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%l_x_, %l_y_), kwargs = {})
    return (add,)

以上是图的文本表示,每一行都是一个节点。

节点 Node

节点代表特定的计算或操作,使用 Python 中的 torch.fx.Node 类来表示。节点之间的边以 Node 类的 args 属性中对其他节点的直接引用的形式表示。使用相同的 FX 机制,我们可以表示计算图通常需要的以下操作,如操作符调用、占位符(即输入)、条件和循环。

节点具有以下架构:

class Node:
    name: str # 节点名称
    op_name: str # 操作类型
    # 下面字段的解释取决于 op_name
    target: [str|Callable]
    args: List[object]
    kwargs: Dict[str, object]
    meta: Dict[str, object]

FX 文本格式

如上面的示例所示,每行的格式如下:

%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})

这种格式以紧凑的形式捕获了 Node 类中的所有内容,除了 meta

具体来说:

  • <name> 是节点的名称,如 node.name 中所示。
  • <op_name>node.op 字段,必须是以下之一:<call_function>, <placeholder>, <get_attr>, 或 <output>
  • <target>node.target 中的目标。这个字段的含义取决于 op_name
  • args1, … args 4…node.args 元组中列出的内容。如果列表中的值是一个 torch.fx.Node,则会特别用 % 表示。

例如,对加法操作符的调用将显示为:

%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})

其中 %x, %y 是两个名为 x 和 y 的其他节点。值得注意的是,字符串 torch.op.aten.add.Tensor 代表实际存储在目标字段中的可调用对象,而不仅仅是它的字符串名称。

这种文本格式的最后一行是:

return [add]

这是一个 op_name = output 的节点,表示我们返回这一个元素。

call_function

call_function 节点代表对操作符的调用。

定义

  • Functional(函数式): 我们说一个可调用对象是“函数式的”,如果它满足以下所有要求:
    • 非变异性:操作符不会改变其输入的值(对于张量,这包括元数据和数据)。
    • 无副作用:操作符不会改变外部可见的状态,比如改变模块参数的值。
  • Operator(操作符): 是一个具有预定义架构的函数式可调用对象。这样的操作符的例子包括 ATen 的函数式操作符。

在 FX 中的表示

%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})

与原生 FX call_function 的区别

  1. 在 FX 图中,call_function 可以引用任何可调用对象,在 Export IR 中,我们将其限制为只选取 ATen 操作符、自定义操作符和控制流操作符的子集。
  2. 在 Export IR 中,常量参数将嵌入图中。
  3. 在 FX 图中,get_attr 节点可以代表读取图模块中存储的任何属性。然而,在 Export IR 中,这被限制为只读取子模块,因为所有参数/buffers将作为输入传递给图模块。

元数据

Node.meta 是附加在每个 FX 节点上的字典。然而,FX 规范并没有指定哪些元数据可以或将会存在。Export IR 提供了更强的约束,特别是所有 call_function 节点将保证仅具有以下元数据字段:

  • node.meta["stack_trace"] 是一个包含引用原始 Python 源代码的 Python 堆栈跟踪的字符串。一个堆栈跟踪的示例看起来像这样:
File "my_module.py", line 19, in forward
    return x + dummy_helper(y)
File "helper_utility.py", line 89, in dummy_helper
    return y + 1
  • node.meta["val"] 描述运行操作后的输出。它可以是 ‘<symint>’, ‘<FakeTensor>’, List[Union[FakeTensor, SymInt]]None 类型。
  • node.meta["nn_module_stack"] 描述节点来自哪个 torch.nn.Module,如果它是从 torch.nn.Module 调用的话。例如,如果一个包含 addmm 操作的节点是从 torch.nn.Linear 模块中的 torch.nn.Sequential 模块调用的,nn_module_stack 将如下所示:
{'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 
 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
  • node.meta["source_fn_stack"] 包含节点调用前的 torch 函数或叶子 torch.nn.Module 类。例如,一个从 torch.nn.Linear 模块调用包含 addmm 操作的节点,其 source_fn 中将包含 torch.nn.Linear,而一个从 torch.nn.functional.Linear 模块调用包含 addmm 操作的节点,其 source_fn 中将包含 torch.nn.functional.Linear

placeholder

占位符代表图的输入。它的语义与 FX 中完全相同。占位符节点必须是图节点列表中的前 N 个节点。N 可以为零。

在 FX 中的表示

%name = placeholder[target = name](args = ())

目标字段是一个字符串,是输入的名称。

args,如果非空,应该是大小为 1,表示此输入的默认值。

元数据

占位符节点也有 meta['val'],就像 call_function 节点一样。在这种情况下,val 字段表示图预期接收到的此输入参数的输入形状/数据类型。

output

输出调用代表函数中的返回语句;因此它终止当前图。只有一个输出节点,并且它将始终是图的最后一个节点。

在 FX 中的表示

output[](args = (%something, …))

这与 torch.fx 中的语义完全相同。args 表示要返回的节点。

元数据

输出节点具有与 call_function 节点相同的元数据。

get_attr

get_attr 节点代表从封装的 torch.fx.GraphModule 中读取子模块。与 torch.fx.symbolic_trace() 中的普通 FX 图不同,get_attr 节点用于从顶级 torch.fx.GraphModule 中读取属性,如参数和缓冲区,参数和缓冲区作为输入传递给图模块,并存储在顶级 torch.export.ExportedProgram 中。

在 FX 中的表示

%name = get_attr[target = name](args = ())

示例

考虑以下模型:

from functorch.experimental.control_flow import cond

def true_fn(x):
    return x.sin()

def false_fn(x):
    return x.cos()

def f(x, y):
    return cond(y, true_fn, false_fn, [x])

图:

graph():
    %x_1 : [num_users=1] = placeholder[target=x_1]
    %y_1 : [num_users=1] = placeholder[target=y_1]
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
    return conditional

这行代码 %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] 读取子模块 true_graph_0,其中包含 sin 操作符。

参考

SymInt

SymInt 是一个既可以是字面整数也可以是表示整数的符号(在 Python 中由 sympy.Symbol 类表示)的对象。当 SymInt 是一个符号时,它描述了在编译时对图未知的整数类型的变量,即其值只在运行时才知道。

FakeTensor

FakeTensor 是一个包含张量元数据的对象。可以视为具有以下元数据的对象。

class FakeTensor:
    size: List[SymInt]
    dtype: torch.dtype
    device: torch.device
    dim_order: List[int] # 这个还不存在

FakeTensor 的 size 字段是一个整数或 SymInt 的列表。如果存在 SymInt,这意味着这个张量有动态形状。如果存在整数,则假定该张量将具有确切的静态形状。TensorMeta 的秩永远不是动态的。dtype 字段代表该节点输出的数据类型。Edge IR 中没有隐式类型提升。FakeTensor 中没有跨度。

换句话说:

  • 如果 node.target 中的操作符返回一个张量,则 node.meta['val'] 是描述该张量的 FakeTensor。
  • 如果 node.target 中的操作符返回一个张量的 n 元组,则 node.meta['val'] 是描述每个张量的 n 元组 FakeTensors。
  • 如果 node.target 中的操作符返回在编译时已知的 int/float/标量,则 node.meta['val'] 为 None。
  • 如果 node.target 中的操作符返回在编译时未知的 int/float/标量,则 node.meta['val'] 是 SymInt 类型。

例如:

  • aten::add 返回一个张量;因此其规范将是一个 FakeTensor,具有该操作符返回的张量的数据类型和大小。
  • aten::sym_size 返回一个整数;因此其 val 将是一个 SymInt,因为其值只在运行时才可用。
  • max_pool2d_with_indexes 返回一个 (Tensor, Tensor) 元组;因此规范也将是两个 FakeTensor 对象的 2 元组,第一个 TensorMeta 描述返回值的第一个元素等。

Python 代码:

def add_one(x):
    return torch.ops.aten(x, 1)

图:

graph():
    %ph_0 : [#users=1] = placeholder[target=ph_0]
    %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
    return [add_tensor]

FakeTensor:

FakeTensor(dtype=torch.int, size=[2,], device=CPU)

from https://pytorch.org/docs/stable/export.ir_spec.html