torch.export
生成一个干净的中间表示 (IR),具有以下不变性。更多关于 IR 的规范可以在这里找到。
- Soundness: 它保证是原始程序的一个准确表示,并保持原始程序的相同调用约定。
- Normalized: 图中没有 Python 语义。原始程序中的子模块被内联形成一个完全扁平化的计算图。
- Defined Operator Set: 产生的图只包含一小部分定义的 Core ATen IR 操作集和注册的自定义操作符。
- Graph properties: 图是纯函数式的,意味着它不包含有副作用的操作,如变异或别名。它不会改变任何中间值、参数或缓冲区。
- Metadata: 图包含在跟踪期间捕获的元数据,例如用户代码的堆栈跟踪。
在底层,torch.export
利用以下最新技术:
- TorchDynamo (torch._dynamo) 是一个内部 API,使用一个名为 Frame Evaluation API 的 CPython 功能来安全地跟踪 PyTorch 图。这提供了一个大幅改进的图形捕获体验,以便完全跟踪 PyTorch 代码需要更少的重写。
- AOT Autograd 提供了一个功能化的 PyTorch 图,并确保图被分解/降级到小的定义的 Core ATen 操作集。
- Torch FX (torch.fx) 是图的底层表示,允许灵活的基于 Python 的转换。
torch.compile()
也使用与 torch.export
相同的 PT2 栈,但有一些不同之处:
- JIT vs. AOT:
torch.compile()
是一个 JIT 编译器,它不是用来生成部署外的编译工件的。 - 部分 vs. 完整图捕获: 当
torch.compile()
遇到模型中无法追踪的部分时,它将“图断裂”并回退到在急切的 Python 运行时中运行程序。相比之下,torch.export
旨在获取 PyTorch 模型的完整图表示,因此当遇到无法追踪的内容时会出错。由于torch.export
产生的完整图与任何 Python 特性或运行时无关,因此这个图可以保存、加载并在不同的环境和语言中运行。 - 可用性权衡: 由于
torch.compile()
能够在遇到无法追踪的内容时回退到 Python 运行时,因此它更加灵活。相比之下,torch.export
将要求用户提供更多信息或重写代码以使其可追踪。
与 torch.fx.symbolic_trace()
相比,torch.export
使用 TorchDynamo 进行追踪,该技术在 Python 字节码级别上操作,使其能够追踪 Python 操作符重载不支持的任意 Python 构造。此外,torch.export
精细地跟踪张量元数据,以便在张量形状等条件上不会导致追踪失败。总的来说,torch.export
预计能在更多用户程序上工作,并产生更低级别的图(在 torch.ops.aten
操作符级别)。请注意,用户仍然可以在 torch.export
之前使用 torch.fx.symbolic_trace()
作为预处理步骤。
与 torch.jit.script()
相比,torch.export
不捕获 Python 控制流或数据结构,但它支持比 TorchScript 更多的 Python 语言特性(因为覆盖 Python 字节码更容易)。产生的图更简单,只有直线控制流(除了显式控制流操作符)。
与 torch.jit.trace()
相比,torch.export
是准确的:它能够追踪在尺寸上进行整数计算的代码,并记录所有必要的副条件,以证明特定追踪对其他输入有效。