SwapPrompt(论文解读): Test-Time Prompt Adaptation for Vision-Language Models
测试自适应 (TTA) 是无监督域自适应(UDA)中一种特殊且实用的设置,它允许源域中的预训练模型去适应另一个目标域中的未标记测试数据。为了避免计算密集型的骨干网络微调过程,因此利用预训练视觉语言模型(例CLIP、CoOp)zero-shot的泛化潜力,仅对未见测试域的运行时提示进行调整。然而,现有的解决方案尚未充分利用预训练模型的表征能力,因为它们只关注基于熵的优化,其性能远低于监督提示适应方法
2023(Neural IPS)
摘要
测试自适应 (TTA) 是无监督域自适应(UDA)中一种特殊且实用的设置,它允许源域中的预训练模型去适应另一个目标域中的未标记测试数据。为了避免计算密集型的骨干网络微调过程,因此利用预训练视觉语言模型(例CLIP、CoOp)zero-shot的泛化潜力,仅对未见测试域的运行时提示进行调整。
然而,现有的解决方案尚未充分利用预训练模型的表征能力,因为它们只关注基于熵的优化,其性能远低于监督提示适应方法,例如CoOp。
本文提出了SwapPrompt,可以有效地利用自监督对比学习来促进测试时提示适应。SwapPrompt 采用双重提示范式,即在线提示和目标提示,目标提示从在线提示中取平均值以保留历史信息。此外,SwapPrompt 应用了交换预测机制,该机制利用预训练模型的表征能力,通过对比学习来增强在线提示。具体来说,就是使用在线提示与输入图像的增强视图一起,来预测类别分配,该类别分配由目标提示与同一图像的另一种增强视图生成。
SwapPrompt可以部署在视觉语言模型上,无需额外要求,实验结果表明,该算法在ImageNet和其他9个数据集上实现了最先进的测试时自适应性能。研究还表明,SwapPrompt甚至可以在监督提示适应方法下实现相当的性能。
Introduction
当训练数据与测试数据存在分布差异时,深度神经网络的泛化性能可能会受到影响。领域适应的关键是构建模型,通过将源域知识迁移到目标域来调整数据分布的变化,这在训练阶段同时需要源域和目标域数据。然而,在实际场景中,通常只有经过源域训练后的模型可用,而没有访问源域数据的权限,或者没有权限更改原始的训练过程。为了解决这个问题,目前已经提出了测试时自适应(TTA),它仅利用未标记的测试数据流使模型适应目标域数据/不可见数据。现有的工作已经开发了诸如熵最小化,类原型,图像生成和自监督训练等技术,这些技术已经显示出卓越的性能。
尽管传统基于模型的 TTA 方法效果不错,但它们通常依赖于对模型主干参数的计算密集型调整。随着视觉语言预训练模型(CLIP 、CoOp、CoCoOp)的出现,情况将变得更糟,这些模型具有大量参数且难以优化。因此,探索高效的技术很必要,这些技术保证主干网络固定,仅微调一小部分参数,以便在测试期间使模型适应新的领域。预训练的视觉语言模型是在大量图像-文本对上训练,引入了强大的范式,为解决这个问题提供了新的见解。一种简单的方法是利用预训练视觉语言模型的zero-shot能力,利用下游任务中带标记数据进行微调来区分测试数据的各个领域,但是TTA场景中无带标签的下游数据,所以不合适。Shu提出了测试时提示调整(TPT)来解决测试时标签稀缺问题。然而,直接通过最小化熵调整特定于实例的提示,它可能会导致模型过度信任风险(即,为错误结果产生高置信度)。
本文提出了 SwapPrompt,这是一种新颖的测试时提示适应(TTA)方法。
解释图1:与以前的方法不同,SwapPrompt 在测试域中利用自监督对比学习策略,该策略由两个关键组成:指数移动平均 (EMA) 提示和提示交换预测机制。(1)EMA 机制采用双重提示范式:目标提示和在线提示。我们优化在线提示,同时通过慢移动平均过程逐步更新目标提示,该过程结合过去的信息以提高稳定性和有效性。(2)提示交换预测机制的灵感来源于无监督学习方法SwAV 。根据图像的增强视图和在线提示,SwapPrompt 预测同一图像的增强视图的类别分配。这使在线提示能够学习更多的表征知识。交换预测策略背后的基本原理是同一图像的两中不同增强视图应该具有相似的类别分配。利用对比表示学习方法来产生更好的决策边界。
除了自监督对比学习的损失函数外,我们还采用了传统的交叉熵损失,如 CLIP 和 CoOp,它使用 CLIP 生成的高置信度的伪标签来调整提示。本文方法也可用于在线测试时场景,其中测试数据以小批量流的形式到达。本文将对所有测试数据执行的操作分解为多个小批量。本文在各种测试时适应基准上评估了本文方法,包括ImageNet和基于它的四个自然分布偏移数据集,以及九个细粒度分类数据集。实验结果表明,该方法实现了最先进的测试时适应性能。
本文贡献:
(1)本文提出了SwapPrompt,一种新颖的测试时提示适应方法,它采用自监督对比学习策略,使提示能够更好地适应下游图像分类任务。
(2)本文首次将无监督表征学习应用于预训练视觉语言模型的快速适应。本文引入了 EMA 提示和提示交换预测策略,使提示能够从预训练模型的强大表征能力中学习更多的知识。
(3)本文对 ImageNet 及其 4 个变体以及其他 9 个图像分类数据集进行了广泛的实验。实证评估表明,本文方法明显优于当前的 TPT 方法,甚至可以在大多数数据集上与监督提示适应方法竞争。
Methodology
3.1 Preliminary and Problem Definition
本文专注于预训练视觉语言模型的测试时提示适应,其中模型在源域上训练,在目标域上进行测试。CLIP(“A photo of [CLS]”)表现出强大的zero-shot泛化能力。然而,这些手工制作的提示并不能完全提取CLIP从大规模、多样化的预训练数据集中学到的丰富知识。优化的提示可以进一步提高 CLIP 获得目标域知识的能力。目前已经有一些关于监督目标域数据的相关工作,包括CoOp方法,本文方法的一部分也融入了CoOp的思想。然后,我们将简要回顾 CoOp 并定义本文中使用的符号。
Context Optimization (CoOp)
CoOp是一种基于CLIP的提示调整的方法。与CLIP类似,CoOp包含一个文本编码器g()和一个图像编码器f()。
目标域数据集:
其中xi表示第i个输入的图像,yi表示相应的类标签,N代表样本个数,其中C表示类别数。
t代表连续的可学习向量,{t;c}表示传递到文本编码器的第c个类别的输入。
zi = f (xi)表示图像编码器输出得到的特征编码,wc = g({t; c})表示文本编码器输出的特征编码。
对于第c个类别中图像xi的预测概率为:
对于所有训练数据,CoOp使用以上公式求所有类的预测概率,并最小化交叉熵损失以调整提示。
Test-Time Prompt Adaptation.
在测试时场景中,来自目标域中带有标签的数据不可以得到,因此不能像CoOp那样优化提示。
因此使用没有标签数据的目标域数据集:。
测试时间提示调整的目标可以表示为:
其中L表达交叉熵损失函数,其目的就是设计无监督提示调整方法,促进提示t与目标域数据兼容,为CLIP获取等多的目标域知识。
3.2 Overview of SwapPrompt
SwapPrompt包含两个重要部分:EMA提示和提示交换预测。介绍图2:SwapPrompt采用双重提示范式:在线提示和目标提示两种。通过应用EMA策略,两种提示相互学习以适应目标域。此外,与以往一个提示对应一个图像的框架不同,SwapPrompt采用提示交换预测机制,为提示自适应建立自监督表征学习,其目标是为统一图像的两种不同的增强分配相似的类别。
3.2.1 Exponential Moving Average Prompt
SwapPrompt的目标就是学习一种可以应用于测试数据的在线提示to,使用目标提示tt引导在线提示to的更新。该设计的动机是:给定目标提示tt,我们可以通过预测目标提示tt生成的表征知识来训练一个新的、有改进的在线提示to。通过反复使用后续的在线提示作为新的目标提示进一步训练,我们就可以得到一个质量随时间提高的提示序列。实际中,SwapPrompt使用在线提示的缓慢移动的指数平均值作为新的目标提示,从而在每个训练step下进行以下的更新(ε ∈ [0, 1],表示目标提示的衰减率):
3.2.2 Prompt Swapped Prediction
在自监督学习领域中,跨视图预测在许多现存工作中广泛应用。这些方法通常将预测问题投射到一个表征空间,然后使用来自同一图像的不同增强视图来学习表征。假设一张图像的不同增强视图在表征空间中相对接近。
与直接强制在表示空间中图像特征之间的一致映射不同,SwAV将一组图像特征进行聚类,使用聚类中心作为原型,并匹配同一图像的不同增强版本到这些原型上,以计算其聚类分配。SwAV通过比较它们的聚类分配,而非图像特征,在多个图像视觉之间进行对比学习。
受到SwAV的启发,本文将一张图像的增强视图分配给原型,从而获得软类别分配,然后使用同一张图片的另一个增强视图来预测其类别分配。这种方法很适合CLIP,因为它有自然的原型:文本编码器输出的文本特征。
(1)具体来说,对于测试数据集Dtest中的所有类别Y ∈ {1, 2, . . . , C},目标提示tt会为文本编码器形成C种输入:
文本编码器会根据输入为不同的类别生成C种文本特征:
由于在大规模的图像文本对种进行有监督的对比预训练,CLIP生成的文本特征与相同类别的图像特征相似度很高,与不同类别的图像特征相似度比较低。因此,这C个文本特征可以作为一组高质量的原型。
(2)对于一个图像xi,本文使用两种不同的图像增强方法A1A2来获得两种不同的增强视图A1(xi)和A2(xi)。图像编码器生成相应的图像特征和。
Eq.1:
(3)通过在文本特征原型与对应类别的两个增广图像视图特征计算Eq1得到相应的类别分配:
其中:
(4)类似,在线提示to产生的预测表示为:
(5)文本建立了图像xi的提示交换预测损失函数(其中l表示衡量预测结果和类别分配之间差异的函数):
其中损失函数:
其中在线提示to通过Eq7被优化,而目标提示tt不会被该损失函数更新。
3.2.3 Prompt Optimization by Pseudo Label
除了使用自监督表征学习方法来优化提示,本文也使用了一组带有标签的数据来优化在线提示,与CoOp类似。但是与CoOp不同,CoOp中目标域的数据标签可用,在测试场景中测试数据的标签不可用。因此,本文首先使用手工提示( A photo of class)对测试数据进行推断,得到它们的伪标签,然后使用Eq6和交叉熵损失:
3.3 Algorithm Workflow
在使用Eq7、Eq9和伪标签训练之前,我们需要进行数据选择来过滤掉潜在的噪声伪标签。具体来说,我们首先使用CLIP和手工设计的提示来获得测试数据的伪标签和分类置信度。然后,对于每个类,只选择置信度排名最高的前K个测试数据。这些被选择的测试数据组成了调整集Dadapt,这是Dtest的一个子集。对于调整集中的测试数据,使用以下函数进行提示调整:
当目标图像以小批量流到达时,即测试数据是在线的,我们无法对整个测试数据集的置信度进行排序。然而,我们仍然可以对小批量进行基于置信度的排序,以选择置信度最高的前k个( k < K)测试数据,同时保持训练过程的其余部分不变。当新的小批量测试数据到来时,对可用的测试数据再次进行置信度排序,得到新的D adapt,以便及时适应。
Conclusion
在本文中,我们研究了一种新的测试时提示适应方法SwapPrompt,以学习适应预训练视觉语言模型的测试域提示。具体来说,我们维护一个在线提示和一个由 EMA 更新的目标提示,它们相互交互和学习。本文设计了一种交换预测机制来训练在线提示,使其能够预测目标提示在不同增强视图下对同一图像的类别分配。SwapPrompt 可以很容易地部署在视觉语言模型的测试时。在各种数据集上进行了大量的实证实验,以验证SwapPrompt的有效性和优越的性能。
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)