PyTorch提供了两种方法将nn.Module转化为TorchScript格式的图:tracing和scripting。
本文将:
- 比较它们的优缺点,重点是tracing的实用技巧。
- 我试图说服你,在部署复杂模型时,应优先选择torch.jit.trace而不是torch.jit.script。
第二点可能是一个不常见的观点:如果我在Google上搜索“tracing vs scripting”,第一篇文章推荐默认使用scripting。但是tracing有很多优点。事实上,在我离开的时候,Facebook/Meta产品中的所有检测和分割模型都是按照“默认使用tracing,仅在必要时使用scripting”的策略部署的。
为什么tracing更好?简而言之:(i) 它不会损害代码质量;(ii) 其主要局限性可以通过混合使用scripting来解决。
术语 Terminology
首先,我们来解释一些常见术语:
- Export: 指将以eager-mode Python代码编写的模型转化为描述计算的图的过程。
- Tracing: 一种导出方法。它使用特定的输入运行模型,并将执行的所有操作“追踪/记录”到一个图中。
torch.jit.trace是一个使用tracing的导出API,使用方法如torch.jit.trace(model, input)。参见其教程和API。 - Scripting: 另一种导出方法。它解析模型的Python源代码,并将代码编译成一个图。
torch.jit.script是一个使用scripting的导出API,使用方法如torch.jit.script(model)。参见其教程和API。 - TorchScript: 这是一个被广泛使用的术语
- 它通常指导出图的表示/格式。
- 但有时它指代scripting导出方法。
为了避免混淆,本文中我不会单独使用“TorchScript”这个词。我将使用“TS-format”来指代格式,使用“scripting”来指代导出方法。
由于这个术语使用起来有歧义,可能会给人一种“scripting”是创建TS-format模型的“官方/首选”方式的印象。但事实并非如此。
- (Torch)Scriptable: 如果torch.jit.script(model)成功,则模型是“scriptable”的,即它可以通过scripting导出。
- Traceable: 如果torch.jit.trace(model, input)在典型输入下成功,则模型是“traceable”的。
- Generalize: 如果追踪模型(trace()返回的对象)在给出其他输入时能够正确推理,则称其“泛化”。Scripted模型总是可以泛化。
- Dynamic control flow或data-dependent control flow: 运算符的执行依赖于输入数据的控制流,例如,对于张量x:
- if x[0] == 4: x += 1 是一个动态控制流。
-
# 不是动态控制流。 model: nn.Sequential = ... for m in model: x = m(x)
-
# 也不是动态控制流。 class A(nn.Module): backbone: nn.Module head: Optional[nn.Module] def forward(self, x): x = self.backbone(x) if self.head is not None: x = self.head(x) return x
Scriptability的代价
如果有人说“我们会通过编写一个编译器来改进Python”,你应该立即警觉,因为这是非常困难的。Python太庞大且动态。编译器最多只能支持其语法特性和内置函数的一个子集——PyTorch中的scripting编译器也不例外。
这个编译器支持Python的哪个子集?一个粗略的答案是:编译器对最基本的语法有很好的支持,但对更复杂的语法(类、内置函数如range和zip、动态类型等)支持中等或没有支持。但没有明确的答案:即使是编译器的开发者也通常需要运行代码来查看它是否可以编译。
不完整的Python编译器限制了用户编写代码的方式。虽然没有明确的约束列表,但从我的经验来看,这些约束对大型项目的影响是显而易见的:代码质量是scriptability的代价。
对大多数项目的影响
为了使代码可以通过scripting编译器进行scriptable/编译,大多数项目选择保持在“安全的一边”,仅使用Python的基本语法:没有/很少自定义结构、没有内置函数、没有继承、没有Union、没有**kwargs、没有lambda、没有动态类型等。
这是因为这些“高级”编译器特性要么完全不支持,要么支持不完整:它们在某些情况下可能有效,但在其他情况下会失败。由于没有明确的支持规范,用户无法推理或解决失败的问题。因此,最终用户转向并停留在安全的一边。
可怕的后果是:开发者由于scriptability的顾虑,停止创建抽象/探索有用的语言特性。
许多项目采用的相关技巧是为scripting重写部分代码:创建一个独立的,仅用于推理的forward代码路径,以使编译器满意。这也使得项目更难维护。
对Detectron2的影响
Detectron2支持scripting,但故事有点不同:它在我们非常重视的研究中没有降低代码质量。相反,借助一些创造力和PyTorch团队的直接支持(以及阿里巴巴工程师的一些志愿帮助),我们设法在不删除任何抽象的情况下,使大多数模型可scriptable。
然而,这并不是一件容易的事:我们不得不向编译器添加数十个语法修复,找到创造性的解决方法,并在detectron2中开发了一些hacky补丁(这些补丁老实说可能会长期影响可维护性)。我不建议其他大型项目追求“在不失去抽象的情况下实现scriptability”,除非他们也得到了PyTorch团队的密切支持。
建议
如果你认为“scripting似乎适用于我的项目”,那么让我们接受它,基于我过去支持scripting的一些项目的经验,我可能会反对这种做法,原因如下:
-
“可行”的东西可能比你想象的更脆弱(除非你局限于基本语法):你的代码现在可能编译通过,但有一天你会对模型进行一些无害的更改,发现编译器拒绝了它。
-
基本语法是不够的:即使目前你的项目不需要更复杂的抽象,但如果项目预计会增长,未来将需要更多的语言特性。
以多任务检测器为例:
- 可能会有十几个输入,因此最好使用一些结构/类。
- 相同的数据可能有不同的表示形式(例如,不同的分割掩码表示方式),这需要Union或更多动态类型。
- 检测器有很多架构选择,这使得继承很有用。
大型、不断增长的项目需要不断发展的抽象,以保持健康。
-
代码质量可能会严重恶化:由于编译器语法限制,清晰代码有时无法编译,丑陋的代码开始积累。此外,由于编译器的语法限制,无法轻松创建抽象来清理丑陋代码。项目的健康状况逐渐恶化。
以下是在PyTorch问题中的抱怨。问题本身只是scripting的一小部分,但类似的抱怨已经听到很多次了。现状是:scripting迫使你编写丑陋代码,所以只在必要时使用它。
使模型可追踪和泛化
可追踪性的代价
使模型可追踪的要求非常明确,对代码健康的影响也很小。
-
首先,如果模型不是一个适当的单设备、可在TS-format中表示的连接图,则无论是scripting还是tracing都无法工作。例如,如果模型有DataParallel子模块,或者模型将张量转换为numpy数组并调用OpenCV函数等,你将不得不重构它。
除了这个明显的约束外,tracing只有两个额外要求。
-
输入/输出格式
模型的输入/输出必须是Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]或它们的嵌套组合。请注意,字典中的值必须属于同一类型。
类似的约束也存在于scripting中。然而,在tracing中,约束不适用于子模块:子模块可以使用任何输入/输出格式:Any的字典、类、kwargs、Python支持的任何格式。只有顶级模型需要使用约束格式。
这使得约束非常容易满足。如果模型使用更丰富的格式,只需围绕它创建一个简单的包装器,将其转换为/从Tuple[Tensor]中转换。Detectron2通过一个通用包装器自动完成所有模型的此操作:
outputs = model(inputs) # inputs/outputs 是丰富结构,例如字典或类 # torch.jit.trace(model, inputs) # 失败!不支持的格式 adapter = TracingAdapter(model, inputs) traced = torch.jit.trace(adapter, adapter.flattened_inputs) # 现在可以追踪模型 # 追踪模型只能生成扁平化输出(张量元组): flattened_outputs = traced(*adapter.flattened_inputs) # 适配器知道如何将其转换回丰富结构(new_outputs == outputs): new_outputs = adapter.outputs_schema(flattened_outputs)
Automatically Flatten & Unflatten Nested Containers has more details on how this adapter is implemented.
-
Symbolic shapes::
表达式如tensor.size(0)、tensor.size()[1]、tensor.shape[2]在eager模式下是整数,但在tracing模式下是张量。这种差异是必要的,以便在tracing过程中,形状计算可以作为图中的符号操作被捕获。在下一节关于泛化的示例中给出了一个例子。
由于返回类型不同,如果模型的部分假定形状是整数,则可能无法追踪。这通常可以通过在代码中处理两种类型来轻松修复。一个有用的函数是torch.jit.is_tracing,用于检查代码是否在tracing模式下执行。
这就是追踪的所有要求——最重要的是,模型实现中允许使用任何Python语法,因为tracing根本不关心语法。
泛化问题
仅“可追踪”是不够的。tracing的最大问题是它可能无法泛化到其他输入。在以下情况下会出现此问题
-
动态控制流:
>>> def f(x): ... return torch.sqrt(x) if x.sum() > 0 else torch.square(x) >>> m = torch.jit.trace(f, torch.tensor(3)) >>> print(m.code) def f(x: Tensor) -> Tensor: return torch.sqrt(x)
在此示例中,由于动态控制流,追踪仅保留了条件的一个分支,并且无法泛化到某些(负)输入。
-
将变量捕获为常量:
>>> a, b = torch.rand(1), torch.rand(2) >>> def f1(x): return torch.arange(x.shape[0]) >>> def f2(x): return torch.arange(len(x)) >>> # 看看两个追踪是否可以从a泛化到b: >>> torch.jit.trace(f1, a)(b) tensor([0, 1]) >>> torch.jit.trace(f2, a)(b) tensor([0]) # 错误! >>> # 为什么f2不能泛化?让我们比较它们的代码: >>> print(torch.jit.trace(f1, a).code, torch.jit.trace(f2, a).code) def f1(x: Tensor) -> Tensor: _0 = ops.prim.NumToTensor(torch.size(x, 0)) _1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False) return _1 def f2(x: Tensor) -> Tensor: _0 = torch.arange(1, dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False) return _0
非张量类型(在此情况下为int类型)的中间计算结果可能会被捕获为常量,使用追踪期间观察到的值。这导致追踪无法泛化。
除了len(),此问题还可能出现在:
- .item(),将张量转换为int/float。
- 任何将torch类型转换为numpy/Python原语的代码。
- 一些有问题的运算符,例如高级索引。
-
捕获设备:
>>> def f(x): ... return torch.arange(x.shape[0], device=x.device) >>> m = torch.jit.trace(f, torch.tensor([3])) >>> print(m.code) def f(x: Tensor) -> Tensor: _0 = ops.prim.NumToTensor(torch.size(x, 0)) _1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False) return _1 >>> m(torch.tensor([3]).cuda()).device device(type='cpu') # 错误!
类似地,接受设备参数的运算符将记住追踪期间使用的设备(这可以在m.code中看到)。因此,追踪可能无法泛化到不同设备上的输入。几乎不需要这种泛化,因为部署通常有目标设备。
让追踪泛化 Let Tracing Generalize
上述问题令人烦恼且通常是静默的(有警告,但没有错误),但它们可以通过良好的实践和工具成功解决:
-
注意TracerWarning:在上述两个示例中,torch.jit.trace实际上会发出警告。第一个示例打印:
a.py:3: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if x.sum() > 0:
注意这些警告(或更好的是,捕获它们)将暴露tracing的大多数泛化问题。
请注意,“捕获设备”案例不会打印警告,因为tracing根本不支持这种泛化。
-
并行性单元测试( Unittests for parity):导出后和部署前应进行单元测试,以验证导出模型生成的输出与原始eager-mode模型的输出相同,即:
assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
如果需要跨形状的泛化(并不总是需要),input2应具有不同于input1的形状。
Detectron2有许多泛化测试,例如此测试和此测试。一旦发现差距,检查导出TS-format模型的代码可以揭示它无法泛化的地方。
-
避免不必要的“特殊情况”条件:避免如下条件:
if x.numel() > 0: output = self.layers(x) else: output = torch.zeros((0, C, H, W)) # 创建空输出
这样处理特殊情况,如空输入。相反,改进self.layers或其底层内核,使其支持空输入。这将导致更清晰的代码,并且还会改进tracing。这就是为什么我参与了许多改进空输入支持的PyTorch问题,例如#12013、#36530、#56998。大多数PyTorch操作都能很好地处理空输入,因此很少需要这种分支。
-
使用符号形状:如前所述,tensor.size()在tracing期间返回Tensor,以便在图中捕获形状计算。用户应避免意外地将张量形状变为常量:
- 使用tensor.size(0)而不是len(tensor),因为后者是int。对于自定义类,实施.size方法或使用.len()而不是len(),例如像这里。
- 不要通过int()或torch.as_tensor转换大小,因为它们会捕获常量。此辅助函数有助于以在tracing和eager模式下均适用的方式将大小转换为张量。
-
混合tracing和scripting:它们可以混合使用,因此可以对tracing无法正常工作的少部分代码使用scripting。这样可以解决tracing的几乎所有问题。更多内容见下文。
混合Tracing和Scripting
Tracing和scripting都有各自的问题,最好的解决方案通常是将它们混合使用。这让我们可以获得两者的优点。
为了最大限度地减少对代码质量的负面影响,我们应在大多数逻辑上使用tracing,仅在必要时使用scripting。
-
使用@script_if_tracing:在torch.jit.trace内部,@script_if_tracing装饰器可以通过scripting编译函数。通常,这只需要对forward逻辑进行一些小的重构,以分离需要编译的部分(使用控制流的部分):
def forward(self, ...): # ... 一些forward逻辑 @torch.jit.script_if_tracing def _inner_impl(x, y, z, flag: bool): # 使用控制流等 return ... output = _inner_impl(x, y, z, flag) # ... 其他forward逻辑
通过仅编译需要的部分,代码质量损害比使整个模型scriptable要小得多,并且完全不影响模块的forward接口。
由@script_if_tracing装饰的函数必须是纯函数,不包含模块。因此,有时需要进行更多重构:
事实上,对于大多数视觉模型,动态控制流仅在少数子模块中需要,因此很容易scriptable。为了显示其需求的稀缺性,整个detectron2仅有两个函数由于控制流被@script_if_tracing装饰:paste_masks和heatmaps_to_keypoints,均用于后处理。其他一些函数也被装饰以跨设备泛化(非常罕见的需求)。
-
使用scripted/追踪子模块:
model.submodule = torch.jit.script(model.submodule) torch.jit.trace(model, inputs)
在此示例中,假设子模块无法正确追踪,我们可以在tracing之前对其进行scripting。然而,我不推荐这样做。如果可能,我建议在子模块的forward内部使用@script_if_tracing,使scripting仅限于子模块的内部,而不影响模块的接口。
同样,
model.submodule = torch.jit.trace(model.submodule, submodule_inputs) torch.jit.script(model)
这在scripting期间使用追踪子模块。这看起来不错,但在实践中用处不大:它会影响子模块的接口,要求它只能接受/返回Tuple[Tensor]——这是一个大约束,可能比scripting更损害代码质量。
一个罕见的“追踪子模块”有用场景是:
class A(nn.Module): def forward(self, x): # 根据动态、数据依赖条件分派到不同子模块: return self.submodule1(x) if x.sum() > 0 else self.submodule2(x)
@script_if_tracing无法编译这样的控制流,因为它只支持纯函数。如果子模块{1,2}复杂且无法scripting,使用追踪子模块在scripting父类A是最佳选择。
-
合并多个追踪:
Scripted模型支持两种追踪模型不支持的功能:
-
基于属性的控制流:一个scripted模块可以有可变属性(例如,一个布尔标志)影响控制流。追踪模块没有控制流。
-
多个方法:一个追踪模块仅支持forward(),但一个scripted模块可以有多个方法。 实际上,上述两种功能都是做同一件事:它们允许导出模型以不同方式使用,即根据调用者的要求执行不同的操作序列。
以下是一个示例场景,如果Detector被scripted,调用者可以修改其do_keypoint属性来控制其行为,或在需要时直接调用predict_keypoint方法。
class Detector(nn.Module): do_keypoint: bool def forward(self, img): box = self.predict_boxes(img) if self.do_keypoint: kpts = self.predict_keypoint(img, box) @torch.jit.export def predict_boxes(self, img): pass @torch.jit.export def predict_keypoint(self, img, box): pass
这种需求并不常见。但如果需要,如何在tracing中实现?我有一个不太干净的解决方案:
tracing只能捕获一个操作序列,所以自然的方法是追踪模型两次:
det1 = torch.jit.trace(Detector(do_keypoint=True), inputs) det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
然后我们可以别名它们的权重(不重复存储),并将两个追踪合并为一个模块进行scripting。
det2.submodule.weight = det1.submodule.weight class Wrapper(nn.ModuleList): def forward(self, img, do_keypoint: bool): if do_keypoint: return self[0](img) else: return self[1](img) exported = torch.jit.script(Wrapper([det1, det2]))
性能
如果一个模型既可追踪又可scriptable,那么tracing总是生成相同或更简单的图(因此可能更快)。
为什么?因为scripting试图忠实地表示你的Python代码,即使其中一些是多余的。例如:它并不总是足够智能,无法意识到Python代码中的某些循环或数据结构实际上是静态的,可以移除:
class A(nn.Module):
def forward(self, x1, x2, x3):
z = [0, 1, 2]
xs = [x1, x2, x3]
for k in z: x1 += xs[k]
return x1
model = A()
print(torch.jit.script(model).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
# z = [0, 1, 2]
# xs = [x1, x2, x3]
# x10 = x1
# for _0 in range(torch.len(z)):
# k = z[_0]
# x10 = torch.add_(x10, xs[k])
# return x10
print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
# x10 = torch.add_(x1, x1)
# x11 = torch.add_(x10, x2)
# return torch.add_(x11, x3)
这个例子非常简单,所以它实际上有scripting的解决方法(使用元组而不是列表),或者循环可能会在后续优化过程中得到优化。但重点是:图编译器并不总是足够智能。对于复杂模型,scripting可能会生成带有不必要复杂性的图,难以优化。
总结
tracing有明显的局限性:本文大部分内容都是在讨论tracing的局限性以及如何解决它们。我实际上认为这是tracing的优势:它有明显的局限性(和解决方案),所以你可以推理它是否可行。
相反,scripting更像是一个黑匣子:在尝试之前没有人知道它是否可行。我没有提到任何修复scripting的技巧:它们有很多,但不值得你花时间去探究和修复一个黑匣子。
tracing的影响范围小:tracing和scripting都影响代码的编写方式,但tracing的影响范围更小,造成的损害也更少:
- 它限制了输入/输出格式,但仅限于最外层的模块。(而且这个问题可以自动解决,如上所述。)
- 它需要一些代码更改以泛化(例如,在tracing中混合scripting),但这些更改仅进入受影响模块的内部实现,而不影响它们的接口。
另一方面,scripting影响:
- 涉及的每个模块和子模块的接口。
- 在我看来,这是最大的损害:接口中需要高级语法特性,我不愿意在接口设计上妥协。
- 这也可能影响训练,因为接口通常在训练和推理之间共享。 推理forward路径中的每一行代码。
影响范围大的原因是scripting可能会严重损害代码质量。
控制流与其他Python语法:PyTorch受到用户喜爱,因为他们可以“只写Python”,最重要的是写Python控制流。但Python的其他语法也很重要。如果能够写Python控制流(scripting)意味着失去其他优秀语法,我宁愿放弃写Python控制流的能力。
事实上,如果PyTorch对Python控制流不那么执着,并为我提供类似于tf.cond的符号控制流(例如torch.cond),我会很高兴使用这个,不再担心scripting:
def f(x):
return torch.cond(x.sum() > 0, lambda: torch.sqrt(x), lambda: torch.square(x))
那么f可以被正确追踪,我会很高兴使用这个,不再担心scripting。TensorFlow AutoGraph是一个自动化这一想法的好例子。
但显然不符合torch易用的设计原则