聊一聊应急C++部署方案——libtorch with torchscript

很早之前写过一篇libtorch的文章,那会torch的版本刚到1.0,而现在torch已经发展到 2.2.2 的版本。libtorch作为很早一代的C++部署方案,有很多的使用场景,而且现在依然可以使用,在pytorch2.0时代,libtorch的一些部署方案以及torchscript依然可以在aot场景发挥余热。

这篇主要介绍libtorch和torchscript在常见部署场景中的用法。后续文章也会介绍一些和AOTInductor结合使用的一些实际场景。

libtorch能做什么

libtorch 基于 PyTorch 的底层实现(比如 ATen 和 autograd),使用 C++ 作为API。 和 Pytorch 共享很多 torch 原生支持的功能,比如 tensor 计算、求导、搭建网络什么的。

有差不多这些组件:

  • ATen: The foundational tensor and mathematical operation library on which all else is built.
  • Autograd: Augments ATen with automatic differentiation.
  • C++ Frontend: High level constructs for training and evaluation of machine learning models.
  • TorchScript: An interface to the TorchScript JIT compiler and interpreter.
  • C++ Extensions: A means of extending the Python API with custom C++ and CUDA routines.

具体组件介绍可以看:https://pytorch.org/cppdocs/

在部署任务中,有很多时候需要将 pytorch 模型部署到 c++环境。大部分情况都是尽量将整个模型转换到 TensorRT、onnx或者其他高性能部署平台的模型格式。但是如果这个模型比较特殊:

  • 有很多算子不支持,比如某些 op 不支持 trt,比如 topk中 k 大于 3840或者 k 为动态的情况
  • 或者说不是模型,没有权重,很多时候是前后处理的逻辑,比如 NMS

这个时候需要

如何写libtorch

写libtorch不再抓狂——自己用的C++在线调试notebook

导出你需要的后处理操作

可以采用cling C++ notebook的方式去实际操作学习一下。

/// IValue (Interpreter Value) is a tagged union over the types /// supported by the TorchScript interpreter. IValues contain their /// values as an IValue::Payload, which holds primitive types /// (int64_t, bool, double, Device) and Tensor as values, /// and all other types as a c10::intrusive_ptr. In order to /// optimize performance of the destructor and related operations by /// making the Tensor and c10::intrusive_ptr paths generate the /// same code, we represent a null c10::intrusive_ptr as /// UndefinedTensorImpl::singleton(), not nullptr.

Tensor内存操作

可以直接通过void*指针得到Tensor

at::Tensor tensor_image = torch::from_blob(image.data, {1, 3, image.rows, image.cols}, at::kByte);
tensor_image = tensor_image.to(at::kFloat);


std::unique_ptr<float[]> outputData(new float[1*17*96*72]);
auto res_point = torch::from_blob(outputData.get(), {input_shape[0],input_shape[1],input_shape[2],input_shape[3]});

Here, I assume that image.data is 8-bit byte values. The to(at::kFloat) will convert the 8-bit values into 32-bit floating points just as if you wrote static_cast(b) where b is a byte – just in case that wasn’t clear. If image.data is already floats, you can just write at::kFloat in place of at::kByte and skip the conversion of course. What’s super important to know is that from_blob does not take ownership of the data! It only interprets the data as a tensor, but doesn’t store the data itself. It’s easy to fix this if you want to, by calling .clone() on the tensor, since that will incur a copy of the data such that the resulting tensor will indeed own its data (which means the original cv::Mat can be destroyed and the cloned tensor will live on).

On the other side, it’s actually easier. You can use tensor.data() to access a tensor’s underlying data through a T*. For example, tensor_image.data() would give you a float*. If you want a more raw void* because you’re dumping the raw data somewhere else, there’s also a data_ptr() method that gives you a raw byte pointer.

Let me know if this helps.

// 假如传入的数据来自GPU void* input_buffer = ...
auto options = torch::TensorOptions().device(at::kCUDA);
auto detections = torch::from_blob(input_buffer, {input_shape[0],input_shape[1],input_shape[2]}, options);

Tensor基本操作

F::interpolate 
cv::Mat image = cv::imread("/home/lll/Pictures/test.jpg");


torch::Tensor image_tensor = torch::from_blob(image.data, {image.rows, image.cols, 3}, torch::kByte);

image_tensor = image_tensor.permute({2, 0, 1}).toType(torch::kFloat).div_(255);
image_tensor.sub_(0.5).div_(0.5);
image_tensor = image_tensor.unsqueeze(0);
image_tensor = image_tensor.to(torch::kCUDA);
image_tensor = image_tensor.contiguous();  // 必要

namespace F = torch::nn::functional;
image_tensor = F::interpolate(
        image_tensor,
        F::InterpolateFuncOptions()
                .mode(torch::kBilinear)
                .size(std::vector<int64_t>({512, 512}))
                .align_corners(true)
);
image_tensor = image_tensor.mul(0.5).add(0.5).mul(255);
image_tensor = image_tensor.squeeze(0).permute({1, 2, 0}).toType(torch::kByte).to(torch::kCPU);

cv::Mat test_mat(512, 512, CV_8UC3);
std::memcpy((void *) test_mat.data, image_tensor.data_ptr(), sizeof(torch::kU8) * image_tensor.numel());
cv::imshow("test", test_mat);
cv::waitKey(0);

如何调试 torchscript/libtorch 模型

很多时候我们要确认 python 和 c++ 的 op 结果是否可以对应上,最直接粗暴的方法就是在保证相同输入的前提下,在最终的输出地点比较两边的输出是否相同,一般可以比较:

  • 相似度,一般用于评估模型输出的整体相似性,适用于端到端的比较。 比如比较检测模型最终输出的坐标点的相似度
  • 绝对相对误差,一般比较具体数值的准确性,看我们的场景需求,理论上pytorch 和 libtorch 的输出应该完全一致

首先我们导出 pytorch 端的输出,把 value 封装为 model 的格式:

class Container(torch.nn.Module):
    def __init__(self, my_values):
        super().__init__()
        for key in my_values:
            setattr(self, key, my_values[key])

my_values = {
    'res': prediction.cpu()
}

container = torch.jit.script(Container(my_values))
container.save("container.pt")

在保存好.pt 之后,在 C++中可以通过这种方式加载出来在 pytorch 端保存的结果:

torch::jit::script::Module container = torch::jit::load("/container.pt");
auto res = container.attr("res").toTensor();

在取出来值之后,就可以进行比较了,这里 reshape 是为了更好的比较,我们保证两个比较的对象的维度是相同的:

std::cout << "cosine_similarity: "   
 << torch::cosine_similarity(res_libtorch.reshape({1,-1}), res.reshape({1,-1})) << "\n";

后记

此文章不断更新中,更新地址:


class Container(torch.nn.Module):
    def __init__(self, my_values):
        super().__init__()
        for key in my_values:
            setattr(self, key, my_values[key])


class DebugModule(torch.nn.Module):
    
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)   
    
    def d(self, a, b):
        return torch.sqrt((a[:, None, :2] - b[None, :, :2]).pow(2).sum(-1))

    def u(self, r):
        return r.pow(2) * torch.log(r + 1e-6)

    def z(self, x, c, theta):
        x = torch.as_tensor(x, dtype=torch.float32)
        U = self.u(self.d(x, c))
        w, a = theta[:-3], theta[-3:]
        reduced = len(theta) == c.size(0) + 2
        
        if reduced:
            w = torch.cat((-w.sum(dim=0,keepdim=True), w))
            
        b = torch.matmul(U, w)
        return a[0] + a[1] * x[:, 0] + a[2] * x[:, 1] + b

    def fit(self, c, lambd:float = 0.0):
        n = c.size(0)
        U = self.u(self.d(c, c))
        K = U + torch.eye(n, dtype=torch.float32).cuda() * lambd
        P = torch.ones(n, 3, dtype=torch.float32).cuda()
        P[:, 1:] = c[:, :2]
        v = torch.zeros(n + 3, dtype=torch.float32).cuda()
        v[:n] = c[:, -1]
        A = torch.zeros(n + 3, n + 3, dtype=torch.float32).cuda()
        A[:n, :n] = K
        A[:n, -3:] = P
        A[-3:, :n] = P.transpose(0, 1)
        rlt = torch.linalg.solve(A, v.unsqueeze(1)).squeeze(1)
        return rlt

    def uniform_grid_pt(self, H:int, W:int):
        
        c = torch.empty(H, W, 2, dtype=torch.float32).cuda()
        c[:, :, 0] = torch.linspace(0, 1, W).unsqueeze(0).repeat(H, 1).cuda()
        c[:, :, 1] = torch.linspace(0, 1, H).unsqueeze(1).repeat(1, W).cuda()
        return c

    def tps_theta_from_points_pt(self, c_src, c_dst):
        delta = c_src - c_dst
        cx = torch.cat((c_dst, delta[:, :1]), 1)
        cy = torch.cat((c_dst, delta[:, 1:2]), 1)
        theta_dx = self.fit(cx)
        theta_dy = self.fit(cy)
        return torch.stack((theta_dx, theta_dy), -1)

    def tps_grid_pt(self, theta, c_dst, dshape:Tuple[int, int]):
        
        ugrid = self.uniform_grid_pt(dshape[0], dshape[1])
        theta = theta.clone().detach().to(torch.float32)
        dx = self.z(ugrid.reshape(-1, 2), c_dst, theta[:, 0]).reshape(dshape[:2])
        dy = self.z(ugrid.reshape(-1, 2), c_dst, theta[:, 1]).reshape(dshape[:2])
        dgrid = torch.stack((dx, dy), -1)
        grid = dgrid + ugrid
        return grid 

    def tps_grid_to_remap_pt(self, grid, sshape:Tuple[int, int]):
        
        mx = (grid[:, :, 0] * sshape[1]).to(torch.float32)
        my = (grid[:, :, 1] * sshape[0]).to(torch.float32)
        return mx, my
    
    def forward(
        self,
        c_src:torch.Tensor,c_dst:torch.Tensor,
        height:int, width:int
    ):
        
        grid_size = (64, 64)
        theta = self.tps_theta_from_points_pt(c_src, c_dst)
        grid = self.tps_grid_pt(theta, c_dst, grid_size)
        mapx, mapy = self.tps_grid_to_remap_pt(grid, (height, width))

        return (mapx, mapy)


def get_remap_maps(c_src, c_dst, height, width):
    
    grid_size = (64, 64)
    theta = tps_theta_from_points(c_src, c_dst)
    grid = tps_grid(theta, c_dst, grid_size)
    mapx, mapy = tps_grid_to_remap(grid, (height, width))

    return (mapx, mapy)
    
   


if __name__ == '__main__':
    
    if False:
        
        # 可能因为版本不一样,通过triton server 直接生成的 pt 无法直接使用,
        # 可以在 triton sever 的 model.py 中 torch.save,然后在这个 demo 中组装为 jit的格式
        input_tensor = torch.load("/code/lab/input.pt")
        print(input_tensor.reshape(1,-1)[0][:100])
        container = torch.jit.script(Container({"input": input_tensor}))
        container.save("/code/lab/input_jit.pt")
        
        input_tensor_jit = torch.jit.load("/code/jit.pt")
    
    
   
        

    

        remap_model = torch.jit.script(RemapModule())
        mapx_pt, mapy_pt = remap_model(c_src.cuda(), c_dst.cuda(), 1920, 1920)
        
        print(torch.cosine_similarity(mapx_pt.cpu().reshape(1,-1), mapx_ref.reshape(1,-1)))
        print(torch.cosine_similarity(mapy_pt.cpu().reshape(1,-1), mapy_ref.reshape(1,-1)))


        remap_model_ori = RemapModule()

        t1 = benchmark.Timer(
        stmt='remap_model_ori(c_src,c_dst,1920, 1920)',
        setup='from __main__ import remap_model_ori',
        globals={"c_src": c_src.cuda(), "c_dst": c_dst.cuda()})
        print(t1.timeit(1000))

        remap_model.save("/code/lab/model.pt")
        model_load = torch.jit.load("/code/lab/model.pt")
        
        mapx_pt, mapy_pt = model_load(c_src.cuda(), c_dst.cuda(), 1920, 1920)
        
        print(torch.cosine_similarity(mapx_pt.cpu().reshape(1,-1), mapx_ref.reshape(1,-1)))
        print(torch.cosine_similarity(mapy_pt.cpu().reshape(1,-1), mapy_ref.reshape(1,-1)))

        t1 = benchmark.Timer(
        stmt='model_load(c_src,c_dst,1920, 1920)',
        setup='from __main__ import model_load',
        globals={"c_src": c_src.cuda(), "c_dst": c_dst.cuda()})
        print(t1.timeit(1000))
    
    
    
@torch.jit.script
def non_max_suppression(
        prediction,
        conf_thres:float,
        iou_thres:float,
        max_det:int,
        nc:int,  # number of classes (optional)
):

    # 这里默认 agnostic是True
    # 整理步骤,过来的框应该是转置后的,1 先卡分数 2 然后xywh2xyxy 3 然后排序 4 然后nms
    bs = prediction.shape[0]  # batch size
    nm = prediction.shape[1] - int(nc) - 4
    mi = 4 + nc  # mask start index
    xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates  这个188是 4+15+169的mask 这里是15个类的score取max

    # min_wh = 2  # (pixels) minimum box width and height
    max_wh = 7680  # (pixels) maximum box width and height
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()

    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs  # 这里开辟可以优化,暂时先不管
    x = prediction[0]  # image index, image inference
    x = x.transpose(0, -1)[xc[0]]  # confidence
    # If none remain process next image
    if not x.shape[0]:
        return output
    
    # Detections matrix nx6 (xyxy, conf, cls)
    box, cls, mask = x.split([4, int(nc), int(nm)], 1)
    box = xywh2xyxy(box)  # center_x, center_y, width, height) to (x1, y1, x2, y2)
    conf, j = cls.max(1, keepdim=True)
    x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
    # Check shape
    n = x.shape[0]  # number of boxes
    if not n:  # no boxes
        return output
    
    x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes
    c = x[:, 5:6] * 0  # classes
    boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
    i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
    i = i[:max_det]  # limit detections
    output[0] = x[i]

    return output

一些细节

  • 我们要注意,因为 gpu 执行是异步的,
  • 可以通过
  • Pytorch或者 libtroch 会偶现 kernel launch 时间过长

异步执行的概念涉及GPU操作默认是异步进行的。当调用使用GPU的函数时,操作被排队到特定设备,但并不立即执行。这样可以并行执行更多计算,包括CPU或其他GPU的操作。

通常,异步计算对调用者是透明的,因为每个设备按队列顺序执行操作,而且PyTorch在CPU与GPU或两个GPU之间复制数据时会自动进行必要的同步。因此,计算看起来就像每个操作都是同步执行的。

你可以通过设置环境变量CUDA_LAUNCH_BLOCKING=1来强制同步计算。这在GPU发生错误时很有用,因为在异步执行中,错误通常在操作实际执行后才报告,所以堆栈跟踪不会显示请求错误的位置。

由于异步计算的特性,未同步的时间测量不准确。为了获取精确的时间测量,应该在测量前调用torch.cuda.synchronize(),或者使用torch.cuda.Event来记录时间,例如:

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

# 在这里运行一些操作

end_event.record()
torch.cuda.synchronize()  # 等待事件被记录!
elapsed_time_ms = start_event.elapsed_time(end_event)

另外,某些函数如to()copy_()允许使用显式的非阻塞参数,当不需要同步时可以绕过同步。CUDA流是另一个例外,这部分内容在其他章节有详细说明。

参考