这里分析几个项目的注册机制的设计:
torch-tensorrt
这里分析torchscript-IR转换为TensorRT-op的converter的注册机制。
使用方法
通过get_node_converter_for这个函数来查询全局的注册converter:
...
auto schema = n->maybeSchema();
TORCHTRT_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) << " (conversion.AddLayer)");
auto converter = converters::get_node_converter_for(schema);
TORCHTRT_CHECK(
converter,
"Unable to convert node: "
<< util::node_info(n) << " (conversion.AddLayer)\nSchema: " << *schema << "\nConverter for " << schema->name()
<< " requested, but no such converter was found.\nIf you need a converter for this operator, you can try implementing one yourself\n"
<< "or request a converter: https://www.github.com/NVIDIA/Torch-TensorRT/issues");
...
get_node_converter_for的函数实现如下:
OpConverter get_node_converter_for(const torch::jit::FunctionSchema* signature) {
return get_converter_registry().GetConverter(signature);
}
具体实现
这个部分是关于 PyTorch 和 TensorRT 之间模型转换的一部分。具体地说,它通过定义和注册转换模式(Conversion Patterns)来实现 PyTorch 计算图节点(Node)到 TensorRT 层的转换。
在给定的代码中,具体的注册机制实现细节并没有完全展示,但从给出的代码结构和函数签名来看,这种注册机制通常是基于以下几个关键组件:
-
存储转换函数:通常,你会有一个全局或静态的数据结构(通常是一个哈希表或字典),用于存储已注册的转换函数。这个数据结构的键可能是操作的签名或名字,值是对应的转换函数。
std::unordered_map<std::string, OpConverter> converter_map;
-
注册函数:
register_node_converter
函数将新的转换函数添加到数据结构中。例如:void register_node_converter(std::string signature, OpConverter& converter) { converter_map[signature] = converter; }
-
查询函数:
get_node_converter_for
等函数用于查询给定操作的转换函数。OpConverter get_node_converter_for(const torch::jit::FunctionSchema* signature) { // 这里简化了,实际实现可能更复杂 return converter_map[signature->name()]; }
-
使用转换函数:在需要进行节点转换时,可以使用查询函数找到合适的转换函数,并调用它。
if (node_is_convertable(node)) { auto converter = get_node_converter_for(node->schema()); // 调用 converter }
-
自动注册:最后的
auto cat_registrations
部分使用了RegisterNodeConversionPatterns
类来自动注册一个特定的转换模式。这是一种用于自动执行注册的常用技术。在这里,pattern
函数接收一个ConversionPattern
,并在其内部调用register_node_converter
。RegisterNodeConversionPatterns&& pattern(ConversionPattern p) && { register_node_converter(p.signature, p.converter); return std::move(*this); }
这样的注册机制使得代码更加模块化和可扩展,允许用户或库开发者轻松地添加新的转换函数或修改现有函数。
这只是一种可能的实现方法,具体的实现可能会有所不同。但基本的模式——使用数据结构存储转换函数,通过注册函数添加新函数,通过查询函数获取函数——通常是一致的。
主要组件
-
类型定义:定义了两种类型
args
和OpConverter
。args
是一个包含变量(可能是Var
类型)的向量。OpConverter
是一个函数对象,接受转换上下文、一个 PyTorch JIT 节点和args
参数,返回一个布尔值表示转换是否成功。 -
ConversionPattern 结构体:定义了转换模式,包含一个字符串
signature
和一个OpConverter
对象。 -
注册函数:提供了几种不同的注册函数(
register_node_converter
)以用于注册转换模式。 -
RegisterNodeConversionPatterns 类:用于构建并注册转换模式。它具有一个
pattern
成员函数,该函数接受一个ConversionPattern
对象并可能返回该类的右值引用。 -
查询函数:如
node_is_convertable
和get_node_converter_for
,用于查询给定的节点是否可转换以及获取对应的转换函数。
自动注册
代码的最后一部分使用 RegisterNodeConversionPatterns
类来自动注册一个转换模式。这里,它为 PyTorch 的 aten::cat
操作注册了一个转换函数。
该转换函数的逻辑如下:
-
获取输入参数:从
args
中获取输入张量和维度。 -
处理张量:对输入张量进行一系列处理,包括类型提升(
promote_types
)和类型转换(castITensor
)。 -
添加 Concatenation 层:在 TensorRT 网络中添加一个 Concatenation 层,并设置其轴(axis)。
-
输出关联:将 TensorRT 输出张量与 PyTorch 计算图节点的输出关联。
-
记录调试信息:记录输出张量的形状。
总结
这段代码实现了一个灵活和可扩展的注册机制,允许用户为特定的 PyTorch 计算图节点(如 aten::cat
)定义定制的转换函数。这种机制非常有用,因为它使得 PyTorch 到 TensorRT 的模型转换更加模块化和可维护。用户可以轻松地添加对新操作的支持或修改现有操作的转换逻辑。
既然是注册机制,必须要有
在 GetConverter
函数中,使用了 std::unordered_map
作为转换器查找表(LUT, Lookup Table):
auto iter = converter_lut_.find(name);
if (iter == converter_lut_.end()) {
// ...
}
return iter->second;
std::unordered_map
是一个哈希表实现,其查找操作的时间复杂度通常是 O(1)。但这取决于多个因素,包括哈希函数的质量和哈希表的负载因子等。在最坏的情况下,时间复杂度可能达到 O(n),但这通常是极少见的。
从给出的代码来看,GetConverter
函数应该具有很高的性能:
- 哈希查找: 使用
std::unordered_map
进行快速的哈希查找。 - 早期返回: 如果找不到对应的转换器,函数会尽早返回,不会进行无用的计算。
关于代码中的注释 // ASK: Is there a better way than returning a nullptr?
:
返回 nullptr
是一种有效的方式来表示“没有找到对应的转换器”。但这也意味着调用方需要检查返回值是否为 nullptr
。如果忽略这一点,可能会导致运行时错误。
其他可能的方法包括:
- 抛出异常: 当转换器不存在时,可以抛出一个特定类型的异常。这样,调用方就会被迫处理这种异常情况。
- 返回一个可选类型: 如
std::optional<OpConverter>
,这样调用方必须明确检查是否有值。 - 返回一个空的/默认的转换器: 这样可以避免返回
nullptr
,但可能会引入其他问题,如需要额外的逻辑来处理这种特殊情况。
选择哪一种方式取决于你的具体需求和设计理念。