Pytorch量化之Post Train Static Quantization(训练后静态量化)

news/2024/7/24 6:21:53 标签: pytorch, 人工智能, python

使用Pytorch训练出的模型权重为fp32,部署时,为了加快速度,一般会将模型量化至int8。与fp32相比,int8模型的大小为原来的1/4, 速度为2~4倍。
Pytorch支持三种量化方式:

  • 动态量化(Dynamic Quantization): 只量化权重,激活在推理过程中进行量化
  • 静态量化(Static Quantization): 量化权重和激活
  • 量化感知训练(Quantization Aware Training,QAT): 插入量化算子后进行训练,主要在静态量化精度不满足需求时进行。
    大多数情况下,我们只需要进行静态量化,少数情况下在量化感知训练不满足时使用QAT进行微调。所以本篇只重点讲静态量化,并且理论部分先略过(后面再专门总结),只关注实操。
    注:下面的代码是在pytorch1.10下,后面Pytorch对量化的接口有调整
    官方文档:Quantization — PyTorch 1.10 documentation

动态模式(Eager Mode)与静态模式(fx graph)

Pytorch支持用2种方式量化,一种是动态图模式,也是我们日常使用Pytorch训练所使用的方式,使用这种方式量化需要自己手动修改网络结构,在支持量化的算子前、后插入量化节点,优点是方便调试。静态模式则是由pytorch自动在计算图中插入量化节点,不需要手动修改网络。
网络上大部分的教程都是基于静态模式,这种方式比较大的问题就是需要手动修改网络结构,官方教程里的网络是属于demo型, 其中的QuantStub和DeQuantStub就分别是量化和反量化的节点:

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

Pytorch对于很多网络层是不支持量化的(比如很常用的Prelu),如果我们用这种方式,我们就必须在这些不支持的层前面插入DeQuantStub,然后在支持的层前面插入QuantStub。笔者体验下来,体验很差,个人觉得不太实用,会破坏原来的网络结构。
而静态图模式,我们只需要调用Pytorch提供的接口将原模型转换一下即可,不需要修改原来的网络结构文件,个人认为实用性更强。
image.png

静态模式量化

1. 载入fp32模型,并转成fx graph

其中量化参数有‘fbgemm’和‘qnnpack’两种,前者在x86运行,后者在arm运行。

model_fp32 = torch.load(xxx)
model_fp32_quantize = copy.deepcopy(model_fp32)
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')}
model_fp32_quantize.eval()
# prepare

model_prepared = quantize_fx.prepare_fx(model_fp32_quantize, qconfig_dict)
model_prepared.eval()

2.读取量化数据,标定(Calibration)量化参数

标定的过程就是使用模型推理量化图片,然后统计权重和激活分布,从而得到量化参数。量化图片一般来源于训练集(几百张左右,根据测试情况调整)。量化图片可以通过Pytorch的Dataloader读取,也可以直接自行实现读图片然后送入网络。

### 使用dataloader读取
for i, (data, label) in enumerate(train_loader):
    data = data.to(torch.device("cpu:0"))
    outputs = model_prepared(data)
    print("calibrating {}".format(i))
    if i > 1000:
        break

3. 转换为量化模型并保存

quantized_model = quantize_fx.convert_fx(model_prepared)
torch.jit.save(torch.jit.script(quantized_model), "quantized_model.pt")

速度测试

量化后的模型使用方法与fp32模型一样:

import torch
import cv2
import numpy as np
torch.set_num_threads(1)

fused_model = torch.jit.load("jit_model.pt")
fused_model.eval()
fused_model.to(torch.device("cpu:0"))

img = cv2.imread("./1.png")
img_fp32 = img.astype(np.float32)
img_fp32 = (img_fp32-127.5) / 127.5
input = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()

def speed_test(model, input):
    # warm up
    for i in range(10):
        model(input)

    import time
    start = time.time()
    for i in range(100):
        model(input)
    end = time.time()
    print("model time: ", (end-start)/100)
    time.sleep(10)

# quantized model
quantized_model= torch.jit.load("quantized_model.pt")
quantized_model.eval()
quantized_model.to(torch.device("cpu:0"))

speed_test(fused_model, input)
speed_test(quantized_model, input)

实测fp32模型单核运行120ms, 量化后47ms

结语

本文介绍了fx graph模式下的Pytorch的PTSQ方法,并实测了一个模型,效果还比较不错。
1_995567224_161_79_3_732056265_62005da0d7c1b531a6cf91ea587d312e.jpg


http://www.niftyadmin.cn/n/4926868.html

相关文章

go的make使用

在 Go 语言中,make 是一个用于创建切片、映射(map)和通道(channel)的内建函数。它提供了一种初始化和分配内存的方式,用于创建具有特定长度和容量的数据结构。下面将详细介绍 make 函数的使用方法和各种情况…

nginx负载均衡(nginx结束)

本节主要内容 1、四层,七层代理的配置方法 2、负载均衡的算法 nginx负载均衡:反向代理来实现 反向代理有两种转发方式:1、四层代理 2、七层代理 Nginx的七层代理和四层代理 七层是最常见的反向代理方式,只能配置在nginx配置文…

Vue 本地应用 记事本 v-on v-model v-for使用

新增功能 vue当中如何生成列表结构?使用的指令是v-for,同时要有一个可以生成列表的数据,常用的是数组。记事本里面的内容并不复杂,所以这里使用字符串数组就行了。 获取用户输入的内容使用绑定v-model,双向数据绑定&a…

前沿分享-无创检测血糖RF波

非侵入性血糖仪,利用射频 (RF) 波连续测量血液中的葡萄糖水平。利用射频波技术连续实时监测血液中的葡萄糖水平,使用的辐射要比手机少得多。 大概原理是血液中的葡萄糖是具有介电特性,一般来说就是介电常数。 电磁波波幅的衰减反映了介质对电…

汽车IVI中控开发入门及进阶(十):车载摄像头接口CVBS、AHD和MIPI

文章目录 前言一、CVBS是什么?二、AHD是什么?三、MIPI是什么?前言 汽车电子电气架构正在由传统的分布式架构向域集中式架构转变,也就是将多个应用程序集中在一个域中,正如提到IVI,有些已经开始导入域控,除了一带多的显示屏、一带多的雷达传感器,当然还有一带多的摄像头…

低代码、逻辑、规则、数据分析、协同工具集合,解决企业不同需求

大家好,我是为IT部门兄弟操碎了心的“软件部长”,随着企业IT建设的不断发展,软开企服也在经历了数十年的项目中积累了丰富的经验,为此开始了IT软件的研发之路,之后就一发不可收拾。。。才有了现在出现在市面上的JVS。 …

android开发之Android 自定义滑动解锁View

自定义滑动解锁View 需求如下: 近期需要做一个类似屏幕滑动解锁的功能,右划开始,左划暂停。 需求效果图如下 实现效果展示 自定义view如下 /** Desc 自定义滑动解锁View Author ZY Mail sunnyfor98gmail.com Date 2021/5/17 11:52 *…

运用反射将对象中的一些字段设置为null

之前的一个需求是需要将查询出来的对象,只保留特定值,其余值都设置为null。 最直接的办法是把对象中的属性捞出来,手动设置null;可是对象每增加一个值可能都需要修改一下,很不便利。 我们可以把利用反射的原理来处理…