一种面向长尾分布的视觉-语言模型提示学习框架

专利2025-07-09  39


本发明属于长尾分布领域,具体涉及一种面向长尾分布的视觉-语言模型提示学习框架。


背景技术:

1、通常,从现实世界收集的数据集会呈现长尾分布。长尾分布是指在数据集中存在少数几个类别或数值具有非常高的频率(头部),而大多数类别或数值则具有较低的频率(尾部)的分布形式。在长尾分布中,头部的类别或数值通常被称为“热门”,而尾部的类别或数值则被称为“长尾”。这种分布形式在很多现实世界的数据集中都很常见,比如销售额、网站访问量、商品评价等等。对于传统的机器学习算法来说,如果没有额外处理来减轻训练集倾斜性的方法,模型的效果往往会表现较差。目前关于长尾分布的数据集有cifar10-lt、places-lt和imagenet-lt。

2、为了解决不平衡数据问题,研究人员提出一系列算法,包括采样策略的调整(对较低频率的类别进行上采样),数据增强以创建合成样本(合成较低频率的样本),基于样本权重的重新加权方案(分配给较低频率的样本更大的权重)以及知识蒸馏技术。

3、近期,基于大模型的方法在零样本任务上取得的令人印象深刻的成就,因此,当前研究的焦点在于选择合适的预训练大模型,并根据下游任务的数据该大模型中的文本编码器或者视觉编码器。在这些大模型中,clip是一个备受关注的文本-视觉大模型,在零样本学习设置中能够高效地解码文本和视觉数据的高级语义信息,展现了其强大的泛化能力。

4、在将预训练的基础模型的知识转移到下游任务时,提示工程已被证明是有效的。其基本思想是学习一个合适的文本上下文作为提示,以围绕主文本构建模型的文本编码器的输入。与手工制作的提示不同,coop中提出的可微分提示在单词嵌入空间中自动学习,极大地加速了下游任务的提示调整过程。除了像coop那样自动学习提示的思想外,最近的proda工作通过在文本编码器的输出嵌入上估计高斯均值和协方差,为clip模型训练了固定的提示集。vl-ltr致力于使用clip模型解决不平衡下游数据集的学习,考虑在从头开始训练学生视觉-语言网络时从clip模型中蒸馏知识。然而,vl-ltr要求用户从网络收集额外的文本语料库,而不仅仅是仅使用下游长尾学习任务中提供的数据作为提示。


技术实现思路

1、为解决上述问题,本发明公开了一种面向长尾分布的视觉-语言模型提示学习框架,能够同时生成多个提示来描述语义,从而建立了一个强大的集成学习算法,使得模型能够充分学习训练样本数量稀缺的尾部类别。

2、为达到上述目的,本发明的技术方案如下:

3、一种面向长尾分布的视觉-语言模型提示学习框架,包括以下步骤:

4、步骤一:获取视觉-语言大模型以及其预训练后的模型参数文件,准备目标长尾分布数据集并对其进行预处理;所述视觉-语言大模型包含图像编码器,文本编码器,提示分布生成器;

5、步骤二:在数据集中采样图像和类别的文本描述,通过提示分布生成器生成提示分布,并从该分布中采样提示集,得到提示嵌入向量集合;

6、步骤三:将采样得到图像输入到图像编码器,得到图像嵌入向量;

7、步骤四:将提示嵌入向量和类别文本描述输入到文本编码器,得到文本编码嵌入向量;

8、步骤五:使用对比学习的方法,选择合适的损失函数和正则项对图像编码器,提示分布生成器进行训练;

9、步骤六:测试的时候将测试图像输入到图像编码器得到图像嵌入向量;将不同类别对应的提示分布生成器生成的提示集和类别描述输入到文本编码器得到文本嵌入向量,通过对比得到测试图像的类别。

10、进一步的,所述步骤一中所使用的视觉-语言大模型为contrastive languageimage pre-training(clip),采用wit数据集进行预训练。下游任务所需数据集为cifar10-lt、places-lt和imagenet-lt。

11、进一步的,所述步骤二中所使用的提示分布生成器是一个可学习的模块,并且可以被分成三个类别:类别通用的提示分布生成器,类别特定的提示分布生成器,和条件类别特定的提示分布生成器,公式表示为:

12、类别通用的提示生成器:

13、

14、类别特定的提示分布生成器:

15、

16、条件类别特定的提示分布生成器:

17、

18、其中,gagn,gspec,gcondi分别表示三种提示分布生成器,θ表示生成器中的可学习参数,y表示类别,c表示类别的数量、wy表示类别y的可学习的条件参数向量。表示正态分布,μ,σ表示正态分布的均值和方差。texty表示类别y的文本描述。训练或者测试时,首先选取指定某一类别的的提示分布生成器,再由该提示分布生成器生成一个提示分布,最后从提示分布中采样k个提示,表示类别y的第k个提示向量。

19、进一步的,所述步骤三所采用的图像编码器为resnet50,由四个残差结构组成。也可以是vision transformer,由12个vit单元级联构成。

20、进一步的,所述步骤四中所采用的文本编码器为text transformer,由12个bert单元级联构成。

21、进一步的,所示步骤五的损失函数和正则项数学公式为:

22、损失函数:

23、

24、其中,l表示损失函数,y表示x的类别。πy一般在实践中设置为类别先验概率的倒数,τ表示将是一个额外的自由参数,可用于调整πy的偏移。f(x)是一个向量,第y’个分量可以表示为:

25、

26、此外

27、v(x)=fimg(x),#(14)

28、

29、其中,fimg和ftxt分别表示图像编码器和文本编码器,x表示输入图像,v(x)表示图像x嵌入向量,表示类别y第k个文本嵌入向量(这里的y也需要变成y’),h(·)表示把文本嵌入为向量的映射函数。

30、正则项:

31、

32、其中,表示kl散度。该正则项的目的是希望所产生的提示分布为标准正态分布,以此保证所采样的提示的多样性。

33、本发明的有益效果为:

34、(1)本发明针对长尾问题提出了一种提示学习算法,该算法确定了一套生成任意数量提示的过程,显示的为每个类别设置了多个提示来描述语义,从而建立了一个强大的集成学习算法,使得模型能够充分学习训练样本数量稀缺的尾部类别。

35、(2)本发明为提示分布生成器的生成的提示分布进行了正态分布限制,这保证了所采样的提示的多样性。

36、(3)本发明不要求用户从网络收集额外的文本语料库,仅使用下游长尾学习任务中提供的数据来微调网络和提示分布生成器。



技术特征:

1.一种面向长尾分布的视觉-语言模型提示学习框架,其特征在于,包括以下步骤:

2.根据权利要求1所述的面向长尾分布的视觉-语言模型提示学习框架,其特征在于,步骤一中所使用的视觉-语言大模型为contrastive language image pre-training,采用wit数据集进行预训练;下游任务用到的数据集为cifar10-lt、places-lt和imagenet-lt。

3.根据权利要求1所述的面向长尾分布的视觉-语言模型提示学习框架,其特征在于,步骤二中所述的提示分布生成器是一个可学习的模块,并且被分成三个类别:类别通用的提示分布生成器,类别特定的提示分布生成器,和条件类别特定的提示分布生成器,公式表示为:

4.根据权利要求1所述的面向长尾分布的视觉-语言模型提示学习框架,其特征在于,步骤三所采用的默认图像编码器为resnet50,由四个残差结构组成;图像编码器是visiontransformer,由12个vit单元级联构成。

5.根据权利要求1所述的面向长尾分布的视觉-语言模型提示学习框架,其特征在于,步骤四中所采用的文本编码器为text transformer,由12个bert单元级联构成。

6.根据权利要求1所述的面向长尾分布的视觉-语言模型提示学习框架,其特征在于,步骤五的损失函数和正则项数学公式为:


技术总结
本发明公开了一种面向长尾分布的视觉‑语言模型提示学习框架,首先使用类特定或者类通用的提示分布生成器生成一个提示分布,然后从该分布采样一定数量的提示,和类别文本一起输入至文本编码器得到文本编码向量。随后将图片输入到图像编码器得到图像编码向量,利用对比损失指导文本编码向量和图像编码向量的在语义对齐上的训练。通过将提示的学习形式化为一个变分问题,该框架能够同时生成多个提示来描述类别,从而建立了一个强大的集成学习算法,使得模型能够充分学习训练样本数量稀缺的尾部类别。实证研究表明,所提出的提示学习框架有助于将预训练的视觉‑语言模型成功应用于数据长尾分布的下游视觉识别任务中。

技术研发人员:方鹏飞,李文倩
受保护的技术使用者:东南大学
技术研发日:
技术公布日:2024/11/11
转载请注明原文地址: https://tieba.8miu.com/read-15689.html

最新回复(0)