TorchScript: Tracing vs. Scripting

本文翻译自 TorchScript: Tracing vs. Scripting - Yuxin's Blog

PyTorch提供了两种方法将nn.Module转化为TorchScript格式的图:tracing和scripting。

本文将:

  • 比较它们的优缺点,重点是tracing的实用技巧。
  • 我试图说服你,在部署复杂模型时,应优先选择torch.jit.trace而不是torch.jit.script。

第二点可能是一个不常见的观点:如果我在Google上搜索“tracing vs scripting”,第一篇文章推荐默认使用scripting。但是tracing有很多优点。事实上,在我离开的时候,Facebook/Meta产品中的所有检测和分割模型都是按照“默认使用tracing,仅在必要时使用scripting”的策略部署的。

为什么tracing更好?简而言之:(i) 它不会损害代码质量;(ii) 其主要局限性可以通过混合使用scripting来解决。

术语 Terminology

首先,我们来解释一些常见术语:

  • Export: 指将以eager-mode Python代码编写的模型转化为描述计算的图的过程。
  • Tracing: 一种导出方法。它使用特定的输入运行模型,并将执行的所有操作“追踪/记录”到一个图中。
    torch.jit.trace是一个使用tracing的导出API,使用方法如torch.jit.trace(model, input)。参见其教程和API。
  • Scripting: 另一种导出方法。它解析模型的Python源代码,并将代码编译成一个图。
    torch.jit.script是一个使用scripting的导出API,使用方法如torch.jit.script(model)。参见其教程和API。
  • TorchScript: 这是一个被广泛使用的术语
    • 它通常指导出图的表示/格式。
    • 但有时它指代scripting导出方法。
      为了避免混淆,本文中我不会单独使用“TorchScript”这个词。我将使用“TS-format”来指代格式,使用“scripting”来指代导出方法。
      由于这个术语使用起来有歧义,可能会给人一种“scripting”是创建TS-format模型的“官方/首选”方式的印象。但事实并非如此。
  • (Torch)Scriptable: 如果torch.jit.script(model)成功,则模型是“scriptable”的,即它可以通过scripting导出。
  • Traceable: 如果torch.jit.trace(model, input)在典型输入下成功,则模型是“traceable”的。
  • Generalize: 如果追踪模型(trace()返回的对象)在给出其他输入时能够正确推理,则称其“泛化”。Scripted模型总是可以泛化。
  • Dynamic control flow或data-dependent control flow: 运算符的执行依赖于输入数据的控制流,例如,对于张量x:
    • if x[0] == 4: x += 1 是一个动态控制流。
    • # 不是动态控制流。
      model: nn.Sequential = ...
      for m in model:
      x = m(x)
      
    • # 也不是动态控制流。
      class A(nn.Module):
      backbone: nn.Module
      head: Optional[nn.Module]
         def forward(self, x):
             x = self.backbone(x)
             if self.head is not None:
                  x = self.head(x)
          return x
      

Scriptability的代价

如果有人说“我们会通过编写一个编译器来改进Python”,你应该立即警觉,因为这是非常困难的。Python太庞大且动态。编译器最多只能支持其语法特性和内置函数的一个子集——PyTorch中的scripting编译器也不例外。

这个编译器支持Python的哪个子集?一个粗略的答案是:编译器对最基本的语法有很好的支持,但对更复杂的语法(类、内置函数如range和zip、动态类型等)支持中等或没有支持。但没有明确的答案:即使是编译器的开发者也通常需要运行代码来查看它是否可以编译。

不完整的Python编译器限制了用户编写代码的方式。虽然没有明确的约束列表,但从我的经验来看,这些约束对大型项目的影响是显而易见的:代码质量是scriptability的代价。

对大多数项目的影响

为了使代码可以通过scripting编译器进行scriptable/编译,大多数项目选择保持在“安全的一边”,仅使用Python的基本语法:没有/很少自定义结构、没有内置函数、没有继承、没有Union、没有**kwargs、没有lambda、没有动态类型等。

这是因为这些“高级”编译器特性要么完全不支持,要么支持不完整:它们在某些情况下可能有效,但在其他情况下会失败。由于没有明确的支持规范,用户无法推理或解决失败的问题。因此,最终用户转向并停留在安全的一边。

可怕的后果是:开发者由于scriptability的顾虑,停止创建抽象/探索有用的语言特性。

许多项目采用的相关技巧是为scripting重写部分代码:创建一个独立的,仅用于推理的forward代码路径,以使编译器满意。这也使得项目更难维护。

对Detectron2的影响

Detectron2支持scripting,但故事有点不同:它在我们非常重视的研究中没有降低代码质量。相反,借助一些创造力和PyTorch团队的直接支持(以及阿里巴巴工程师的一些志愿帮助),我们设法在不删除任何抽象的情况下,使大多数模型可scriptable。

然而,这并不是一件容易的事:我们不得不向编译器添加数十个语法修复,找到创造性的解决方法,并在detectron2中开发了一些hacky补丁(这些补丁老实说可能会长期影响可维护性)。我不建议其他大型项目追求“在不失去抽象的情况下实现scriptability”,除非他们也得到了PyTorch团队的密切支持。

建议

如果你认为“scripting似乎适用于我的项目”,那么让我们接受它,基于我过去支持scripting的一些项目的经验,我可能会反对这种做法,原因如下:

  • “可行”的东西可能比你想象的更脆弱(除非你局限于基本语法):你的代码现在可能编译通过,但有一天你会对模型进行一些无害的更改,发现编译器拒绝了它。

  • 基本语法是不够的:即使目前你的项目不需要更复杂的抽象,但如果项目预计会增长,未来将需要更多的语言特性。

    以多任务检测器为例:

    • 可能会有十几个输入,因此最好使用一些结构/类。
    • 相同的数据可能有不同的表示形式(例如,不同的分割掩码表示方式),这需要Union或更多动态类型。
    • 检测器有很多架构选择,这使得继承很有用。

    大型、不断增长的项目需要不断发展的抽象,以保持健康。

  • 代码质量可能会严重恶化:由于编译器语法限制,清晰代码有时无法编译,丑陋的代码开始积累。此外,由于编译器的语法限制,无法轻松创建抽象来清理丑陋代码。项目的健康状况逐渐恶化。

以下是在PyTorch问题中的抱怨。问题本身只是scripting的一小部分,但类似的抱怨已经听到很多次了。现状是:scripting迫使你编写丑陋代码,所以只在必要时使用它。

使模型可追踪和泛化

可追踪性的代价

使模型可追踪的要求非常明确,对代码健康的影响也很小。

  • 首先,如果模型不是一个适当的单设备、可在TS-format中表示的连接图,则无论是scripting还是tracing都无法工作。例如,如果模型有DataParallel子模块,或者模型将张量转换为numpy数组并调用OpenCV函数等,你将不得不重构它。

    除了这个明显的约束外,tracing只有两个额外要求。

  • 输入/输出格式

    模型的输入/输出必须是Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]或它们的嵌套组合。请注意,字典中的值必须属于同一类型。

    类似的约束也存在于scripting中。然而,在tracing中,约束不适用于子模块:子模块可以使用任何输入/输出格式:Any的字典、类、kwargs、Python支持的任何格式。只有顶级模型需要使用约束格式。

    这使得约束非常容易满足。如果模型使用更丰富的格式,只需围绕它创建一个简单的包装器,将其转换为/从Tuple[Tensor]中转换。Detectron2通过一个通用包装器自动完成所有模型的此操作:

    outputs = model(inputs)   # inputs/outputs 是丰富结构,例如字典或类
    # torch.jit.trace(model, inputs)  # 失败!不支持的格式
    adapter = TracingAdapter(model, inputs)
    traced = torch.jit.trace(adapter, adapter.flattened_inputs)  # 现在可以追踪模型
    
    # 追踪模型只能生成扁平化输出(张量元组):
    flattened_outputs = traced(*adapter.flattened_inputs)
    # 适配器知道如何将其转换回丰富结构(new_outputs == outputs):
    new_outputs = adapter.outputs_schema(flattened_outputs)
    

    Automatically Flatten & Unflatten Nested Containers has more details on how this adapter is implemented.

  • Symbolic shapes::

    表达式如tensor.size(0)、tensor.size()[1]、tensor.shape[2]在eager模式下是整数,但在tracing模式下是张量。这种差异是必要的,以便在tracing过程中,形状计算可以作为图中的符号操作被捕获。在下一节关于泛化的示例中给出了一个例子。

    由于返回类型不同,如果模型的部分假定形状是整数,则可能无法追踪。这通常可以通过在代码中处理两种类型来轻松修复。一个有用的函数是torch.jit.is_tracing,用于检查代码是否在tracing模式下执行。

这就是追踪的所有要求——最重要的是,模型实现中允许使用任何Python语法,因为tracing根本不关心语法。

泛化问题

仅“可追踪”是不够的。tracing的最大问题是它可能无法泛化到其他输入。在以下情况下会出现此问题

  • 动态控制流:

    >>> def f(x):
    ...   return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
    >>> m = torch.jit.trace(f, torch.tensor(3))
    >>> print(m.code)
    def f(x: Tensor) -> Tensor:
      return torch.sqrt(x)
    

    在此示例中,由于动态控制流,追踪仅保留了条件的一个分支,并且无法泛化到某些(负)输入。

  • 将变量捕获为常量:

    >>> a, b = torch.rand(1), torch.rand(2)
    >>> def f1(x): return torch.arange(x.shape[0])
    >>> def f2(x): return torch.arange(len(x))
    >>> # 看看两个追踪是否可以从a泛化到b:
    >>> torch.jit.trace(f1, a)(b)
    tensor([0, 1])
    >>> torch.jit.trace(f2, a)(b)
    tensor([0])  # 错误!
    >>> # 为什么f2不能泛化?让我们比较它们的代码:
    >>> print(torch.jit.trace(f1, a).code, torch.jit.trace(f2, a).code)
    def f1(x: Tensor) -> Tensor:
      _0 = ops.prim.NumToTensor(torch.size(x, 0))
      _1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
      return _1
    def f2(x: Tensor) -> Tensor:
      _0 = torch.arange(1, dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
      return _0
    

    非张量类型(在此情况下为int类型)的中间计算结果可能会被捕获为常量,使用追踪期间观察到的值。这导致追踪无法泛化。

    除了len(),此问题还可能出现在:

    • .item(),将张量转换为int/float。
    • 任何将torch类型转换为numpy/Python原语的代码。
    • 一些有问题的运算符,例如高级索引。
  • 捕获设备:

    >>> def f(x):
    ...   return torch.arange(x.shape[0], device=x.device)
    >>> m = torch.jit.trace(f, torch.tensor([3]))
    >>> print(m.code)
    def f(x: Tensor) -> Tensor:
      _0 = ops.prim.NumToTensor(torch.size(x, 0))
      _1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
      return _1
    >>> m(torch.tensor([3]).cuda()).device
    device(type='cpu')  # 错误!
    

    类似地,接受设备参数的运算符将记住追踪期间使用的设备(这可以在m.code中看到)。因此,追踪可能无法泛化到不同设备上的输入。几乎不需要这种泛化,因为部署通常有目标设备。

让追踪泛化 Let Tracing Generalize

上述问题令人烦恼且通常是静默的(有警告,但没有错误),但它们可以通过良好的实践和工具成功解决:

  • 注意TracerWarning:在上述两个示例中,torch.jit.trace实际上会发出警告。第一个示例打印:

    a.py:3: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
    We can't record the data flow of Python values, so this value will be treated as a constant in the future.
    This means that the trace might not generalize to other inputs!
    if x.sum() > 0:
    

    注意这些警告(或更好的是,捕获它们)将暴露tracing的大多数泛化问题。

    请注意,“捕获设备”案例不会打印警告,因为tracing根本不支持这种泛化。

  • 并行性单元测试( Unittests for parity):导出后和部署前应进行单元测试,以验证导出模型生成的输出与原始eager-mode模型的输出相同,即:

    assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
    

    如果需要跨形状的泛化(并不总是需要),input2应具有不同于input1的形状。

    Detectron2有许多泛化测试,例如此测试和此测试。一旦发现差距,检查导出TS-format模型的代码可以揭示它无法泛化的地方。

  • 避免不必要的“特殊情况”条件:避免如下条件:

    if x.numel() > 0:
      output = self.layers(x)
    else:
      output = torch.zeros((0, C, H, W))  # 创建空输出
    

    这样处理特殊情况,如空输入。相反,改进self.layers或其底层内核,使其支持空输入。这将导致更清晰的代码,并且还会改进tracing。这就是为什么我参与了许多改进空输入支持的PyTorch问题,例如#12013、#36530、#56998。大多数PyTorch操作都能很好地处理空输入,因此很少需要这种分支。

  • 使用符号形状:如前所述,tensor.size()在tracing期间返回Tensor,以便在图中捕获形状计算。用户应避免意外地将张量形状变为常量:

    • 使用tensor.size(0)而不是len(tensor),因为后者是int。对于自定义类,实施.size方法或使用.len()而不是len(),例如像这里。
    • 不要通过int()或torch.as_tensor转换大小,因为它们会捕获常量。此辅助函数有助于以在tracing和eager模式下均适用的方式将大小转换为张量。
  • 混合tracing和scripting:它们可以混合使用,因此可以对tracing无法正常工作的少部分代码使用scripting。这样可以解决tracing的几乎所有问题。更多内容见下文。

混合Tracing和Scripting

Tracing和scripting都有各自的问题,最好的解决方案通常是将它们混合使用。这让我们可以获得两者的优点。

为了最大限度地减少对代码质量的负面影响,我们应在大多数逻辑上使用tracing,仅在必要时使用scripting。

  • 使用@script_if_tracing:在torch.jit.trace内部,@script_if_tracing装饰器可以通过scripting编译函数。通常,这只需要对forward逻辑进行一些小的重构,以分离需要编译的部分(使用控制流的部分):

    def forward(self, ...):
      # ... 一些forward逻辑
      @torch.jit.script_if_tracing
      def _inner_impl(x, y, z, flag: bool):
          # 使用控制流等
          return ...
      output = _inner_impl(x, y, z, flag)
      # ... 其他forward逻辑
    

    通过仅编译需要的部分,代码质量损害比使整个模型scriptable要小得多,并且完全不影响模块的forward接口。

    由@script_if_tracing装饰的函数必须是纯函数,不包含模块。因此,有时需要进行更多重构:

    事实上,对于大多数视觉模型,动态控制流仅在少数子模块中需要,因此很容易scriptable。为了显示其需求的稀缺性,整个detectron2仅有两个函数由于控制流被@script_if_tracing装饰:paste_masks和heatmaps_to_keypoints,均用于后处理。其他一些函数也被装饰以跨设备泛化(非常罕见的需求)。

  • 使用scripted/追踪子模块:

    model.submodule = torch.jit.script(model.submodule)
    torch.jit.trace(model, inputs)
    

    在此示例中,假设子模块无法正确追踪,我们可以在tracing之前对其进行scripting。然而,我不推荐这样做。如果可能,我建议在子模块的forward内部使用@script_if_tracing,使scripting仅限于子模块的内部,而不影响模块的接口。

    同样,

    model.submodule = torch.jit.trace(model.submodule, submodule_inputs)
    torch.jit.script(model)
    

    这在scripting期间使用追踪子模块。这看起来不错,但在实践中用处不大:它会影响子模块的接口,要求它只能接受/返回Tuple[Tensor]——这是一个大约束,可能比scripting更损害代码质量。

    一个罕见的“追踪子模块”有用场景是:

    class A(nn.Module):
      def forward(self, x):
        # 根据动态、数据依赖条件分派到不同子模块:
        return self.submodule1(x) if x.sum() > 0 else self.submodule2(x)
    

    @script_if_tracing无法编译这样的控制流,因为它只支持纯函数。如果子模块{1,2}复杂且无法scripting,使用追踪子模块在scripting父类A是最佳选择。

  • 合并多个追踪:

Scripted模型支持两种追踪模型不支持的功能:

  • 基于属性的控制流:一个scripted模块可以有可变属性(例如,一个布尔标志)影响控制流。追踪模块没有控制流。

  • 多个方法:一个追踪模块仅支持forward(),但一个scripted模块可以有多个方法。 实际上,上述两种功能都是做同一件事:它们允许导出模型以不同方式使用,即根据调用者的要求执行不同的操作序列。

    以下是一个示例场景,如果Detector被scripted,调用者可以修改其do_keypoint属性来控制其行为,或在需要时直接调用predict_keypoint方法。

    class Detector(nn.Module):
      do_keypoint: bool
    
      def forward(self, img):
          box = self.predict_boxes(img)
          if self.do_keypoint:
              kpts = self.predict_keypoint(img, box)
    
      @torch.jit.export
      def predict_boxes(self, img): pass
    
      @torch.jit.export
      def predict_keypoint(self, img, box): pass
    

    这种需求并不常见。但如果需要,如何在tracing中实现?我有一个不太干净的解决方案:

    tracing只能捕获一个操作序列,所以自然的方法是追踪模型两次:

    det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
    det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
    

    然后我们可以别名它们的权重(不重复存储),并将两个追踪合并为一个模块进行scripting。

    det2.submodule.weight = det1.submodule.weight
    class Wrapper(nn.ModuleList):
      def forward(self, img, do_keypoint: bool):
        if do_keypoint:
            return self[0](img)
        else:
            return self[1](img)
    exported = torch.jit.script(Wrapper([det1, det2]))
    

性能

如果一个模型既可追踪又可scriptable,那么tracing总是生成相同或更简单的图(因此可能更快)。

为什么?因为scripting试图忠实地表示你的Python代码,即使其中一些是多余的。例如:它并不总是足够智能,无法意识到Python代码中的某些循环或数据结构实际上是静态的,可以移除:

class A(nn.Module):
  def forward(self, x1, x2, x3):
    z = [0, 1, 2]
    xs = [x1, x2, x3]
    for k in z: x1 += xs[k]
    return x1
model = A()
print(torch.jit.script(model).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
#   z = [0, 1, 2]
#   xs = [x1, x2, x3]
#   x10 = x1
#   for _0 in range(torch.len(z)):
#     k = z[_0]
#     x10 = torch.add_(x10, xs[k])
#   return x10
print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
#   x10 = torch.add_(x1, x1)
#   x11 = torch.add_(x10, x2)
#   return torch.add_(x11, x3)

这个例子非常简单,所以它实际上有scripting的解决方法(使用元组而不是列表),或者循环可能会在后续优化过程中得到优化。但重点是:图编译器并不总是足够智能。对于复杂模型,scripting可能会生成带有不必要复杂性的图,难以优化。

总结

tracing有明显的局限性:本文大部分内容都是在讨论tracing的局限性以及如何解决它们。我实际上认为这是tracing的优势:它有明显的局限性(和解决方案),所以你可以推理它是否可行。

相反,scripting更像是一个黑匣子:在尝试之前没有人知道它是否可行。我没有提到任何修复scripting的技巧:它们有很多,但不值得你花时间去探究和修复一个黑匣子。

tracing的影响范围小:tracing和scripting都影响代码的编写方式,但tracing的影响范围更小,造成的损害也更少:

  • 它限制了输入/输出格式,但仅限于最外层的模块。(而且这个问题可以自动解决,如上所述。)
  • 它需要一些代码更改以泛化(例如,在tracing中混合scripting),但这些更改仅进入受影响模块的内部实现,而不影响它们的接口。

另一方面,scripting影响:

  • 涉及的每个模块和子模块的接口。
    • 在我看来,这是最大的损害:接口中需要高级语法特性,我不愿意在接口设计上妥协。
    • 这也可能影响训练,因为接口通常在训练和推理之间共享。 推理forward路径中的每一行代码。

影响范围大的原因是scripting可能会严重损害代码质量。

控制流与其他Python语法:PyTorch受到用户喜爱,因为他们可以“只写Python”,最重要的是写Python控制流。但Python的其他语法也很重要。如果能够写Python控制流(scripting)意味着失去其他优秀语法,我宁愿放弃写Python控制流的能力。

事实上,如果PyTorch对Python控制流不那么执着,并为我提供类似于tf.cond的符号控制流(例如torch.cond),我会很高兴使用这个,不再担心scripting:

def f(x):
  return torch.cond(x.sum() > 0, lambda: torch.sqrt(x), lambda: torch.square(x))

那么f可以被正确追踪,我会很高兴使用这个,不再担心scripting。TensorFlow AutoGraph是一个自动化这一想法的好例子。

但显然不符合torch易用的设计原则

参考