CPP代码规范之注册机制

这里分析几个项目的注册机制的设计:

torch-tensorrt

这里分析torchscript-IR转换为TensorRT-op的converter的注册机制。

使用方法

通过get_node_converter_for这个函数来查询全局的注册converter:

...
  auto schema = n->maybeSchema();
  TORCHTRT_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) << " (conversion.AddLayer)");

  auto converter = converters::get_node_converter_for(schema);
  TORCHTRT_CHECK(
      converter,
      "Unable to convert node: "
          << util::node_info(n) << " (conversion.AddLayer)\nSchema: " << *schema << "\nConverter for " << schema->name()
          << " requested, but no such converter was found.\nIf you need a converter for this operator, you can try implementing one yourself\n"
          << "or request a converter: https://www.github.com/NVIDIA/Torch-TensorRT/issues");
...

get_node_converter_for的函数实现如下:

OpConverter get_node_converter_for(const torch::jit::FunctionSchema* signature) {
  return get_converter_registry().GetConverter(signature);
}

具体实现

这个部分是关于 PyTorch 和 TensorRT 之间模型转换的一部分。具体地说,它通过定义和注册转换模式(Conversion Patterns)来实现 PyTorch 计算图节点(Node)到 TensorRT 层的转换。

在给定的代码中,具体的注册机制实现细节并没有完全展示,但从给出的代码结构和函数签名来看,这种注册机制通常是基于以下几个关键组件:

  1. 存储转换函数:通常,你会有一个全局或静态的数据结构(通常是一个哈希表或字典),用于存储已注册的转换函数。这个数据结构的键可能是操作的签名或名字,值是对应的转换函数。

    std::unordered_map<std::string, OpConverter> converter_map;
    
  2. 注册函数register_node_converter 函数将新的转换函数添加到数据结构中。例如:

    void register_node_converter(std::string signature, OpConverter& converter) {
      converter_map[signature] = converter;
    }
    
  3. 查询函数get_node_converter_for 等函数用于查询给定操作的转换函数。

    OpConverter get_node_converter_for(const torch::jit::FunctionSchema* signature) {
      // 这里简化了,实际实现可能更复杂
      return converter_map[signature->name()];
    }
    
  4. 使用转换函数:在需要进行节点转换时,可以使用查询函数找到合适的转换函数,并调用它。

    if (node_is_convertable(node)) {
      auto converter = get_node_converter_for(node->schema());
      // 调用 converter
    }
    
  5. 自动注册:最后的 auto cat_registrations 部分使用了 RegisterNodeConversionPatterns 类来自动注册一个特定的转换模式。这是一种用于自动执行注册的常用技术。在这里,pattern 函数接收一个 ConversionPattern,并在其内部调用 register_node_converter

    RegisterNodeConversionPatterns&& pattern(ConversionPattern p) && {
      register_node_converter(p.signature, p.converter);
      return std::move(*this);
    }
    

这样的注册机制使得代码更加模块化和可扩展,允许用户或库开发者轻松地添加新的转换函数或修改现有函数。

这只是一种可能的实现方法,具体的实现可能会有所不同。但基本的模式——使用数据结构存储转换函数,通过注册函数添加新函数,通过查询函数获取函数——通常是一致的。

主要组件

  1. 类型定义:定义了两种类型 argsOpConverterargs 是一个包含变量(可能是 Var 类型)的向量。OpConverter 是一个函数对象,接受转换上下文、一个 PyTorch JIT 节点和 args 参数,返回一个布尔值表示转换是否成功。

  2. ConversionPattern 结构体:定义了转换模式,包含一个字符串 signature 和一个 OpConverter 对象。

  3. 注册函数:提供了几种不同的注册函数(register_node_converter)以用于注册转换模式。

  4. RegisterNodeConversionPatterns 类:用于构建并注册转换模式。它具有一个 pattern 成员函数,该函数接受一个 ConversionPattern 对象并可能返回该类的右值引用。

  5. 查询函数:如 node_is_convertableget_node_converter_for,用于查询给定的节点是否可转换以及获取对应的转换函数。

自动注册

代码的最后一部分使用 RegisterNodeConversionPatterns 类来自动注册一个转换模式。这里,它为 PyTorch 的 aten::cat 操作注册了一个转换函数。

该转换函数的逻辑如下:

  1. 获取输入参数:从 args 中获取输入张量和维度。

  2. 处理张量:对输入张量进行一系列处理,包括类型提升(promote_types)和类型转换(castITensor)。

  3. 添加 Concatenation 层:在 TensorRT 网络中添加一个 Concatenation 层,并设置其轴(axis)。

  4. 输出关联:将 TensorRT 输出张量与 PyTorch 计算图节点的输出关联。

  5. 记录调试信息:记录输出张量的形状。

总结

这段代码实现了一个灵活和可扩展的注册机制,允许用户为特定的 PyTorch 计算图节点(如 aten::cat)定义定制的转换函数。这种机制非常有用,因为它使得 PyTorch 到 TensorRT 的模型转换更加模块化和可维护。用户可以轻松地添加对新操作的支持或修改现有操作的转换逻辑。

既然是注册机制,必须要有

GetConverter 函数中,使用了 std::unordered_map 作为转换器查找表(LUT, Lookup Table):

auto iter = converter_lut_.find(name);
if (iter == converter_lut_.end()) {
  // ...
}
return iter->second;

std::unordered_map 是一个哈希表实现,其查找操作的时间复杂度通常是 O(1)。但这取决于多个因素,包括哈希函数的质量和哈希表的负载因子等。在最坏的情况下,时间复杂度可能达到 O(n),但这通常是极少见的。

从给出的代码来看,GetConverter 函数应该具有很高的性能:

  1. 哈希查找: 使用 std::unordered_map 进行快速的哈希查找。
  2. 早期返回: 如果找不到对应的转换器,函数会尽早返回,不会进行无用的计算。

关于代码中的注释 // ASK: Is there a better way than returning a nullptr?

返回 nullptr 是一种有效的方式来表示“没有找到对应的转换器”。但这也意味着调用方需要检查返回值是否为 nullptr。如果忽略这一点,可能会导致运行时错误。

其他可能的方法包括:

  1. 抛出异常: 当转换器不存在时,可以抛出一个特定类型的异常。这样,调用方就会被迫处理这种异常情况。
  2. 返回一个可选类型: 如 std::optional<OpConverter>,这样调用方必须明确检查是否有值。
  3. 返回一个空的/默认的转换器: 这样可以避免返回 nullptr,但可能会引入其他问题,如需要额外的逻辑来处理这种特殊情况。

选择哪一种方式取决于你的具体需求和设计理念。