很早之前写过一篇libtorch的文章,那会torch的版本刚到1.0,而现在torch已经发展到 2.2.2 的版本。libtorch作为很早一代的C++部署方案,有很多的使用场景,而且现在依然可以使用,在pytorch2.0时代,libtorch的一些部署方案以及torchscript依然可以在aot场景发挥余热。
这篇主要介绍libtorch和torchscript在常见部署场景中的用法。后续文章也会介绍一些和AOTInductor结合使用的一些实际场景。
libtorch能做什么
libtorch 基于 PyTorch 的底层实现(比如 ATen 和 autograd),使用 C++ 作为API。 和 Pytorch 共享很多 torch 原生支持的功能,比如 tensor 计算、求导、搭建网络什么的。
有差不多这些组件:
- ATen: The foundational tensor and mathematical operation library on which all else is built.
- Autograd: Augments ATen with automatic differentiation.
- C++ Frontend: High level constructs for training and evaluation of machine learning models.
- TorchScript: An interface to the TorchScript JIT compiler and interpreter.
- C++ Extensions: A means of extending the Python API with custom C++ and CUDA routines.
具体组件介绍可以看:https://pytorch.org/cppdocs/
在部署任务中,有很多时候需要将 pytorch 模型部署到 c++环境。大部分情况都是尽量将整个模型转换到 TensorRT、onnx或者其他高性能部署平台的模型格式。但是如果这个模型比较特殊:
- 有很多算子不支持,比如某些 op 不支持 trt,比如 topk中 k 大于 3840或者 k 为动态的情况
- 或者说不是模型,没有权重,很多时候是前后处理的逻辑,比如 NMS
这个时候需要
如何写libtorch
写
写libtorch不再抓狂——自己用的C++在线调试notebook
导出你需要的后处理操作
可以采用cling C++ notebook的方式去实际操作学习一下。
/// IValue (Interpreter Value) is a tagged union over the types /// supported by the TorchScript interpreter. IValues contain their /// values as an IValue::Payload
, which holds primitive types /// (int64_t
, bool
, double
, Device
) and Tensor
as values, /// and all other types as a c10::intrusive_ptr
. In order to /// optimize performance of the destructor and related operations by /// making the Tensor
and c10::intrusive_ptr
paths generate the /// same code, we represent a null c10::intrusive_ptr
as /// UndefinedTensorImpl::singleton()
, not nullptr
.
Tensor内存操作
可以直接通过void*指针得到Tensor
at::Tensor tensor_image = torch::from_blob(image.data, {1, 3, image.rows, image.cols}, at::kByte);
tensor_image = tensor_image.to(at::kFloat);
std::unique_ptr<float[]> outputData(new float[1*17*96*72]);
auto res_point = torch::from_blob(outputData.get(), {input_shape[0],input_shape[1],input_shape[2],input_shape[3]});
Here, I assume that image.data is 8-bit byte values. The to(at::kFloat) will convert the 8-bit values into 32-bit floating points just as if you wrote static_cast(b) where b is a byte – just in case that wasn’t clear. If image.data is already floats, you can just write at::kFloat in place of at::kByte and skip the conversion of course. What’s super important to know is that from_blob does not take ownership of the data! It only interprets the data as a tensor, but doesn’t store the data itself. It’s easy to fix this if you want to, by calling .clone() on the tensor, since that will incur a copy of the data such that the resulting tensor will indeed own its data (which means the original cv::Mat can be destroyed and the cloned tensor will live on).
On the other side, it’s actually easier. You can use tensor.data() to access a tensor’s underlying data through a T*. For example, tensor_image.data() would give you a float*. If you want a more raw void* because you’re dumping the raw data somewhere else, there’s also a data_ptr() method that gives you a raw byte pointer.
Let me know if this helps.
// 假如传入的数据来自GPU void* input_buffer = ...
auto options = torch::TensorOptions().device(at::kCUDA);
auto detections = torch::from_blob(input_buffer, {input_shape[0],input_shape[1],input_shape[2]}, options);
Tensor基本操作
F::interpolate
cv::Mat image = cv::imread("/home/lll/Pictures/test.jpg");
torch::Tensor image_tensor = torch::from_blob(image.data, {image.rows, image.cols, 3}, torch::kByte);
image_tensor = image_tensor.permute({2, 0, 1}).toType(torch::kFloat).div_(255);
image_tensor.sub_(0.5).div_(0.5);
image_tensor = image_tensor.unsqueeze(0);
image_tensor = image_tensor.to(torch::kCUDA);
image_tensor = image_tensor.contiguous(); // 必要
namespace F = torch::nn::functional;
image_tensor = F::interpolate(
image_tensor,
F::InterpolateFuncOptions()
.mode(torch::kBilinear)
.size(std::vector<int64_t>({512, 512}))
.align_corners(true)
);
image_tensor = image_tensor.mul(0.5).add(0.5).mul(255);
image_tensor = image_tensor.squeeze(0).permute({1, 2, 0}).toType(torch::kByte).to(torch::kCPU);
cv::Mat test_mat(512, 512, CV_8UC3);
std::memcpy((void *) test_mat.data, image_tensor.data_ptr(), sizeof(torch::kU8) * image_tensor.numel());
cv::imshow("test", test_mat);
cv::waitKey(0);
如何调试 torchscript/libtorch 模型
很多时候我们要确认 python 和 c++ 的 op 结果是否可以对应上,最直接粗暴的方法就是在保证相同输入的前提下,在最终的输出地点比较两边的输出是否相同,一般可以比较:
- 相似度,一般用于评估模型输出的整体相似性,适用于端到端的比较。 比如比较检测模型最终输出的坐标点的相似度
- 绝对相对误差,一般比较具体数值的准确性,看我们的场景需求,理论上pytorch 和 libtorch 的输出应该完全一致
首先我们导出 pytorch 端的输出,把 value 封装为 model 的格式:
class Container(torch.nn.Module):
def __init__(self, my_values):
super().__init__()
for key in my_values:
setattr(self, key, my_values[key])
my_values = {
'res': prediction.cpu()
}
container = torch.jit.script(Container(my_values))
container.save("container.pt")
在保存好.pt 之后,在 C++中可以通过这种方式加载出来在 pytorch 端保存的结果:
torch::jit::script::Module container = torch::jit::load("/container.pt");
auto res = container.attr("res").toTensor();
在取出来值之后,就可以进行比较了,这里 reshape 是为了更好的比较,我们保证两个比较的对象的维度是相同的:
std::cout << "cosine_similarity: "
<< torch::cosine_similarity(res_libtorch.reshape({1,-1}), res.reshape({1,-1})) << "\n";
后记
此文章不断更新中,更新地址:
class Container(torch.nn.Module):
def __init__(self, my_values):
super().__init__()
for key in my_values:
setattr(self, key, my_values[key])
class DebugModule(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def d(self, a, b):
return torch.sqrt((a[:, None, :2] - b[None, :, :2]).pow(2).sum(-1))
def u(self, r):
return r.pow(2) * torch.log(r + 1e-6)
def z(self, x, c, theta):
x = torch.as_tensor(x, dtype=torch.float32)
U = self.u(self.d(x, c))
w, a = theta[:-3], theta[-3:]
reduced = len(theta) == c.size(0) + 2
if reduced:
w = torch.cat((-w.sum(dim=0,keepdim=True), w))
b = torch.matmul(U, w)
return a[0] + a[1] * x[:, 0] + a[2] * x[:, 1] + b
def fit(self, c, lambd:float = 0.0):
n = c.size(0)
U = self.u(self.d(c, c))
K = U + torch.eye(n, dtype=torch.float32).cuda() * lambd
P = torch.ones(n, 3, dtype=torch.float32).cuda()
P[:, 1:] = c[:, :2]
v = torch.zeros(n + 3, dtype=torch.float32).cuda()
v[:n] = c[:, -1]
A = torch.zeros(n + 3, n + 3, dtype=torch.float32).cuda()
A[:n, :n] = K
A[:n, -3:] = P
A[-3:, :n] = P.transpose(0, 1)
rlt = torch.linalg.solve(A, v.unsqueeze(1)).squeeze(1)
return rlt
def uniform_grid_pt(self, H:int, W:int):
c = torch.empty(H, W, 2, dtype=torch.float32).cuda()
c[:, :, 0] = torch.linspace(0, 1, W).unsqueeze(0).repeat(H, 1).cuda()
c[:, :, 1] = torch.linspace(0, 1, H).unsqueeze(1).repeat(1, W).cuda()
return c
def tps_theta_from_points_pt(self, c_src, c_dst):
delta = c_src - c_dst
cx = torch.cat((c_dst, delta[:, :1]), 1)
cy = torch.cat((c_dst, delta[:, 1:2]), 1)
theta_dx = self.fit(cx)
theta_dy = self.fit(cy)
return torch.stack((theta_dx, theta_dy), -1)
def tps_grid_pt(self, theta, c_dst, dshape:Tuple[int, int]):
ugrid = self.uniform_grid_pt(dshape[0], dshape[1])
theta = theta.clone().detach().to(torch.float32)
dx = self.z(ugrid.reshape(-1, 2), c_dst, theta[:, 0]).reshape(dshape[:2])
dy = self.z(ugrid.reshape(-1, 2), c_dst, theta[:, 1]).reshape(dshape[:2])
dgrid = torch.stack((dx, dy), -1)
grid = dgrid + ugrid
return grid
def tps_grid_to_remap_pt(self, grid, sshape:Tuple[int, int]):
mx = (grid[:, :, 0] * sshape[1]).to(torch.float32)
my = (grid[:, :, 1] * sshape[0]).to(torch.float32)
return mx, my
def forward(
self,
c_src:torch.Tensor,c_dst:torch.Tensor,
height:int, width:int
):
grid_size = (64, 64)
theta = self.tps_theta_from_points_pt(c_src, c_dst)
grid = self.tps_grid_pt(theta, c_dst, grid_size)
mapx, mapy = self.tps_grid_to_remap_pt(grid, (height, width))
return (mapx, mapy)
def get_remap_maps(c_src, c_dst, height, width):
grid_size = (64, 64)
theta = tps_theta_from_points(c_src, c_dst)
grid = tps_grid(theta, c_dst, grid_size)
mapx, mapy = tps_grid_to_remap(grid, (height, width))
return (mapx, mapy)
if __name__ == '__main__':
if False:
# 可能因为版本不一样,通过triton server 直接生成的 pt 无法直接使用,
# 可以在 triton sever 的 model.py 中 torch.save,然后在这个 demo 中组装为 jit的格式
input_tensor = torch.load("/code/lab/input.pt")
print(input_tensor.reshape(1,-1)[0][:100])
container = torch.jit.script(Container({"input": input_tensor}))
container.save("/code/lab/input_jit.pt")
input_tensor_jit = torch.jit.load("/code/jit.pt")
remap_model = torch.jit.script(RemapModule())
mapx_pt, mapy_pt = remap_model(c_src.cuda(), c_dst.cuda(), 1920, 1920)
print(torch.cosine_similarity(mapx_pt.cpu().reshape(1,-1), mapx_ref.reshape(1,-1)))
print(torch.cosine_similarity(mapy_pt.cpu().reshape(1,-1), mapy_ref.reshape(1,-1)))
remap_model_ori = RemapModule()
t1 = benchmark.Timer(
stmt='remap_model_ori(c_src,c_dst,1920, 1920)',
setup='from __main__ import remap_model_ori',
globals={"c_src": c_src.cuda(), "c_dst": c_dst.cuda()})
print(t1.timeit(1000))
remap_model.save("/code/lab/model.pt")
model_load = torch.jit.load("/code/lab/model.pt")
mapx_pt, mapy_pt = model_load(c_src.cuda(), c_dst.cuda(), 1920, 1920)
print(torch.cosine_similarity(mapx_pt.cpu().reshape(1,-1), mapx_ref.reshape(1,-1)))
print(torch.cosine_similarity(mapy_pt.cpu().reshape(1,-1), mapy_ref.reshape(1,-1)))
t1 = benchmark.Timer(
stmt='model_load(c_src,c_dst,1920, 1920)',
setup='from __main__ import model_load',
globals={"c_src": c_src.cuda(), "c_dst": c_dst.cuda()})
print(t1.timeit(1000))
@torch.jit.script
def non_max_suppression(
prediction,
conf_thres:float,
iou_thres:float,
max_det:int,
nc:int, # number of classes (optional)
):
# 这里默认 agnostic是True
# 整理步骤,过来的框应该是转置后的,1 先卡分数 2 然后xywh2xyxy 3 然后排序 4 然后nms
bs = prediction.shape[0] # batch size
nm = prediction.shape[1] - int(nc) - 4
mi = 4 + nc # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates 这个188是 4+15+169的mask 这里是15个类的score取max
# min_wh = 2 # (pixels) minimum box width and height
max_wh = 7680 # (pixels) maximum box width and height
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs # 这里开辟可以优化,暂时先不管
x = prediction[0] # image index, image inference
x = x.transpose(0, -1)[xc[0]] # confidence
# If none remain process next image
if not x.shape[0]:
return output
# Detections matrix nx6 (xyxy, conf, cls)
box, cls, mask = x.split([4, int(nc), int(nm)], 1)
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
return output
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
c = x[:, 5:6] * 0 # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
output[0] = x[i]
return output
一些细节
- 我们要注意,因为 gpu 执行是异步的,
- 可以通过
- Pytorch或者 libtroch 会偶现 kernel launch 时间过长
异步执行的概念涉及GPU操作默认是异步进行的。当调用使用GPU的函数时,操作被排队到特定设备,但并不立即执行。这样可以并行执行更多计算,包括CPU或其他GPU的操作。
通常,异步计算对调用者是透明的,因为每个设备按队列顺序执行操作,而且PyTorch在CPU与GPU或两个GPU之间复制数据时会自动进行必要的同步。因此,计算看起来就像每个操作都是同步执行的。
你可以通过设置环境变量CUDA_LAUNCH_BLOCKING=1
来强制同步计算。这在GPU发生错误时很有用,因为在异步执行中,错误通常在操作实际执行后才报告,所以堆栈跟踪不会显示请求错误的位置。
由于异步计算的特性,未同步的时间测量不准确。为了获取精确的时间测量,应该在测量前调用torch.cuda.synchronize()
,或者使用torch.cuda.Event
来记录时间,例如:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
# 在这里运行一些操作
end_event.record()
torch.cuda.synchronize() # 等待事件被记录!
elapsed_time_ms = start_event.elapsed_time(end_event)
另外,某些函数如to()
和copy_()
允许使用显式的非阻塞参数,当不需要同步时可以绕过同步。CUDA流是另一个例外,这部分内容在其他章节有详细说明。
参考
- https://pytorch.org/docs/stable/notes/cuda.html
- The future of C++ model deployment - #6 by desertfire - PyTorch Dev Discussions
- Load tensor from file in C++ fails · Issue #20356 · pytorch/pytorch · GitHub
- libtorch 常用api函数示例(史上最全、最详细) - 无左无右 - 博客园
- https://pytorch.org/cppdocs/notes/tensor_basics.html
- https://medium.com/crim/from-pytorch-to-libtorch-tips-and-tricks-dc45b6c1b1ac
- GitHub - crim-ca/crim-libtorch-extensions: Provides multiple algorithm implementation with Python/C++ bindings as extensions to
libtorch
(https://github.com/pytorch/pytorch C++ API). - 安全验证 - 知乎
- https://medium.com/crim/contributing-to-libtorch-recent-architectures-and-vanilla-training-pipeline-3789c7bf6959
- GitHub - prabhuomkar/pytorch-cpp: C++ Implementation of PyTorch Tutorials for Everyone
- https://medium.com/crim/from-pytorch-to-libtorch-tips-and-tricks-dc45b6c1b1ac