一种多阶段的无监督域适应因果关系识别方法技术

技术编号:32454275 阅读:11 留言:0更新日期:2022-02-26 08:28
本发明专利技术公开了一种多阶段的无监督域适应因果关系识别方法。本发明专利技术步骤包括:(1)数据集划分;(2)利用自适应对比学习进行预训练;(3)结合知识蒸馏进行对抗学习;(4)多级数据过滤得到种子集;(5)单级数据过滤得到伪标签集合;(6)利用k

【技术实现步骤摘要】
一种多阶段的无监督域适应因果关系识别方法


[0001]本专利技术属于自然语言处理领域,涉及自然语言中的自适应对比学习、结合知识蒸馏的对抗迁移学习、基于特征级别数据增强和一致性损失的自训练策略的无监督域适应因果关系识别方法。

技术介绍

[0002]因果关系识别是推理和决策的基础。自然语言中因果关系的特征投影使机器能够更好地理解周围的环境,为逻辑推理、问答系统等下游任务提供关键线索。由于传统因果关系识别方法缺乏背景知识、难以理解上下文以及缺少逻辑推理能力,随着人类社会进入一个信息大爆炸时代,传统方法越发不能满足人们对智能化推理和决策的需要。识别因果关系需要模型具备背景知识和理解上下文的能力,现有的有监督因果关系识别方法都严重依赖于带注释训练数据的数量和质量,然而实际情况是未标注的数据远多于标注数据,特别是一些特殊领域的文本数据,其标注难度大、数据规模小。因此人们开始关注如何在数据有限的情况下提高其识别能力和泛化能力。而无监督域适应因果关系识别任务旨在通过学习有标签的源域数据,将其中的知识迁移至无标签的目标域中来获得在目标域数据上更好的识别效果。目前,关于因果关系的无监督域适应问题的相关研究还很少,并且现有的方法存在以下弊端:(a)通过拉进源域特征和目标域特征距离的一般无监督域适应方法会产生遗忘灾难和模糊分类边界,即在迁移过程中模型会失去对于该任务的分类能力;(b)现有的方法不具备学习目标域知识的能力,极大地限制了模型能力的提升,当源域知识和数据规模不能满足需求时其识别效果将大幅下降;(c)现有的伪标签方法容易引入噪声,从而极大地影响模型的识别能力。
[0003]通过研究发现,因果关系内部存在着丰富的多样性,其具体表现为因果关系在特征空间中可以细分为更多的子类原型,利用这种多样性可以有效地提高模型的识别能力和泛化能力。从因果关系的定义来看,因果关系可以被继续细分,例如时间层面的因果关系、条件因果关系等。因此,这些子类原型是具有实体意义的。我们利用这一事实进行特征级别的数据增强,进而引入一致性和自训练方法,可以让模型主动学习目标域的知识并进一步明确分类边界。

技术实现思路

[0004]根据上述现有方法缺陷描述,本专利技术提出了一种多阶段融合的无监督域适应因果关系识别方法;其中预训练阶段学习源域数据的知识,对抗阶段将源域知识迁移至目标域,一致性调整阶段结合数据过滤、数据增强和一致性等方法主动学习目标域数据中的特有知识。
[0005]为实现上述目的,本专利技术是通过以下技术方案来实现的:
[0006]步骤1、数据集划分:对源域数据进行三次随机划分,再利用得到的三组源域数据集经过预训练阶段和对抗阶段得到三组目标域模型;
[0007]步骤2、自适应对比学习:自适应对比学习的目的是在预训练阶段获取足够的类间距,并保持适当的类内距离以保留其多样性;
[0008]步骤3、结合知识蒸馏的对抗学习:在基于对抗的一般迁移学习基础上增加知识蒸馏,通过保留模型对于源域数据的分类能力,避免出现遗忘灾难;
[0009]步骤4、多级数据过滤:使用以上三步得到的三组目标域模型经过伪标签分配、投票机制、置信度筛选和不确定性筛选,得到一组相对干净的目标域种子集;
[0010]步骤5、单级数据过滤:在一致性调整阶段每轮训练结束后,使用置信度阈值筛选出种子集以外置信度高于置信度阈值的样本,并分配伪标签;
[0011]步骤6、子类原型获取:对于经多级数据过滤和单级数据过滤得到的数据集,均使用k

means聚类方法获取因果关系和非因果关系的子类原型;
[0012]步骤7、特征级别数据增强和一致性损失:利用得到的原型对输入样本的特征进行特征级别的数据增强,新特征向量和原特征向量的因果类别应该是一致的;
[0013]步骤8、自训练:使用上述步骤4和步骤5得到的伪标签以及其置信度经过映射得到的权重对模型进行训练,直到模型收敛。
[0014]进一步的,所述的步骤1中的数据集划分得到的每个源域数据集由60%的训练集、20%的测试集和20%的验证集组成,且三个源域数据集在划分过程中要保证足够的随机性。
[0015]进一步的,所述的步骤2具体实现如下:
[0016]2‑
1.将输入的自然语言文本划分为tokens,并利用bert编码器模型将文本投影为768维的特征;
[0017]2‑
2.存储生成的特征向量,聚类得到因果关系和非因果关系的类中心,计算出所有样本到类中心的平均距离;
[0018]2‑
3.以平均距离将作为对比损失的超参训练模型,直到模型收敛。进一步的,所述的步骤3中的结合知识蒸馏的对抗学习具体步骤是:
[0019]3‑
1.使用2

3训练的模型分别将源域数据和目标域数据编译为特征向量;
[0020]3‑
2.使用鉴别器鉴别3

1得到特征向量是否来源于目标域,从而计算出对抗损失;
[0021]3‑
3.计算目标域的特征向量和源域特征的相似度作为蒸馏损失;
[0022]3‑
3.使用3

2的对抗损失和3

3的蒸馏损失训练模型,直到模型收敛。
[0023]进一步的,所述的步骤4中的多级数据过滤具体步骤是:
[0024]4‑
1.不同的源域数据集训练得到的三组模型对目标域数据集分配伪标签,三个模型分配的伪标签相同时该样本才会通过投票机制的筛选;
[0025]4‑
2.通过Softmax层的输出得到每个样本的置信度,筛选出置信度最高的20%样本进行下一步筛选;
[0026]4‑
3.在启用dropout层的情况下,模型多次计算同一样本是因果关系的概率,存储下所有输出结果并通过均方差计算出其不确定性,筛选出不确定性小于0.01的样本组成种子集。
[0027]进一步的,所述的步骤7中的特征级别数据增强和一致性损失具体实现如下:
[0028]7‑
1.从步骤6获取子类原型筛选出和输入特征欧式距离最近的原型;
[0029]7‑
2.计算该原型和输入特征的相似度;
[0030]7‑
3.以7

2获得的相似度作为权重,将原型叠加到输入特征,最后经过全连接层的映射得到增强后的特征向量;
[0031]7‑
4.以增强的特征向量和输入特征向量类别的一致性计算出对应的一致性损失。
[0032]进一步的,所述的步骤8中的自训练具体步骤实现如下:
[0033]8‑
1.从Softmax层获得当前模型对当前样本的置信度;
[0034]8‑
2.以多级数据过滤和单级数据过滤得到的伪标签为目标,以步骤8

1得到的置信度为权重计算出交叉熵损失;
[0035]8‑
3.使用交叉熵损失和7

4的一致本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种多阶段的无监督域适应因果关系识别方法,其特征在于该方法包括以下步骤:步骤1、数据集划分:对源域数据进行三次随机划分,再利用得到的三组源域数据集经过预训练阶段和对抗阶段得到三组目标域模型;步骤2、自适应对比学习:自适应对比学习的目的是在预训练阶段获取足够的类间距,并保持适当的类内距离以保留其多样性;步骤3、结合知识蒸馏的对抗学习:在基于对抗的一般迁移学习基础上增加知识蒸馏,通过保留模型对于源域数据的分类能力,避免出现遗忘灾难;步骤4、多级数据过滤:使用以上三步得到的三组目标域模型经过伪标签分配、投票机制、置信度筛选和不确定性筛选,得到一组相对干净的目标域种子集;步骤5、单级数据过滤:在一致性调整阶段每轮训练结束后,使用置信度阈值筛选出种子集以外置信度高于置信度阈值的样本,并分配伪标签;步骤6、子类原型获取:对于经多级数据过滤和单级数据过滤得到的数据集,均使用k

means聚类方法获取因果关系和非因果关系的子类原型;步骤7、特征级别数据增强和一致性损失:利用得到的原型对输入样本的特征进行特征级别的数据增强,新特征向量和原特征向量的因果类别应该是一致的;步骤8、自训练:使用上述步骤4和步骤5得到的伪标签以及其置信度经过映射得到的权重对模型进行训练,直到模型收敛。2.根据权利要求1所述的一种多阶段的无监督域适应因果关系识别方法,其特征在于步骤1中的数据集划分得到的每个源域数据集由60%的训练集、20%的测试集和20%的验证集组成,且三个源域数据集在划分过程中要保证足够的随机性。3.根据权利要求2所述的一种多阶段的无监督域适应因果关系识别方法,其特征在于步骤2具体实现如下:2

1.将输入的自然语言文本划分为tokens,并利用bert编码器模型将文本投影为768维的特征;2

2.存储生成的特征向量,聚类得到因果关系和非因果关系的类中心,计算出所有样本到类中心的平均距离;2

3.以平均距离将作为对比损失的超参训练模型,直到模型收敛。4.根据权利要求3所述的多阶段的无监督域适应因果关系识别方法,其特征在于步骤3中的结合知识蒸馏的对抗学习具体步骤是:3

1.使用2

3训练的模型分别将源域数据和目标域数据编译为特征...

【专利技术属性】
技术研发人员:李建军周云帆俞杰陆奇李胜炎李新付田万勇赵露露惠国宝唐政
申请(专利权)人:杭州电子科技大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1