当前位置: 首页 > 专利查询>鹏城实验室专利>正文

基于对比学习的小样本学习鲁棒性提升方法技术

技术编号:34241985 阅读:84 留言:0更新日期:2022-07-24 09:28
本发明专利技术公开了一种基于对比学习的小样本学习鲁棒性提升方法,包括以下步骤:S1、对原始数据集进行预处理,构造对抗数据集和对比数据集;S2、将原始数据集、对抗数据集和对比数据集分别输入预训练语言模型中,得到相应的嵌入表示,并使用对比学习损失函数计算三种嵌入表示之间的距离;S3、根据对比学习损失函数与原模型的损失函数计算模型更新的梯度,以总体损失更小为目标来训练模型。本发明专利技术通过构造对抗和对比数据集为模型鲁棒性学习提供数据支持,使用对比学习目标函数计算损失能够更好地获取原始样本与对抗样本的相似性,也能更好地区分原始样本与对比样本的差异,从而提升模型受到对抗或者对比扰动时的鲁棒性。对抗或者对比扰动时的鲁棒性。对抗或者对比扰动时的鲁棒性。

Small sample learning robustness improvement method based on comparative learning

【技术实现步骤摘要】
基于对比学习的小样本学习鲁棒性提升方法


[0001]本专利技术涉及自然语言处理
,特别是涉及一种基于对比学习的小样本学习鲁棒性提升方法。

技术介绍

[0002]少样本学习是人工智能达到人类智能水平的基础任务。人类智能的一个重要特性是,人类具有从少量样本中进行学习的能力,并且具有极强的泛化性,即所谓举一反三,融会贯通。
[0003]而少样本学习研究的就是如何从少量样本中去学习。正因为如此,少样本学习成为了近年来深度学习领域非常重要的一个前沿研究方向。在自然语言处理领域,小样本学习主要是基于已有的预训练模型例如BERT(Bidirectional Encoder Representationfrom Transformers,双向编码器表示),RoBERTa(Robustly Optimized BERT,强力优化的BERT模型),GPT

3(Generative Pre

training,预训练生成模型)等来处理下游任务。目前主流的方法是对上述预训练模型进行微调(fine

tuning)或者提示调整(prompt tuning)。
[0004]微调(fine

tuning)方法的思想是使预训练语言模型适应下游任务,而不用从头训练新模型,因此不仅能够节省计算资源,而且能有效地利用预训练模型中已有的知识。在过去的许多年,微调(fine

tuning)方法在多个任务都取得了优秀的成绩,所以对特定任务进行预训练语言模型的微调成为了自然语言界的共识。但是,该方法需要充足的有标签的训练数据,并且将预训练语言模型的输出层替换成任务特定的头(head)。另外有很多研究表明,微调(fine

tuning)方法存在的问题是预训练和微调(fine

tuning)之间的目标函数存在显著的形式差别,阻碍了预训练语言模型到下游任务的知识迁移。
[0005]为了弥补微调(fine

tuning)方法的上述短板,提示调整(prompt tuning)提出了让下游任务去适应预训练语言模型。一个提示通常包括两部分:模版和标签词。在提示调整(prompt tuning)中,通过融合输入句子和模板,下游任务可以改写成相应的完形填空的问题。此时,预训练语言模型被用来预测完形填空中的空缺内容,然后模型预测出来的词再进一步映射成标签。与微调(fine

tuning)方法相比,提示调整(prompt tuning)方法不需要大量的带标签的训练数据以及额外的为特定任务设计的神经网络层,在多种小样本下游任务中取得了更好的效果。
[0006]尽管微调(fine

tuning)方法和提示调整(prompt tuning)方法在小样本学习场景下表现优异,但是当遇到对抗或者对比攻击,例如同义词或者反义词替换时,鲁棒性都欠佳。最近有某些研究机构的研究工作尝试对预训练语言模型进行对抗训练,以获得更鲁棒的模型。但是,这些方法大部分都是基于微调(fine

tuning)并且拥有大量训练数据,不适用于低资源(小样本)场景和提示调整(prompt tuning)方法。另外,也有研究为小样本微调(fine

tuning)设计新的目标函数,例如使用对比学习来使同一个类的样本距离更近,反之更远,但是仍然不能抵抗对抗或者对比扰动攻击。

技术实现思路

[0007]为了弥补上述
技术介绍
的不足,本专利技术提出一种基于对比学习的小样本学习鲁棒性提升方法,以解决微调方法和提示调整方法在小样本学习场景下遇到对抗和对比扰动时鲁棒性差的问题。
[0008]本专利技术的技术问题通过以下的技术方案予以解决:
[0009]本专利技术公开了一种基于对比学习的小样本学习鲁棒性提升方法,包括以下步骤:S1、对原始数据集进行预处理,构造对抗数据集和对比数据集;S2、将原始数据集、对抗数据集和对比数据集分别输入预训练语言模型中,得到相应的嵌入表示,并使用对比学习损失函数计算三种嵌入表示之间的距离;S3、根据对比学习损失函数与原模型的损失函数计算模型更新的梯度,以总体损失更小为目标来训练模型。
[0010]在一些实施例中,步骤S1具体包括:S1.1、根据原始数据集构造相应的同义词字典和反义词字典;S1.2、通过查找构造的同义词字典和反义词字典,对原始数据集中的数据进行同义词替换构造对抗数据集,对原始数据集中的数据进行反义词替换构造对比数据集。
[0011]在一些实施例中,步骤S1还包括:S1.3、计算原始数据集中每个原始样本的句子的所有扰动的集合,公式如下:
[0012][0013]式中,x是句子,Perturb是扰动该句子,是扰动x后的句子,即替换了同义词或者反义词后的句子;c
i
是中第i个字或者词,是c
i
的同义词或者反义词,是中第i个字或者词;Perturb(x)原始句子的扰动后的句子是的集合;是指存在一个位置i被扰动就会有一个
[0014]在一些实施例中,步骤S2中,所述得到相应的嵌入表示的具体方法包括:将原始数据集中每个原始样本与其对应的正例构成正例对,将原始数据集中每个原始样本与其对应的负例构成负例对,并通过预训练语言模型得到相应的嵌入表示,计算每个原始样本与对应正例和负例之间的语义相似性。
[0015]进一步地,具有相同语义的两个句子属于正例对,具有不同语义的两个句子属于负例对。
[0016]进一步地,计算两个句子之间的语义相似性的公式如下:
[0017][0018]式中,x
a
和x
b
表示句子a和句子b,exp()是求指数,h为句子x的特征表示,表示句子a经过模型处理后在[CLS]位置的嵌入,表示句子b经过模型处理后在[CLS]位置的嵌入。
[0019]进一步地,步骤S2中,所述使用对比学习损失函数计算三种嵌入表示之间的距离的计算公式如下:
[0020][0021]式中,x是句子,与x的语义相距较远甚至完全相反的句子为x
ant
,与x的语义相距较
近的句子为x
syn
,(x,x
syn
)构成正例对,(x,x
ant
)构成负例对。
[0022]进一步地,步骤S3中,所述总体损失的计算公式如下:
[0023]L=λ1L
ori
+λ2L
cl
[0024]式中,L
ori
表示除L
cl
外的其他模型自带的损失的目标函数,λ1和λ2表示模型学到的权重。
[0025]在一些实施例中,所述方法还包括:S4、通过反向传播机制更新模型参数,迭代至模型收敛,用最优的模型在下游任务的测试集上进行测试,得到预测结果。
[0026]本专利技术还公开了一种计算机可读存储本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于对比学习的小样本学习鲁棒性提升方法,其特征在于,包括以下步骤:S1、对原始数据集进行预处理,构造对抗数据集和对比数据集;S2、将原始数据集、对抗数据集和对比数据集分别输入预训练语言模型中,得到相应的嵌入表示,并使用对比学习损失函数计算三种嵌入表示之间的距离;S3、根据对比学习损失函数与原模型的损失函数计算模型更新的梯度,以总体损失更小为目标来训练模型。2.如权利要求1所述的基于对比学习的小样本学习鲁棒性提升方法,其特征在于,步骤S1具体包括:S1.1、根据原始数据集构造相应的同义词字典和反义词字典;S1.2、通过查找构造的同义词字典和反义词字典,对原始数据集中的数据进行同义词替换构造对抗数据集,对原始数据集中的数据进行反义词替换构造对比数据集。3.如权利要求2所述的基于对比学习的小样本学习鲁棒性提升方法,其特征在于,步骤S1还包括:S1.3、计算原始数据集中每个原始样本的句子的所有扰动的集合,公式如下:式中,x是句子,Perturb是扰动该句子,是扰动x后的句子,即替换了同义词或者反义词后的句子;c
i
是中第i个字或者词,是c
i
的同义词或者反义词,是中第i个字或者词;Perturb(x)原始句子的扰动后的句子是的集合;是指存在一个位置i被扰动就会有一个4.如权利要求1所述的基于对比学习的小样本学习鲁棒性提升方法,其特征在于,步骤S2中,所述得到相应的嵌入表示的具体方法包括:将原始数据集中每个原始样本与其对应的正例构成正例对,将原始数据集中每个原始样本与其对应的负例构成负例对,并通过预训练语言模型得到相应的嵌入表示,计算每个原始样本与对应正例和负例之间的语义相似性。5.如权利要求4所述的基于对比学习的小样本学习鲁棒性提升方法,其特征在于,具有相同语义的两个句子属于正...

【专利技术属性】
技术研发人员:郑海涛阳佳城王晖江勇夏树涛肖喜蒋芳清
申请(专利权)人:鹏城实验室
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1