【笔记】TinyBERT(EMNLP2019)

两阶段蒸馏:预训练阶段+finetune阶段

在这里插入图片描述

设计3种损失函数分布来适应bert的不同层级的损失计算

  1. embedding 层输出
  2. 来自 transformer 层的隐藏层和注意力矩阵
  3. 输出 logits 的预测层

1. 知识蒸馏的设计

可以将网络的任何一层称为行为函数( f f f , behavior function),KD就是利用小模型( S S S, student)学习大模型( T T T, teacher)。知识蒸馏的数学表示:
L K D = ∑ x ∈ X L ( f S ( x ) , f T ( x ) ) \mathcal{L}_{KD} = \sum_{x \in \mathcal{X}}L(f^S(x), f^T(x)) LKD=xXL(fS(x),fT(x))

对于Transformer层的蒸馏任务而言,需要学习的就是1)多头自注意力层(Mulit-head attention)、2)全连接前馈网络(fully feed-forward network)以及3)其他中间表示(例如注意力矩阵)。

因此研究的关键在于如何定义有效的1)行为函数和2)损失函数,包括在预训练和finetune阶段。

2. Methods

在这里插入图片描述

2.1 Transformer 蒸馏

设「学生模型」共 M M M 层,「学生模型」 N N N 层。

(一)学生层与教师层的对应关系

公式 n = g ( m ) n=g(m) n=g(m) 表示「学生模型」第 m m m 层映射至「教师模型」第 n n n 层,特别的, 0 = g ( 0 ) 0=g(0) 0=g(0) 表示学生的embedding层映射至教师embedding层; N + 1 = g ( M + 1 ) N+1=g(M+1) N+1=g(M+1) 表示预测层相对应。

此知识蒸馏任务可公式化为
L model = ∑ x ∈ X ∑ m = 0 M + 1 λ m L layer ( f m S ( x ) , f g ( m ) T ( x ) ) \mathcal{L}_\text{model} = \sum_{x \in \mathcal{X}}\sum_{m=0}^{M+1} \lambda_{m} L_\text{layer}(f_m^S(x), f_{g(m)}^T(x)) Lmodel=xXm=0M+1λmLlayer(fmS(x),fg(m)T(x))
(二)具体Transmformer层的蒸馏对象

基于attention、基于hidden states,如Figure 2所示。

  1. Attention based distillation

原因:BERT的attention层能够捕获丰富的语言学知识,包括句法(syntax)和共现关系(coreference information),这些都是自然语言理解的基础。

更明确地说,就是用学生模型学习教师模型的 attention matrices :
L attn = 1 h ∑ i = 1 h MSE ( A i S , A i T ) \mathcal L_\text{attn} = \frac{1}{h}\sum_{i=1}^{h}\text{MSE}(A_i^S, A_i^T) Lattn=h1i=1hMSE(AiS,AiT)

式(3)是注意力矩阵拟合的损失函数表达式。而且论文之间对矩阵 A i A_i Ai 进行拟合,而不是对 s o f t m a x ( A i ) softmax(A_i) softmax(Ai) 实验表明前者的性能更佳且收敛速度更快。

  1. Hidden states based distillation

L hidn = M S E ( H S W h , H T ) \mathcal{L}_\text{hidn} = MSE(H^SW_h, H^T) Lhidn=MSE(HSWh,HT)

其中 H S ∈ R l × d ′ H^S\in\mathbb{R}^{l \times d'} HSRl×d H T ∈ R l × d H^T \in \mathbb{R}^{l \times d} HTRl×d 分别是学生模型和教师模型的Transformer FFN的隐藏层参数, W h W_h Wh 是一个可学习的线形层,用来将 S S S 对齐至 T T T

  1. 嵌入层蒸馏

L embd = M S E ( E S W e , E T ) \mathcal{L}_\text{embd} = MSE(E^SW_e, E^T) Lembd=MSE(ESWe,ET)

和隐藏层类似。

  1. 预测/输出层蒸馏

L pred = C E ( z T / t , z S / t ) \mathcal{L}_\text{pred} = CE(z^T/t, z^S/t) Lpred=CE(zT/t,zS/t)

z S z^S zS z T z^T zT 是学生和教师模型的 logits 向量,并对它进行soft,除以 temperature - t t t 。实验表明, t = 1 t=1 t=1 效果最好(没加更好?)。

因此最后的蒸馏任务损失函数就是以上4个的组合:
L layer = { L embd m = 0 L hidn + L attn M ≥ m ≥ 0 L pred m = M + 1 \mathcal{L}_\text{layer} = \begin{cases} \mathcal{L}_\text{embd} & m=0\\ \mathcal{L}_\text{hidn} + \mathcal{L}_\text{attn} & M\ge m \ge 0\\ \mathcal{L}_\text{pred} & m=M+1\\ \end{cases} Llayer=LembdLhidn+LattnLpredm=0Mm0m=M+1
(三)TinyBERT Learning

实验设计了两阶段的蒸馏(学习)任务:通用蒸馏 + 特定任务蒸馏,如图1所示。

  1. General Distillation

通用蒸馏,即在预训练阶段进行蒸馏,它能帮助「学生模型」学习到丰富的embedding知识,有助于提升模型的泛化能力。

预训练阶段的蒸馏任务损失函数(7)不包含公式中的预测层损失函数 L pred \mathcal{L}_\text{pred} Lpred。 旨在让「学生模型」学习模型的中间结构。并且初步的实验表明,在预训练阶段加入预测层损失函数并不能提升下游任务性能。

  1. Task-specifific Distillation

研究表明现复杂模型在特定领域的任务中存在 **over-parametrization(过度参数化)**的问题,这会造成模型过拟合,泛化性变差。所以,一些参数量小的模型或许能够达到和原来的BERT差不多的效果。

实验中使用了一个 finetuned的BERT模型 + 数据增强来蒸馏TinyBERT。

3. 总结

TinyBERT设计了不同阶段的损失函数,包括对BERT的Embedding层、预测层以及中间层;以及设计了两阶段的蒸馏任务。

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐