0%

模型感知训练

模型感知训练

文章参考学习

模型感知训练也是模型量化中的一部分,它主要是在量化过程中,对网络进行训练,从而让网络参数能更好地适应量化带来的信息损失。这种方式会更加灵活,因此准确性普遍从后训练量化要高,但是,同样地,它也存在缺点,如操作起来不方便。

量化训练过程中的梯度问题

首先回顾一下之前的所说的量化过程: \[ q = round(\frac{r}{S}+Z) \tag{1} \] 但这个\(round\)函数存在一个问题,就是它是不可导的,梯度几乎处处为零,这就会导致反向传播梯度也变成0,进而导致量化训练无法进行。它的函数图像如下:

img
img

解决梯度处处为零的问题

一个简单常用的方法就是Straight Through Estimator(STE),即直接跳过伪量化的过程,避开\(round\),直接将卷积的梯度作为结果传回去。

pytorch中的具体实现通常是使用torch.autograd.Function接口来重定义伪量化的过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
"""
grad_output:后一层传来的梯度
返回的None相当于是forward中的qparam的梯度,这里不需要,因为
qparam只是统计min和max的
"""
return grad_output, None

具体代码实现中的注意事项

代码实现使用的是Github上的量化代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class QConv2d(QModule):
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)

self.qw.update(self.conv_module.weight.data)

x = F.conv2d(x, FakeQuantize.apply(self.conv_module.weight, self.qw),
self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)

if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)

return x

需要注意的是x = F.conv2d(...)这一行,这里需要使用F.conv2d()来调用卷积,而不能使用self.conv_module()来调用,因为Pytorch中规则是torch.nn.Conv2d的weight无法自定义,需要手动设置weight时只能使用torch.nn.function.conv2d

此外,代码中对于输入和输出也进行了伪量化,即x=FakeQuantize.apply(x, self.qx),这在量化训练中是有必要的,可以帮助网络更好地感知量化带来的损失。

总结

量化训练(模型感知训练)部署虽然在准确性上会有所提升,但实际应用时去会比后训练量化麻烦很多!

目前大部分主流推理框架在处理后训练量化时,只需要用户把模型和数据扔进去,就可以得到量化模型,然后直接部署。但很少有框架支持量化训练。