Pytorch的量化方法切换到了torchao,本篇基于官方教程简单介绍下torchao的量化使用教程。
使用 TorchAO 实现 GPU 量化
本篇对segment anything 模型进行量化和优化。参考了 segment-anything-fast 仓库时所采取的步骤。
本指南演示了如何应用这些技术来加速模型,尤其是那些使用 Transformer 的模型。为此,我们将重点关注广泛适用的技术,例如使用 torch.compile
进行性能优化和量化,并衡量其影响。
环境
实验环境:
- CUDA 12.1
- A100-PG509-200,功率限制为 330.00 W
不同硬件可能结果不同。
conda create -n myenv python=3.10
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install git+https://github.com/pytorch-labs/ao.git
Segment Anything Model checkpoint:
- 访问 segment-anything checkpoint,并下载
vit_h
checkpoint。或者可以使用wget
(例如,wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth --directory-prefix=<path>
)。 - 通过编辑以下代码传入该目录路径:
sam_checkpoint_base_path = <path>
import torch
from torchao.quantization import change_linear_weights_to_int8_dqtensors
from segment_anything import sam_model_registry
from torch.utils.benchmark import Timer
sam_checkpoint_base_path = "data"
model_type = 'vit_h'
model_name = 'sam_vit_h_4b8939.pth'
checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}"
batchsize = 16
only_one_block = True
@torch.no_grad()
def benchmark(f, *args, **kwargs):
for _ in range(3):
f(*args, **kwargs)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
t0 = Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)
return {'time': res.median * 1e3, 'memory': torch.cuda.max_memory_allocated() / 1e9}
def get_sam_model(only_one_block=False, batchsize=1):
sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda()
model = sam.image_encoder.eval()
image = torch.randn(batchsize, 3, 1024, 1024, device='cuda')
# 使用模型的单个 block
if only_one_block:
model = model.blocks[0]
image = torch.randn(batchsize, 64, 64, 1280, device='cuda')
return model, image
在本教程中,我们将重点量化 image_encoder
,因为它的输入是固定尺寸的,而 prompt encoder
和 mask decoder
的尺寸是dynamic,量化这些模块更为复杂。
我们首先从单个 block 开始,以简化分析。
基准测试
首先,我们来测量模型的基础运行时间:
try:
model, image = get_sam_model(only_one_block, batchsize)
fp32_res = benchmark(model, image)
print(f"模型的基础 fp32 运行时间为 {fp32_res['time']:0.2f}ms,峰值内存为 {fp32_res['memory']:0.2f}GB")
except Exception as e:
print("无法运行 fp32 模型:", e)
模型的基础 fp32 运行时间为 198.00ms,峰值内存为 8.54GB。
使用 bfloat16 提升性能
通过将模型转换为 bfloat16 格式,直接就有性能提升。我们选择 bfloat16 而不是 fp16 的原因是它的动态范围与 fp32 相当。bfloat16 和 fp32 具有相同的 8 位指数,而 fp16 只有 4 位。较大的动态范围有助于防止溢出错误及其他可能因量化而出现的问题。
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
bf16_res = benchmark(model, image)
print(f"bf16 block 的运行时间为 {bf16_res['time']:0.2f}ms,峰值内存为 {bf16_res['memory']: 0.2f}GB")
bf16 block 的运行时间为 70.45ms,峰值内存为 5.38GB。
通过此简单的更改,运行时间提高了约 7 倍(从 186.16ms 到 25.43ms)。
使用 torch.compile
进行编译优化
接下来,我们使用 torch.compile
对模型进行编译,看看性能有多大提升:
model_c = torch.compile(model, mode='max-autotune')
comp_res = benchmark(model_c, image)
print(f"bf16 编译后 block 的运行时间为 {comp_res['time']:0.2f}ms,峰值内存为 {comp_res['memory']: 0.2f}GB")
torch.compile
提供了大约 27% 的性能提升。
量化
接下来,我们将应用量化。对于 GPU,量化主要有三种形式:
- int8 动态量化
- int8 仅权重量化
- int4 仅权重量化
不同的模型或模型中的不同层可能需要不同的量化技术。在此示例中,Segment Anything
模型是计算密集型的,因此我们使用动态量化:
del model_c, model, image
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
change_linear_weights_to_int8_dqtensors(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 量化后 block 的运行时间为 {quant_res['time']:0.2f}ms,峰值内存为 {quant_res['memory']: 0.2f}GB")
通过量化,我们进一步提高了性能,但内存使用显著增加。
内存优化
我们通过融合整数矩阵乘法与后续的重新缩放操作来减少内存使用:
torch._inductor.config.force_fuse_int_mm_with_mul = True
通过这种方式,我们再次提升了性能,且大大减少了内存的增长。
进一步优化
最后,我们还可以应用一些通用的优化措施来获得最终的最佳性能:
- 禁用 epilogue fusion
- 应用坐标下降优化
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.coordinate_descent_check_all_directions = True
总结
通过本教程,我们了解了如何通过量化和优化技术加速 Segment Anything
模型。在批量大小为 16 的情况下,最终模型量化加速大约为 7.7%。
torchao.quantization.quantize_(model, int8_dynamic_activation_int8_weight())之后的model要怎样才能被aot_compile调用呢
quantize后model的linear层的weight会被替换掉
forward的时候会直接dispatch到对应的实现