TensorRT中的plugin——Efficient NMS Plugin

比较高效的NMS官方插件,支持batch以及多类别NMS。

输入

接受两种输入,会根据输入的数量判断使用哪个模式:

    1. Standard NMS Mode: Only two input tensors are given, (i) the bounding box coordinates and (ii) the corresponding classification scores for each box.
    1. Fused Box Decoder Mode: Three input tensors are given, (i) the raw localization predictions for each box originating directly from the localization head of the network, (ii) the corresponding classification scores originating from the classification head of the network, and (iii) the default anchor box coordinates usually hardcoded as constant tensors in the network.

参数

根据参数可以看到支持的功能以及一些细节:

Type Parameter Description
float score_threshold * The scalar threshold for score (low scoring boxes are removed).
float iou_threshold The scalar threshold for IOU (additional boxes that have high IOU overlap with previously selected boxes are removed).
int max_output_boxes The maximum number of detections to output per image.
int background_class The label ID for the background class. If there is no background class, set it to -1.
bool score_activation * Set to true to apply sigmoid activation to the confidence scores during NMS operation.
bool class_agnostic Set to true to do class-independent NMS; otherwise, boxes of different classes would be considered separately during NMS. 支持 class_agnostic 和 class_specific ,可以所有类别一块nms或者单独类别分别nms
int box_coding Coding type used for boxes (and anchors if applicable), 0 = BoxCorner, 1 = BoxCenterSize. 0是xyxy的形式,1是xywh的形式

如何使用

class TRT8_NMS(torch.autograd.Function):
    '''TensorRT NMS operation'''
    @staticmethod
    def forward(
        ctx,
        boxes,
        scores,
        background_class=-1,
        box_coding=1,
        iou_threshold=0.45,
        max_output_boxes=100,
        plugin_version='1',
        score_activation=0,
        score_threshold=0.25,
    ):
        batch_size, num_boxes, num_classes = scores.shape
        num_det = torch.randint(0,
                                max_output_boxes, (batch_size, 1),
                                dtype=torch.int32)
        det_boxes = torch.randn(batch_size, max_output_boxes, 4)
        det_scores = torch.randn(batch_size, max_output_boxes)
        det_classes = torch.randint(0,
                                    num_classes,
                                    (batch_size, max_output_boxes),
                                    dtype=torch.int32)
        return num_det, det_boxes, det_scores, det_classes

    @staticmethod
    def symbolic(g,
                 boxes,
                 scores,
                 background_class=-1,
                 box_coding=1,
                 iou_threshold=0.45,
                 max_output_boxes=100,
                 plugin_version='1',
                 score_activation=0,
                 score_threshold=0.25):
        out = g.op('TRT::EfficientNMS_TRT',
                   boxes,
                   scores,
                   background_class_i=background_class,
                   box_coding_i=box_coding,
                   iou_threshold_f=iou_threshold,
                   max_output_boxes_i=max_output_boxes,
                   plugin_version_s=plugin_version,
                   score_activation_i=score_activation,
                   score_threshold_f=score_threshold,
                   outputs=4)
        nums, boxes, scores, classes = out
        return nums, boxes, scores, classes

注意点

性能问题:

  • 选择 score_threshold 很重要,越高代表越少的框参与计算,也就越快。如果你的模型还很多分数低的框(比如很多负样本框需要过滤),那么估计会比较慢。一般score_threshold低于0.01就会很慢
  • 提供给NMS的分数可以是sigmoid之前的,然后开启score_activation参数。这样的话,sigmoid操作只会在最后max_output_boxes这些框中进行操作(在NMS中操作了),而不是在网络中对所有框进行操作(你导出网络中输出是sigmoid之后给NMS的)
  • When using networks with many anchors, such as EfficientDet or SSD, it may be more efficient to do box decoding within the NMS plugin. For this, pass the raw box predictions as the boxes input, and the default anchor coordinates as the optional third input to the plugin.

参考