Pytorch模型加速系列(一)——新的Torch-TensorRT以及TorchScript/FX/dynamo

2023年真是令人兴奋的一年啊,AI技术层出不穷,前有Stable Diffusion,后有大语言模型。这些技术已经下放到生活中老长一段时间了,比如妙鸭相机和chatgpt,已经达到可以变现的程度,火爆程度无需多说。

除了这些新模型的快速迭代,Pytorch也升级到了2.0,可以使用一行代码提速你的模型:torch.compile。而作为compile函数中重要部分的TorchDynamo,也是2.0的重点:这个新的trace计算图的方式。

因此之前聊过的一些操作:

都会随之变化。

本系列文章主要内容,会讲述在新的dynamo工具出现后,Pytorch依赖TensorRT加速模型的方法有哪些改变、如何使用和一些个人看法等等。

因为TensorRT也马上出来9.0了,Pytorch2.0和Torch-TensorRT也在不停的更新迭代中,所以文章中有些内容之后可能随时会过时,关于过时的部分会有说明

Pytorch to tensorrt 一路走来

首先我们理一下关系。

2021年,Pytorch在1.8版本推出一个叫做FX的编译器,可以做 python-to-python code transformation,我们可以用这个工具去做很多事情,例如优化模型(fuse op)或者对模型做量化(PTQ、QAT),我也写过几篇相关的文章:

其中上述优化或者量化模型的第一个步骤就是先用FX中的trace工具去trace你的模型,然后再blablabla。第三篇和转TensorRT也是利用了fx的trace,将模型按照FX中的IR一个一个去构建trt的网络,这个过程使用了fx2trt这个库(这个库通过FX可以将pytorch模型转换为TensorRT,类似于onnx2trt,这个代码之前是在pytorch主仓库中进行维护)。

后来,写第三篇FX2TRT的时候,fx2trt的仓库分支由之前的Pytorch仓库合并到了pytorch/TensorRT仓库中,作为了Torch-TensorRT的一部分。原本Torch-TensorRT只支持走torchscript这条路线,因为fx2trt的迁移,现在又多了一条FX路线到TensorRT:

但又在去年12月份的时候,Pytorch更新2.0,发布了dynamo工具。TorchDynamo的作用是从 PyTorch 应用中抓取计算图 ,相比于TorchScript 和 TorchFX,TorchDynamo更加灵活、可靠性更高。于是Torch-TensorRT也同样开始支持dynamo的trace模式:

不过随着torchdynamo的快速发展,以及Pytorch对torchscript不再进行功能迭代(目前torchscript还未被废弃,但不进行功能维护了),同时fx的trace也有很多限制,最近Torch-TensorRT的模型trace路径已经从torchscript切换为dynamo

所以目前最新,也是未来我们需要学习的pytorch to TensorRT路径,就是torchdynamo+FX的路径。

因为工作中需要利用TensorRT去部署,而TensorRT也不可能支持所有模型,我们也没有必要必须把整个模型都使用TensorRT去构建,另外对于快速的模型开发迭代上线,纯使用TensorRT的路径,去手写那些不支持的op开发量很大,

因此,就需要Torch-TensorRT这样的编译器去快速支持TensorRT无法完全cover住的模型,不支持的子图拆分出来,支持的子图使用TensorRT编译,主打的就是一个兼容性强。

本系列文章会借着Torch-TensorRT这个工具以及torchdynamo去做一些实际的模型部署例子,当然核心后端还是trt,但也会尝试一些新的后端(dynamo出来很多后端都可以进行适配,比如triton-openai)。包含一些常见的使用场景,转trt、python runtime、C++ runtime、量化等等。

dynamo/fx/torchscript(graph capture)使用场景

官方提到,这三类相关人员需要了解graph capture:

  • 芯片设计
  • 模型加速开发相关人员
  • 编译器开发人员

当然这也可以映射到我们实际的使用场景中,除了torch2trt这个路径,它还能干很多事。对于我来说,我需要使用这个工具去分析模型,从而做一些有意思的事儿:

  • profiling模型
  • trace模型查看结构、shape信息
  • trace模型后修改模型结构、执行自定义pass操作,比如量化模型

同样,如果你有类似需求,最简单的,你想利用TensorRT加速你的Pytorch模型,除了onnx2trt,这条新路子你也可以试一试。

TorchDynamo and FX Graphs

torchdynamo作为核心,可以拦截我们的Python代码执行,并将其转换为FX中间表示(IR),并将其存储在称为FX图的特殊数据结构中。差不多就是这个样子:

具体细节的话就比较多了,推荐一篇大佬的文章:

对细节感兴趣的也可以看看官方的介绍

One important component of torch.compile is TorchDynamo. TorchDynamo is responsible for JIT compiling arbitrary Python code into FX graphs, which can then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode during runtime and detecting calls to PyTorch operations.

这里暂时不详细展开讲dynamo了,我们接下来要说的Torch2TRT这个路径,属于下面pt2路径图中,绿色方框的others:

Accelerating Inference in PyTorch 2.0 with TensorRT

使用Pytorch 2.0转trt的优点:

  • torch.compile 一句就可以转
  • Python based converter library and workflow for easy extensions and customization 转换代码基本都是python写的,想要写个转换算子也比较简单
  • Deployment in Python or C++ via TorchScript 后续可以通过torchscript进行c++部署

使用torchdynamo路径的转TensorRT的流程是啥,我举一个例子,这里仅仅是简单演示,更详细的介绍会在之后的系列中。

首先将model传入,注意看这里的ir是torch_compile,这也是新版torch-TensorRT中默认的IR。这个函数的作用就是将一个Pytorch module转换为一个包含TensorRT engine的Pytorch module(这个module内部的实现已经替换为了trt)。

# Build and compile the model with torch.compile, using Torch-TensorRT backend
optimized_model = torch_tensorrt.compile(
    model,
    ir="torch_compile",
    inputs=inputs,
    enabled_precisions=enabled_precisions,
    debug=debug,
    workspace_size=workspace_size,
    min_block_size=min_block_size,
    torch_executed_ops=torch_executed_ops,
)

这个torch_tensorrt.compile是torch-tensorrt的前端用法,其实这样写也是可以的:optimized_model = torch.compile(model, backend="torch_tensorrt", options={"enabled_precisions": enabled_precisions, ...}); optimized_model(*inputs)

看下就知道了,我们进入torch_tensorrt.compile函数内部,根据ir的设置,接下来走的是torch_compile分支:

    elif target_ir == _IRType.torch_compile:
        return torch_compile(module, enabled_precisions=enabled_precisions, **kwargs)

torch_complie函数的实看一下现比较简单,核心就是torch_tensorrt_backend,然后内部调用的也就是刚刚说的torch.compile,torch_tensorrt.compile只是包了一层前端:

def torch_compile(module, **kwargs):
    """
    Returns a boxed model which is the output of torch.compile.
    This does not compile the model to TRT. Execute this model on
    sample inputs to compile the model to TRT.
    """
    from torch_tensorrt.dynamo.backend import torch_tensorrt_backend

    boxed_fn = torch.compile(module, backend=torch_tensorrt_backend, options={**kwargs})

    return boxed_fn

这个backend是dynamo中的概念,torchdynamo可以从pytorch模型中抓取计算图( TorchDynamo Graph),然后backend可以拿计算图去进行相应的优化,这里的backend就是走TensorRT路径优化的后端。

TorchDynamo has a growing list of backends, which can be found in the backends folder or torch._dynamo.list_backends() each of which with its optional dependencies.

我们看下torch_tensorrt_backend的实现,又套了一个aot_torch_tensorrt_aten_backend

@td.register_backend(name="torch_tensorrt")
def torch_tensorrt_backend(
    gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
):
    DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend

    return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)

接着看aot_torch_tensorrt_aten_backend的实现,在拿到dynamo返回的计算图后,调用AOTAutograd将计算图中的torch IR转化为Aten IR,随后再将包含Aten IR的FX计算图转换为TensorRT的形式,这也是PT2.0的新路径。

@td.register_backend(name="aot_torch_tensorrt_aten")
def aot_torch_tensorrt_aten_backend(
    gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
):
    settings = parse_dynamo_kwargs(kwargs)

    custom_backend = partial(
        _pretraced_backend,
        settings=settings,
    )

    # Perform Pre-AOT Lowering for Module-Level Replacement
    gm = pre_aot_substitutions(gm)

    # Invoke AOTAutograd to translate operators to aten
    return aot_module_simplified(
        gm,
        sample_inputs,
        fw_compiler=make_boxed_compiler(custom_backend),
        decompositions=get_decompositions(),
    )

aot_torch_tensorrt_aten_backend函数中,转换trt的代码在_pretraced_backend里头,包的比较深,从函数名称和介绍可以得知这是个helper function会在转换trt失败后返回啥也没变的GraphModule forward。

def _pretraced_backend(
    gm: torch.fx.GraphModule,
    sample_inputs: Sequence[torch.Tensor],
    settings: CompilationSettings = CompilationSettings(),
):
    """Helper function to manage translation of traced FX module to TRT engines

    Args:
        module: FX GraphModule to convert
        inputs: Inputs to the module
        settings: Compilation settings
    Returns:
        Compiled FX GraphModule
    """
    try:
        logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

        trt_compiled = _compile_module(
            gm,
            sample_inputs,
            settings=settings,
        )
        return trt_compiled
    except:
        if not settings.pass_through_build_failures:
            logger.warning(
                "TRT conversion failed on the subgraph. See trace above. "
                + "Returning GraphModule forward instead.",
                exc_info=True,
            )
            return gm.forward
        else:
            logger.critical(
                "Halting compilation on build failure since "
                + "pass_through_build_failures was specified as True. "
                + "To return the default Torch implementation and avoid "
                + "halting compilation on engine build failures, "
                + "specify pass_through_build_failures=False."
            )
            raise

_pretraced_backend只是个helper函数,其会调用_compile_module函数,如下面的代码所示,这个函数会将trace后的FX模型按照设置以及现有支持的op将模型拆分成多个部分,完全被trt支持的模型将被转化为tensorrt:


def _compile_module(
    gm: torch.fx.GraphModule,
    sample_inputs: Sequence[torch.Tensor],
    settings: CompilationSettings = CompilationSettings(),
) -> torch.fx.GraphModule:
    """Compile a traced FX module

    Includes: Partitioning + Conversion Phases

    Args:
        module: FX GraphModule to convert
        inputs: Inputs to the module
        settings: Compilation settings
    Returns:
        Compiled FX GraphModule
    """
    # Partition module into components that can be TRT-accelerated
    partitioned_module = partition(
        gm,
        verbose=settings.debug,
        min_block_size=settings.min_block_size,
        torch_executed_ops=settings.torch_executed_ops,
    )

    # Store TRT replicas of Torch subgraphs
    trt_modules = {}

    # Iterate over all components that can be accelerated
    # Generate the corresponding TRT Module for those
    for name, _ in partitioned_module.named_children():
        submodule = getattr(partitioned_module, name)

        # Get submodule inputs
        submodule_inputs = get_submod_inputs(
            partitioned_module, submodule, sample_inputs
        )

        # Handle long/double inputs if requested by the user
        if settings.truncate_long_and_double:
            submodule_inputs = repair_long_or_double_inputs(
                partitioned_module, submodule, submodule_inputs, name
            )

        # Create TRT Module from submodule
        trt_mod = convert_module(
            submodule,
            submodule_inputs,
            settings=settings,
            name=name,
        )

        trt_modules[name] = trt_mod

    # Replace all FX Modules with TRT Modules
    for name, trt_mod in trt_modules.items():
        setattr(partitioned_module, name, trt_mod)

    return partitioned_module

核心实现在convert_module中,这个实现和之前的FX的路径一样,创建一个TRTInterpreter解释器(torch.fx中的概念),在run的时候遍历每一个op实现pytorch-op(aten op)到 trt op的转换,这里的转换过程也就是利用TensorRT的python api去搭建TensorRT network的过程,待搭建完成后执行build即可。

def convert_module(
    module: torch.fx.GraphModule,
    inputs: Sequence[torch.Tensor],
    settings: CompilationSettings = CompilationSettings(),
    name: str = "",
):
    """Convert an FX module to a TRT module
    Args:
        module: FX GraphModule to convert
        inputs: Sequence of Tensors representing inputs to the module
        settings: Compilation settings
        name: TRT engine name
    Returns:
        _PythonTorchTRTModule or TorchTensorRTModule
    """
    # Specify module output data types to ensure TRT output types agree with
    # that of the equivalent Torch module
    module_outputs = module(*inputs)

    if not isinstance(module_outputs, (list, tuple)):
        module_outputs = [module_outputs]

    output_dtypes = list(output.dtype for output in module_outputs)
    interpreter = TRTInterpreter(
        module,
        Input.from_tensors(inputs, disable_memory_format_check=True),
        logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
        output_dtypes=output_dtypes,
    )
    interpreter_result = interpreter.run(
        workspace_size=settings.workspace_size,
        precision=settings.precision,
        profiling_verbosity=(
            trt.ProfilingVerbosity.VERBOSE
            if settings.debug
            else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
        ),
        max_aux_streams=settings.max_aux_streams,
        version_compatible=settings.version_compatible,
        optimization_level=settings.optimization_level,
    )

    if settings.use_python_runtime:
        return _PythonTorchTRTModule(
            engine=interpreter_result.engine,
            input_names=interpreter_result.input_names,
            output_names=interpreter_result.output_names,
        )

    else:
        from torch_tensorrt.dynamo.runtime import TorchTensorRTModule

        with io.BytesIO() as engine_bytes:
            engine_bytes.write(interpreter_result.engine.serialize())
            engine_str = engine_bytes.getvalue()
        return TorchTensorRTModule(
            serialized_engine=engine_str,
            name=name,
            input_binding_names=interpreter_result.input_names,
            output_binding_names=interpreter_result.output_names,
        )

到了这里就整体流程就over了,流程图如下:

两种OP

Core aten ops是aten操作符的核心子集,可用于组合其他操作符。Core aten IR是fully functional,并且在此opset中没有inplace或_out变体。与Prims IR相比,core aten ops重用了“native_functions.yaml”中现有的aten ops,并且不会进一步将ops分解为explicit type promotion and broadcasting ops。该opset旨在作为与后端接口交互的功能性IR。

Aten Op是主要转换时使用的op ir。

Prims IR是一组primitive operators,可用于组合其他运算符。 Prims IR是比核心aten IR更低级的操作集,它进一步将操作分解为explicit type promotion and broadcasting ops:prims.convert_element_type和prims.broadcast_in_dim。 这个操作集旨在与编译器后端进行接口交互。

Prims IR主要是和编译器,硬件优化交互的,torch2TensorRT还没到这么底层。

部分参考:

C++路线相关信息

既然都转成TensorRT了,当然性能是第一考虑的,转成的TensorRT最好也可以直接嵌入到C++ runtime中,比如这样:

  • Get an model in Aten IR via tracing or dynamo (torch.export)
  • Compile using torch_tensorrt FX frontend
  • Target the Torch-TensorRT runtime instead of purepython runtime (Compiled module is savable using traditional PyTorch methods,Torch-TRT runtime modules are TorchScript traceable in orderto export to non Python environments)
  • Loadable just like TorchScript modules

这是新版torch-TensorRT的runtime方式,可以无缝迁移到C++中。不过这种方式和pt2中要推的torch.export是两码事,感兴趣可以看这个讨论:

TRTModuleNext

实现fx2trt后,可以包装成C++导出来的核心是TRTModuleNext,不过这个后来改名成TorchTensorRTModule了,不过核心没有变:

The FX frontend will return a torch.nn.Module containing torch_tensorrt.TRTModuleNext submodules instead of torch_tensorrt.fx.TRTModules. The features of these modules are nearly identical but with a few key improvements.

  1. TRTModuleNext profiling dumps a trace visualizable with Perfetto (see above for more details).
  2. TRTModuleNext modules are torch.jit.trace-able, meaning you can save FX compiled modules as TorchScript for python-less / C++ deployment scenarios. Traced compiled modules have the same deployment instructions as compiled modules produced by the TorchScript frontend.
  3. TRTModuleNext maintains the same serialization workflows TRTModule supports as well (state_dict / extra_state, torch.save/torch.load)

这个的出现可以使从FX路径生成的torch-trt模块也可以通过torchscript的方式导出,从而在C++中调用。

改名后的TorchTensorRTModule实现也可以看下,后续文章中会详细讲:

class TorchTensorRTModule(torch.nn.Module):
    """TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.

    This module is backed by the Torch-TensorRT runtime and is fully compatibile with both
    FX / Python deployments (just ``import torch_tensorrt`` as part of the application) as
    well as TorchScript / C++ deployments since TorchTensorRTModule can be passed to ``torch.jit.trace``
    and then saved.

    The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where
    the internal implementation is ``return Tuple(torch.ops.tensorrt.execute_engine(list(inputs), self.engine))``

    > Note: TorchTensorRTModule only supports engines built with explict batch

    Attributes:
        name (str): Name of module (for easier debugging)
        engine (torch.classess.tensorrt.Engine): Torch-TensorRT TensorRT Engine instance, manages [de]serialization, device configuration, profiling
        input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules
        output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned
    """

    def __init__(
        self,
        serialized_engine: bytearray = bytearray(),
        name: str = "",
        input_binding_names: List[str] = [],
        output_binding_names: List[str] = [],
        target_device: Device = Device._current_device(),
    ):
        """__init__ method for torch_tensorrt.dynamo.runtime._TorchTensorRTModule.TorchTensorRTModule

        Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
        a PyTorch ``torch.nn.Module`` around it.

        If binding names are not provided, it is assumed that the engine binding names follow the following convention:

            - [symbol].[index in input / output array]
                - ex. [x.0, x.1, x.2] -> [y.0]

        Args:
            name (str): Name for module
            serialized_engine (bytearray): Serialized TensorRT engine in the form of a bytearray
            input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules
            output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned
            target_device: (torch_tensorrt.Device): Device to instantiate TensorRT engine on. Must be a compatible device i.e. same GPU model / compute capability as was used to build the engine

        Example:

            ..code-block:: py

                with io.BytesIO() as engine_bytes:
                    engine_bytes.write(trt_engine.serialize())
                    engine_str = engine_bytes.getvalue()

                trt_module = TorchTensorRTModule(
                    engine_str,
                    name="my_module",
                    input_binding_names=["x"],
                    output_binding_names=["output"],
                )

        """
        super(TorchTensorRTModule, self).__init__()

编译Torch-TensorRT

因为Torch-TensorRT变动很频繁,这里建议自己手动编译最新的Torch-TensorRT,既可以在官方更新后直接拉,也可以避免一些环境不匹配的问题。

官方默认的编译方式是bazel。因为我不喜欢用bazel,所以这里拆开编译:

  • 通过cmake编译bin文件
  • 通过setup.py编译python前端

cmake编译bin文件

执行命令:

cmake .. -DCMAKE_PREFIX_PATH=/server/convert/pytorch/torch/lib -DTensorRT_ROOT=/data/software/TensorRT-8.6.1.6 

libtorch的话,直接使用官方的会报缺失头文件的问题,所以这里干脆自己编译了Pytorch。

编译后有三个lib以及torchtrtc可执行文件:

tree
.
├── libtorchtrt_plugins.so
├── libtorchtrt_runtime.so
└── libtorchtrt.so

接下来编译python前端。

python前端

先说个坑,如果我们的Pytorch是通过pip安装的(pytorch.org):


pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

这个版本默认是old pre-cxx11 ABI,则需要刚才编译的lib也是old pre-cxx11 ABI才行,这个需要匹配。因为我这里是自行编译的Pytorch,所以我这边都是cxx11 ABI版本。

回到正题,进入/TensorRT/py目录,修改下setup.py文件(因为我这里没有使用bazel这种方式,而setup.py默认代码是搭配的bazel),注释掉相关调用bazel编译的代码:

class DevelopCommand(develop):
   ...
    def run(self):
        if FX_ONLY:
            gen_version_file()
            develop.run(self)
        else:
            global CXX11_ABI
            # build_libtorchtrt_pre_cxx11_abi(develop=True, cxx11_abi=CXX11_ABI)
            gen_version_file()
            # copy_libtorchtrt()
            develop.run(self)

然后执行python setup.py develop,编译好后将上小节中得到libtorchtrt.so移动到TensorRT/py/torch_tensorrt/lib中,即可(如果目录不存在就创建一个)。

这样就安装最新版的torch-tensorrt了:

torch-tensorrt     2.0.0.dev0+65277c52 /home/oldpan/code/convert/TensorRT/py

未完待续 TODO

下篇中将介绍使用Torch-TensorRT:

  • 编译一个实际的模型,转换细节
  • 导出并且部署到C++中
  • 一些细节

关于dynamo、trace相关的知识也会写一些。

参考