一种基于混合增强对比的黑盒源域无监督领域自适应方法技术

技术编号:39310741 阅读:7 留言:0更新日期:2023-11-12 15:56
本发明专利技术属于机器学习下的迁移学习技术领域,公开了一种基于混合增强对比的黑盒源域无监督领域自适应方法,包括知识蒸馏初始化目标模型、混合增强特征对比学习以及随机混合增强矫正,基于黑盒源域模型利用知识蒸馏、互信息熵最大化和早期学习正则化方法对目标模型进行初始化,获取类别原型和类别学习难度阈值,并以此选取目标域样本,最小化混合特征对比损失,随机混合目标域样本及其伪标签作为增强样本,计算交叉熵矫正优化模型。该方法在不使用源域数据与模型参数的情况下进行领域间知识迁移,对无标签的目标域数据进行分类,从更细致的角度优化目标域类间结构以精确识别数据特征,在保证良好泛化性能的同时具备较强的安全性及隐私保护能力。全性及隐私保护能力。全性及隐私保护能力。

【技术实现步骤摘要】
一种基于混合增强对比的黑盒源域无监督领域自适应方法


[0001]本专利技术属于机器学习下的迁移学习
,涉及一种领域自适应模型方法,具体的说是涉及一种基于混合增强对比的黑盒无监督领域自适应方法。

技术介绍

[0002]随着大数据时代的到来,数据产生速度不断加快,数据规模呈现爆发式增长,这使得有能力处理庞大数据集的机器学习技术备受关注。大量数据为机器学习和深度学习提供了更多训练和优化的机会,从而提升了模型的性能和适用性。尽管机器学习在许多领域取得了令人瞩目的成功,但在现实场景中仍然存在着一些限制。传统的机器学习方法通常依赖于大量标记数据来构建模型,以实现较高的分类性能。然而,获取大规模标记数据并不总是容易或实际可行的。这就带来了一个新的挑战,即如何在有限的标记数据条件下训练出具有良好泛化能力的模型,并能够准确地预测未标记数据。
[0003]迁移学习旨在利用已经训练好的模型中的知识和特征,将其应用于新任务或领域中以提升性能。通过迁移学习,我们可以将一个领域中的知识和经验转移到另一个相关或类似的领域,从而节省大量时间和资源。领域自适应是迁移学习的一个分支,它关注的是不同领域之间的知识迁移。在现实场景中,不同领域的数据可能存在领域间差异,例如图像的拍摄环境、文本的语言风格等,这些领域差异会影响模型在目标领域上的性能。领域自适应旨在通过减小不同领域间差异,使模型能够在目标领域上具有较好的泛化能力,其中一种常见的领域自适应方法是无监督领域自适应,它利用目标领域中未标记的数据进行训练。无监督领域自适应通常通过学习领域间的共享特征或对抗性学习来实现,而无需目标领域的标记数据。
[0004]尽管无监督领域自适应取得了显著的成功,但人们对数据隐私的日益关注给这项任务带来了新的挑战。源域和目标域的数据通常储存在不同的设备上并包含私人信息,因此将源域数据暴露给目标域存在一定的风险,换言之,已经标记的源域数据可能无法为目标模型所用,这就使得一些现有的无监督领域自适应方法不再适用,因此便有了无源领域自适应方法,以促进模型迁移并保护源数据的隐私安全。无源领域自适应向未标记的目标域提供训练有素的源模型而非已经标记的源域数据,因此无源领域自适应也称为白盒领域自适应。
[0005]然而在实际应用中,白盒源域模型并不总是能获得的。常见的云服务模型如谷歌云,腾讯云,被封装为应用程序编程接口的形式提供给用户,其中只有模型的输入输出接口可用,模型本身被保存为黑盒接口,这使大量无源领域自适应方法在实践中变得不可用,为此黑盒领域自适应诞生。黑盒领域自适应方法只能使用源域模型的接口访问,在安全性提高的同时也给领域自适应任务带来了不小的挑战,无法获得源模型输出的样本特征使解决域偏移问题变得困难,源模型接口信噪比的不确定也使伪标签变得不可靠。

技术实现思路

[0006]为了解决上述技术问题,本专利技术提供了一种基于混合增强对比的黑盒源域无监督领域自适应方法,该方法在基于知识蒸馏模型的基础上,增加了改进的混合特征对比模块、早期学习正则化模块和随机混合增强模块,帮助学习源域和目标域间共享类的知识和目标域私有类的知识,有效地提高了目标模型的预测准确率。
[0007]为了达到上述目的,本专利技术是通过以下技术方案实现的:
[0008]本专利技术是一种基于混合增强对比的黑盒无监督领域自适应方法,包括如下步骤:
[0009]步骤1、将每个目标域样本输入黑盒源域,获得源域预测,代表样本属于源域中每个类的概率。根据源域预测计算每个类别的原型样本和学习难度阈值;
[0010]步骤2、将每个目标域样本输入目标模型,计算目标模型输出的互信息熵和与源域预测的相对熵作为蒸馏损失;
[0011]步骤3、计算并存储每个样本与类原型样本特征之间的距离作为非线性预测,增加早期学习正则化项,配合蒸馏损失初始化目标模型,迭代更新样本特征以保留模型训练早期的易学习特征;
[0012]步骤4、根据步骤1得到的源域预测以及步骤3得到的非线性预测计算伪标签,根据步骤1得到的学习难度阈值为目标样本筛选置信的非同类样本,将两者按相等比例混合增强后重新获得特征充当混合负样本;
[0013]步骤5、根据类原型特征及步骤4得到的混合负样本特征计算混合增强对比损失,目的是使得每个目标样本与类原型近,与其他类原型和混合负样本远;
[0014]步骤6、随机选择目标样本对按0.25和0.75的比例进行混合数据增强,根据其在目标模型的输出和其混合后的伪标签计算交叉熵;
[0015]步骤7、整体损失计算梯度,反向传播,迭代更新网络参数、类原型特征、学习难度阈值直至损失收敛,对目标域数据样本进行预测得到预测标签,与目标域数据样本的真实标签比较,对于每一类计算出该类的平均分类准确率作为度量结果。
[0016]进一步的,在步骤1中通过黑盒源域模型的输出计算每个类的原型样本和学习难度阈值,如下所示,
[0017][0018][0019]其中表示目标域样本;f
sk
表示源域模型预测第k类的概率;为超参数。
[0020]进一步的,在步骤2中构造了蒸馏损失,通过最小化蒸馏损失来更新目标模型,蒸馏损失由相对熵和互信息熵组成,定义如下:
[0021][0022][0023]L
warm
=L
kd

L
im
[0024]其中D
kl
表示相对熵,f
t
表示目标模型,h(p)=


i
p
i
logp
i
表示自信息熵。
[0025]进一步的,在步骤3中通过早期学习正则化项来正则化模型训练过程,保留模型早
期记忆的具有正确标签的干净样本,防止噪声数据影响。储存器用于记录每个样本的非线性预测,并通过动量策略基于新的预测进行更新,非线性预测、动量策略和早期学习正则化项定义如下:
[0026][0027][0028][0029]其中l2()为L2范式,σ为softmax函数,表示类原型样本,o
i
表示样本在当前模型的中的非线性预测,β为超参数。
[0030]进一步的,目标模型的线性和非线性预测均有其局限性,在步骤4中综合考量两者获得伪标签,通过伪标签为目标样本筛选置信的非同类样本进行混合增强,定义如下:
[0031][0032][0033]其中X
i
表示与第i个样本拥有相同伪标签且置信度大于学习难度阈值的目标样本集合,将第i个样本与集合中的每个样本进行混合增强,获得增强后样本特征作为混合增强对比负样本,混合增强定义如下:
[0034][0035]Mix
λ
(a,b)=λa+(1

λ)b
[0036]进一步的,在步骤5中,最小化目标域样本的infoNCE损失函数,其中样本特征作为锚点,类原型特征作为正样本,其他类原型特征及混合增强样本特征作为负样本,同时为了减本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于混合增强对比的黑盒源域无监督领域自适应方法,其特征在于:具体步骤如下:步骤1、将每个目标域样本输入黑盒源域,获得源域预测,代表样本属于源域中每个类的概率,根据源域预测计算每个类别的原型样本和学习难度阈值;步骤2、将每个目标域样本输入目标模型,计算目标模型输出的互信息熵和与源域预测的相对熵作为蒸馏损失;步骤3、计算并存储每个目标域样本与类原型样本特征之间的距离作为非线性预测,增加早期学习正则化项,配合蒸馏损失初始化目标模型,迭代更新所有目标域样本特征以保留目标模型训练早期的易学习特征;步骤4、根据步骤1得到的源域预测以及步骤3得到的非线性预测计算伪标签,根据步骤1得到的学习难度阈值为目标样本筛选置信的非同类样本,将两者按相等比例混合增强后重新获得特征充当混合负样本特征;步骤5、根据类原型样本特征及步骤4得到的混合负样本特征计算混合增强对比损失,使得每个目标样本与类原型近,与其他类原型和混合负样本远;步骤6、随机选择目标样本对按0.25和0.75的比例进行混合数据增强,根据其在目标模型的输出和其混合后的伪标签计算交叉熵;步骤7、整体损失计算梯度,反向传播,迭代更新网络参数、类原型特征、学习难度阈值直至损失收敛,对目标域样本进行预测得到预测标签,与目标域样本的真实标签比较,对于每一类计算出该类的平均分类准确率作为度量结果。2.根据权利要求1所述的一种基于混合增强对比的黑盒源域无监督领域自适应方法,其特征在于:在步骤1中根据源域预测计算每个类别的原型样本和学习难度阈值,具体表示为:为:其中表示目标域样本,f
sk
表示源域模型预测第k类的概率,为超参数,C
k
表示原型样本,Φ
k
表示学习难度阈值。3.根据权利要求1所述的一种基于混合增强对比的黑盒源域无监督领域自适应方法,其特征在于:步骤2中构造了蒸馏损失,通过最小化蒸馏损失来更新目标模型,蒸馏损失由相对熵和互信息熵组成,定义如下:相对熵和互信息熵组成,定义如下:L
warm
=L
kd

L
im
其中D
kl
表示相对熵,f
t
表示目标模型,h(p)=


i
p
i
logp
i
表示自信息熵,L
kd
表示知识蒸馏损失,E表示经验风险,X
t
表示所有的目标域样本,x
t
表示当前目标域样本,L
im
表示互信息最大化损失,L
warm
表示总的蒸馏损失。4.根据权利要求1所述的一种基于混合增强对比的黑盒源域无监督领域自适应方法,
其特征在于:在步骤3中通过早期学习正则化项来正则化模型训练过程,保留模型早期记忆的具有正确标签的干净样本,防止噪声数据影响,储存器用于记录每个样本的非线性预测,并通过动量策略基于新的预测进行更新,非线性预测、动量策略和早期学习正则化项定义如下:期学习正则化项定义如下:期学习正则化项定义如下:...

【专利技术属性】
技术研发人员:汪云云华子毅
申请(专利权)人:南京邮电大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1