torch.export 机制

torch.export 生成一个干净的中间表示 (IR),具有以下不变性。更多关于 IR 的规范可以在这里找到。

  • Soundness: 它保证是原始程序的一个准确表示,并保持原始程序的相同调用约定。
  • Normalized: 图中没有 Python 语义。原始程序中的子模块被内联形成一个完全扁平化的计算图。
  • Defined Operator Set: 产生的图只包含一小部分定义的 Core ATen IR 操作集和注册的自定义操作符。
  • Graph properties: 图是纯函数式的,意味着它不包含有副作用的操作,如变异或别名。它不会改变任何中间值、参数或缓冲区。
  • Metadata: 图包含在跟踪期间捕获的元数据,例如用户代码的堆栈跟踪。

在底层,torch.export 利用以下最新技术:

  • TorchDynamo (torch._dynamo) 是一个内部 API,使用一个名为 Frame Evaluation API 的 CPython 功能来安全地跟踪 PyTorch 图。这提供了一个大幅改进的图形捕获体验,以便完全跟踪 PyTorch 代码需要更少的重写。
  • AOT Autograd 提供了一个功能化的 PyTorch 图,并确保图被分解/降级到小的定义的 Core ATen 操作集。
  • Torch FX (torch.fx) 是图的底层表示,允许灵活的基于 Python 的转换。

torch.compile() 也使用与 torch.export 相同的 PT2 栈,但有一些不同之处:

  • JIT vs. AOT: torch.compile() 是一个 JIT 编译器,它不是用来生成部署外的编译工件的。
  • 部分 vs. 完整图捕获: 当 torch.compile() 遇到模型中无法追踪的部分时,它将“图断裂”并回退到在急切的 Python 运行时中运行程序。相比之下,torch.export 旨在获取 PyTorch 模型的完整图表示,因此当遇到无法追踪的内容时会出错。由于 torch.export 产生的完整图与任何 Python 特性或运行时无关,因此这个图可以保存、加载并在不同的环境和语言中运行。
  • 可用性权衡: 由于 torch.compile() 能够在遇到无法追踪的内容时回退到 Python 运行时,因此它更加灵活。相比之下,torch.export 将要求用户提供更多信息或重写代码以使其可追踪。

torch.fx.symbolic_trace() 相比,torch.export 使用 TorchDynamo 进行追踪,该技术在 Python 字节码级别上操作,使其能够追踪 Python 操作符重载不支持的任意 Python 构造。此外,torch.export 精细地跟踪张量元数据,以便在张量形状等条件上不会导致追踪失败。总的来说,torch.export 预计能在更多用户程序上工作,并产生更低级别的图(在 torch.ops.aten 操作符级别)。请注意,用户仍然可以在 torch.export 之前使用 torch.fx.symbolic_trace() 作为预处理步骤。

torch.jit.script() 相比,torch.export 不捕获 Python 控制流或数据结构,但它支持比 TorchScript 更多的 Python 语言特性(因为覆盖 Python 字节码更容易)。产生的图更简单,只有直线控制流(除了显式控制流操作符)。

torch.jit.trace() 相比,torch.export 是准确的:它能够追踪在尺寸上进行整数计算的代码,并记录所有必要的副条件,以证明特定追踪对其他输入有效。

参考