0%

《From Dense to Sparse:Contrastive Pruning for Better Pre-trained Language Model Compression》

译:《从密集到稀疏:对比学习用于更好的预训练语言模型压缩》

论文地址:From Dense to Sparse: Contrastive Pruning for Better Pre-trained Language Model Compression

用于模型剪枝的对比学习框架

Motivation

预训练语言模型(PLMs)在NLP领域已经取得巨大成功,但这些PLMs通常具有大量的参数,并且有计算密集型和资源密集型的特点。因此模型压缩被引入来压缩大规模PLMs,然而,绝大多数之前的方法只考虑了针对下游任务的特定任务知识,而忽视了在剪枝过程中任务无关的知识,这可能会导致灾难性遗忘问题,导致泛化能力较差。

为了解决这些问题,作者提出了他们的方法,对比剪枝(Contrastive Pruning, CAP)。

CAP是在预训练和微调的范式下的,被设计作为一种通用框架的,同时兼容结构化剪枝和非结构化剪枝,统一于对比学习。CAP可以使修剪后的模型从预训练模型中学习任务无关的知识,然后基于特定任务知识微调模型。

CAP的优点:

  1. CAP在修剪后的模型中同时维护任务无关的知识和特定任务的知识,从而减少灾难性遗忘问题,并在剪枝过程中保持模型性能
  2. CAP是基于对比学习这一已经被证明是非常强大的表征学习技术的
  3. CAP是一个框架,而不是一个特定剪枝方法,因此,它是与其它各种剪枝标准正交的,可以被结合使用

方法简介

作者提出的CAP框架总览如上图,通过监督预训练模型和微调模型来剪枝,并在剪枝过程中截取快照以获得不同类型的知识。按照迭代剪枝方法,将预训练模型\(\phi_{pre}\)压缩到期望的稀疏比R%(\(\phi_{pre} \to \phi_{r1} \to \phi_{r2}\to \dots \to \phi_R\)),在每一步可以使用任意剪枝标准。

CAP由三个模块组成:PrCSnCFiC,它们都是基于对比学习的,使用不同的方法构造正样本。

对比学习损失计算

对于样本\(x_i\),由模型\(\phi\)编码成向量表示\(z_i=\phi(x_i)\in\mathbb R^d\),此外,还有N个样本被编码为\(\mathcal S=\{\hat z_j\}_{j=1}^N\),用于与\(z_i\)进行对比,假设\(\mathcal S\)中一个或多个正样本\(\hat z_p \in \mathcal S\)和其它负样本\(\mathcal S \setminus\{\hat z_p\}\),可对比学习目标可定义为: \[ \mathcal L_i = -\frac{1}{||P(i)||}\sum_{\hat z_j\in P(i)}\log \frac{e^{sim(z_i,\hat z_j)/\tau}}{\sum_{k=1}^Ne^{sim(z_i,\hat z_k)/\tau}} \] 其中\(P(i) \subset \mathcal S\),表示\(z_i\)的正样本集,\(sim(z_i,z_j)=\frac{z_i^Tz_j}{||z_i||||z_j||}\)表示余弦相似度函数,\(\tau\)表示温度超参数。

PrC:使用预训练模型的对比学习

在向特定下游任务迁移学习过程中,原始PLM中任务无关知识很容易被丢失,这样就会导致灾难性遗忘问题,因此,引入PrC模块(图中绿色线)可以基于对比学习来维持通用语言知识。

假设样本\(x_i\)被模型\(\phi_r\)编码为\(z_i=\phi_r(x_i) \in \mathbb R^d\),其中\(r\%\)表示的是稀疏率。高层次想法是可以将\(z_i\)与由预训练模型\(\phi_{pre}\)编码的\(\{\hat z_j = \phi_{pre}(x_j)\}_{j=1}^N\)进行对比,使得模型可以正确找出那些语义相似(正例)的样本。这样一来,当前的剪枝模型\(\phi_r\)就可以模仿预训练模型的表示建模能力,从而保留与任务无关的知识。

具体地,对于无监督学习PrC\(\phi_{pre}(x_i)\)作为\(\phi_r(x_i)\)的一个正样本,而\(\{\phi_{pre}(x_j)\}_{j\ne i}\)作为负样本,损失为\(\mathcal L_{unsup}^{PrC}\)

对于监督学习PrC,则是进一步利用数据的句子级标注,例如:MNLI任务中,句子被标记为蕴涵(entailment)、中性(neutral)或矛盾(contradiction),直观地说,是将那些与\(x_i\)有相同标签的句子作为正例,因为他们有相似的语义特征,而其它的则作为负例,正式地表示:定义正例为\(\{\phi_{pre}(x_j)|y_j=y_i\}\),其中\(y_i\)表示\(x_i\)的标签,损失为\(\mathcal L_{sup}^{PrC}\)

因此,PrC最终训练目标为\(\mathcal L^{PrC}=\mathcal L_{unsup}^{PrC} + \mathcal L_{sup}^{PrC}\)

SnC:使用快照的对比学习

剪枝可以一次完成也可以迭代多次,在这篇论文中,作者采用迭代剪枝的方法因为它更适合高度稀疏。迭代剪枝就是一步步迭代移除权重,直到达到期望的稀疏性,而迭代的过程中产生的蹭模型就被称为快照。而之前的研究都只是简单地忽略掉这些快照,作者提出的SnC(图中黄线)则可以基于对比学习从这些快照中学习进而对当前模型剪枝。

具体地,剪枝过程快照集可以用\(\{\phi_{r'}\}_{r'<r}=\{\phi_{r1},\phi_{r2},\cdots \}\)来表示,这些快照可以桥接稀疏模型\(\phi_r\)和密集模型(\(\phi_{pre},\phi_{fine}\))之间的差距,提供具有不同稀疏结构的多样化监督。

对于无监督学习SnC,对于样本\(x_i\)通过当前模型\(\phi_r\)编码成\(\phi_r(x_i)\in\mathbb R^d\),使用相同样本但不同快照编码表示\(\{\phi_{r'}(x_i)|r'<r\}\)作为正样本,而作为\(\{\phi_{r'}(x_j)|j\ne i,r'<r\}\)负样本

对于监督学习SnC,结合标注,将相同标签实例作为正样本。

最后,\(SnC\)训练目标:\(\mathcal L^{SnC}=\mathcal L_{unsup}^{SnC}+\mathcal L_{sup}^{SnC}\)

FiC:使用微调模型的对比学习

为了更好地适应下游任务,剪枝模型\(\phi_r\)也可以从微调模型\(\phi_{fine}\)学习,为此,作者提出FiC模块,从\(\phi_r\)\(\phi_{fine}\)构造对比学习,它与\(PrC\)几乎相同,除了要用\(\phi_{fine}\)替换\(\phi_{pre}\),这样,最后的损失为:\(\mathcal L^{FiC}=\mathcal L_{unsup}^{FiC}+\mathcal L_{sup}^{FiC}\)

使用CAP框架剪枝

综合三种模块,对应的正样本如下:

PrCSnCFiC结合在一起,就可以实现提出的CAP框架了,并且可以灵活地结合不同的剪枝标准,这里分别考虑结构化剪枝与非结构化剪枝。

结构化剪枝

一种广泛使用的结构化剪枝是使用一阶泰勒展开,基于去除后损失\(\mathcal L\)的变化来推导重要性得分。这里将其称为First-order pruning,将其融入CAP,称为CAP-f \[ I_w=|\mathcal L_w-\mathcal L_{w=0}|\approx |\frac{\partial \mathcal L}{\partial w}w| \]

非结构化剪枝

对于非结构化剪枝,这里使用的是基于movement的剪枝方法,其计算参数\(w\)重要性得分方法为: \[ I_w=-\sum_t \frac{\partial \mathcal L^{(t)}}{\partial w}w^{(t)} \] 其中\(t\)是训练步数,基于这种方法反转参数后选择前\(K\)个策略或预定义阈值策略被称为Movement pruning or Soft-movement pruning,将这些与CAP结合,称为CAP-mCAP-soft

最后,剪枝和训练模型使用的最终目标可表示为: \[ \mathcal L=\lambda_1 \mathcal L^{CE} + \lambda_2 \mathcal L^{PrC}+\lambda_3\mathcal L^{SnC}+\lambda_4 \mathcal L^{FiC} \] 其中\(\mathcal L^{CE}\)是针对下游任务的交叉熵损失。

额外的内存开销

使用CAP框架,剪枝模型需要从预训练模型、快照和微调模型中学习,但其实是不需要将所以模型加载到GPU中的,从对比学习损失计算公式可以看出,不需要反向传播梯度,只需要简单地预编码样本并将它们存储到CPU中即可。