译:《TinyTL:降低内存而非参数进行高效的设备学习》
论文地址:TinyTL: Reduce Memory, Not Parameters for Efficient On-Device Learning
从降低内存角度提升边缘设备迁移学习效率
Motivation
场景
有各种传感器的智能边缘设备已经在我们的日常生活中无处不在,而这些设备每天都在通过传感器收集新的敏感数据,但希望可以在不牺牲隐私的情况下提供高质量的定制服务。这样就给更加高效的AI系统带来了新的挑战,这些系统不仅可以运行推理,而且需要根据新收集的数据不断微调预先训练的模型(如设备上学习)
问题
虽然设备上学习可以实现许多吸引人的应用,但这也是一个极具挑战性的问题。
首先,边缘设备是有内存限制的,例如一个Raspberry Pi 1 Model A只有256MB的内存,这对于模型推理是足够的,但目前而言对于训练是不够的,即使使用轻量神经网络结构(MobileNetV2)。此外,内存还是由设备上的各种应用和操作系统共享的,单个应用程序可能只能分配到总内存中的一小部分,这就使得这一挑战更加重要。
其次,边缘设备是有能源限制的,DRAM访问比片上SRAM访问多消耗两个数量级的能量,片上SRAM无法容纳大的内存需要,因此就必须访问DRAM。具体地,如在批量大小为16的情况下,MobileNetV2的训练内存接近1GB,这远远大于AMD EPYC CPU的SRAM大小,更别提更低级的边缘设备。因此,如果能将训练内存固定在片上SRAM,一定可以大大提高速度和能量效率。
相关工作
有很多有效的推理技术可以减少可训练参数的数量和计算FLOPs,但参数高效或FLOPs高效技术并不能直接节省训练内存。训练内存的瓶颈是激活值而非参数。如下图,比较了ResNet-50和MobileNetV2-1.4,
已有的工作通过降低训练参数数量来解决这个问题,但这不能直接转换成节省内存,因为主要的瓶颈是激活值而是不是参数。在参数大小方面,MobileNetV2-1.4比ResNet-50小4.3倍,但对于训练激活大小,MobileNetV2-1.4与ResNet-50几乎相同(仅1.1倍),导致内存几乎没有减少。因此,减少反向传播中所需的中间激活的大小对于解决设备上训练内存瓶颈是至关重要的。
高效推理技术
- 网络剪枝,移除不重要的单元或通道
- 网络量化,降低参数或激活的bit位数
- 设计轻量网络结构,手工或NAS
降低内存占用
研究人员一直在寻找减少训练内存的方法。
- 重新计算反向过程中丢弃的激活,这种方法以较大的计算开销为代价减少了内存占用
- 分层训练,可以降低内存占用,但会损失精度
- 剪枝激活,构建一个动态稀疏计算图来剪枝训练期间的激活
- 量化,引入新的低精度浮点格式来减少训练激活的bit位数
- 集中于降低峰值推理内存代价,如RNNPool、MemNet
但本文的方法与这些技术是正交的,可以组合使用。
迁移学习
迁移学习是指在大规模数据集上预训练,并将得到的模型广泛用于迁移学习的固定特征提取器,然后只需要微调最后一层。这种方法不需要存储特征提取器的中间激活,因此内存效率很高。但这种方法容易有限,可能导致精度较差,由于在新数据集与原数据集分布区别较大时。微调整个网络可以获得更高的精度,但需要很大的内存占用。
有人提出只更新BN层的参数,可以大大减少可训练参数的数量,但参数效率并不能转化为内存效率,它仍然需要大量的内存来存储BN层的输入激活,并且这种方法的准确性比微调整个网络差很多。也有人对一些层进行微调,但具体选择多少层仍然是特定的。
论文贡献
基于这些,作者提出了Tiny-Transfer-Learning (TinyTL) 来解决这些挑战,贡献如下:
- 提出一种新的迁移学习方法TinyTL,可以将训练存储空间减少一个数量级,从而实现高效的设备上学习。系统性地分析训练了训练时的内存,发型瓶颈来自于权重的更新而非偏置项(bias,假设是ReLU激活)
- 引入轻量残差模块( lite residual module),这是一个内存高效的bias模块,可以在内存开销很小的情况下提升模型容量
- 在迁移学习任务上通过实验证明了TinyTL具有很高的内存效率和有效性
方法简介
理解反向传播中的内存占用
不失一般性,假设神经网络\(\mathcal M\)由一系列层组成: \[ \mathcal M(.)=\mathcal F_{W_n}(\mathcal F_{W_{n-1}}(\cdots \mathcal F_{w_2}(\mathcal F_{w_1}(.))\cdots)) \] 其中,\(W_i\)表示第\(i\)层的参数,\(a_i\)和\(a_{i+1}\)分别表示第\(i\)层的输入和输出激活,\(\mathcal L\)表示损失,在反向传播时,给定\(\frac{\partial \mathcal L}{\partial a_{i+1}}\),则第\(i\)层有两个目标:计算\(\frac{\partial \mathcal L}{\partial a_i}\)和\(\frac{\partial \mathcal L}{\partial W_i}\)
假设第\(i\)层是线性层,则其前向传播是:\(a_{i+1}=a_iW+b\),然后在批量大小为1的情况下,反向过程是: \[ \frac{\partial \mathcal L}{\partial a_i} = \frac{\partial \mathcal L}{\partial a_{i+1}} \frac{\partial a_{i+1}}{\partial a_i}=\frac{\partial \mathcal L}{\partial a_{i+1}} W^T,\quad \frac{\partial \mathcal L}{\partial W}=\textcolor{red}{a_i^T} \frac{\partial \mathcal L}{\partial a_{i+1}},\quad \frac{\partial \mathcal L}{\partial b}=\frac{\partial \mathcal L}{\partial a_{i+1}} \] 从这些公式可以发现,只有在计算权重梯度(\(\frac{\partial \mathcal L}{\partial W}\))时需要主要占用内存的中间激活(\(\{a_i\}\)),而偏置(bias)不用。因此如果只更新偏置,训练内存就可以大大得到节省,这一属性也适用于卷积层和归一化层(如BN、group normalization),因为它们可以被认为是特殊类型的线性层。
如上表,可以发现,非线性激活层(如ReLU、sigmoid、h-swish),sigmoid和h-swish需要保存\(a_i\)来计算\(\frac{\partial \mathcal L}{\partial a_i}\),因此他们是内存不高效的,因此构建在它们之上的激活层也不能有效利用内存,如tanh、swish等。相反地,ReLU和类ReLU形式的激活层,只需要存储一个二进制掩码表示值是否为0,这比存储\(a_i\)小了32倍。
轻量残差学习
基于之前的内存占用分析,一个可能的降低内存代价的方法就是Figure 2b显示的冻结预训练特征提取器的权重(weights),只更新偏置(bias)。但是只更新偏置(bias)适应能力有限。
因此,这里作者引入了轻量残差学习(Figure 2c),利用一类新的广义内存效率高的偏置模块来精炼中间特征图。
公式化的,一个冻结权重只学习偏置的层可以表示为: \[ a_{i+1} = \textcolor{gray}{ \mathcal F_W(a_i) }+b \] 为了在保持较小内存占用的情况下提高模型容量,作者添加一个轻量残差模块: \[ a_{i+1}=\textcolor{gray}{ \mathcal F_W(a_i) }+b+\mathcal F_{W_r}(a_i'=reduce(a_i)) \] 其中\(a_i'=reduce(a_i)\)是减少的激活,根据之前的分析,学习这些轻量残差模块只需要存储很少的激活\(\{a_i'\}\)而不是完整激活\(\{a_i\}\)
实现
将这应用到 mobile inverted bottleneck blocks (MB-block),关键原则是要保证激活小,根据这一原则,提出两个设计维度来降低激活大小:
宽度
广泛使用inverted bottleneck需要大量的通道来补偿深度卷积的小容量,这是参数效率高但激活效率低的。更糟的是,来回将\(1 \times\)通道转成\(6\times\)通道需要两个\(1\times 1\)投影层,这使得总激活增至\(12\times\)。深度卷积(Depthwise convolution)也具有很低的算术强度(如果是256通道,每字节OPs/Byte不到\(1\times 1\)卷积的OPs/Byte的4%),因此存储效率很低,几乎没有重用。
为了解决这一限制,轻量残差模块采用了具有比深度卷积高得多的算术强度的分组卷积,在FLOPs和内存之间提供很好的折衷。,也移除了\(1\times 1\)投影层,降低总通道数到\(\frac{6\times 2+1}{1+1}=6.5\times\)
分辨率
随着分辨率的提高,激活大小呈二次曲线增长。因此,作者通过使用\(2×2\)的平均池化来对输入特征图进行下采样来缩小轻残差模块中的分辨率。然后,通过双线性上采样,对轻量残差模块的输出进行上采样,以匹配主分支的输出特征映射的大小。
结合分辨率和宽度优化,轻量残差模块的激活大约比反向瓶颈小\(22\times 6.5=26\times\)。
讨论
标准化层
TinyTL灵活支持不再的归一化层,包括BG、GN、LN等。BN是在视觉任务广泛使用的一种,但BN需要大的批次才能在训练时准确统计估计,并不适合在通常使用小批次节省内存的边缘设备上学习,设备上学习中数据以流的形式,这需要训练批次为1。而GN则可以处理较小的训练批次,因为GN统计的是针对不同输出独立计算的,虽然在实验中发现小批次的GN性能略逊于大批次的BN的。
特征提取器自适应
TinyTL可以用于不同的主干神经网络,如MobileNetV2、ProxylessNASNets、EfficientNets等。但是,由于由于特征提取器的权重在TinyTL中被冻结,而且发现对所有迁移任务使用相同的主干神经网络是次优的。因此,作者使用了预训练的once-for-all网络来选择TinyTL的主干,以自适应地选择最适合目标传输数据集的专用特征提取器