一种新的领域自适应学习方法技术

技术编号:24171307 阅读:72 留言:0更新日期:2020-05-16 03:02
提本发明专利技术属于新一代信息技术领域自适应学习技术领域,提出了一种新的领域自适应学习方法,该方法面向图像分类任务,不需要复杂的对抗学习,而是通过一个目标领域图像旋转预测的辅助分类任务和目标领域无标记样本插值后预测结果一致性约束,构建多任务学习模型,最终学习得到适用于目标领域的特征和适用于目标领域数据分布的分类模型。本发明专利技术不依赖目标领域样本标注的情况下学习得到适用于目标领域的分类器,大大降低了在测试数据分布发生变化时手工标注样本的压力。本发明专利技术结合目标领域样本的插值一致性先验和目标领域无标记样本的旋转角度预测这一辅助任务进行多任务学习,既能学习到适用于目标领域的特征,又能确保分类边界在目标领域数据分布中处在合适位置,能够有效提高目标领域的分类性能。

A new domain adaptive learning method

【技术实现步骤摘要】
一种新的领域自适应学习方法
本专利技术涉及领域自适应深度学习领域,特别是面向图像分类的领域自适应深度学习方法。
技术介绍
目前大多数深度学习方法采用监督学习,通过手工标记大量的样本进行模型训练。但是,手工标注样本十分耗费体力,成本高昂。此外,标记的训练样本和真实的测试样本之间很有可能存在分布不同的问题,这种情况下,训练的模型在测试数据上的性能往往会急剧下降。领域自适应学习就是一种为了解决由于训练数据和测试数据的分布不同导致机器学习性能下降而提出了一种迁移学习方法。领域自适应学习利用源领域的标注数据学习得到目标领域依然适用的模型。根据目标领域数据是否有标注信息,领域自适应学习可以分为有监督领域自适应学习、半监督领域自适应学习和无监督领域自适应学习。无监督领域自适应学习由于完全不依赖目标领域数据标注信息而应用更加广泛。近年来,深度学习快速发展并在计算机视觉领域取得空前成功。最近提出的领域自适应学习方法也大都采用深度神经网络模型,这些深度领域自适应学习方法可以分为两类,一类是基于最小化差异(discrepancy)的方法,这些方法通过最小化源领域和目标领域的特征之间的差异实现领域不变特征学习。另一类方法是基于对抗学习(adversariallearning)的方法,这类方法通过最小最大化博弈,学习一个领域判别器实现对源领域和目标领域的鉴别,同时学习一个特征提取器(生成器)迷惑之前的领域判别器,当最小最大化优化达到均衡时可以实现领域特征的对齐。这两类方法存在的问题是优化目标和训练过程较为复杂。自监督学习是近年来发展迅速的一类机器学习方法,它通过设置不依赖手工标注的辅助任务,学习得到适用于下游任务的特征。文献(RevisitingSelf-SupervisedVisualRepresentationLearning,AlexanderKolesnikov,CVPR2019)证明,自监督学习是一种有效的特征学习方法。基于自监督学习的思想,文献(Self-SupervisedDomainAdaptationforComputerVisionTasks,JiaolongXu,IEEEAccess2019(7):156694-156706)和专利(201910139916.8)提出一种自监督领域自适应学习方法,利用目标领域的图像旋转预测这一辅助任务学习适用于目标领域的特征,能够有效提升模型在目标领域数据上的性能。文献(Self-EnsemblingForVisualDomainAdaptation,GeoffFrench,ICLR18)提出了一种基于自集成的领域自适应方法,这种方法利用训练过程中不同迭代获得的模型的参数均值(自集成)作为教师模型,同时对目标领域的无标记样本进行随机增广,利用增广后样本在教师模型和学生模型上的预测的一致性作为监督信号,以学习适用于目标领域的模型。上述文献记载的方法挖掘了目标领域无标记样本的自监督信息,通过辅助任务构建多任务学习系统,能够学习到适用于目标领域的特征,但是这两种方法没有显式地考虑目标领域的聚类假设,也就是在目标领域的数据分布中,相近的样本很可能具有相同的类别,导致学习得到的分类边界可能存在不合理现象。
技术实现思路
本专利技术的目的是解决自监督领域自适应学习方法中欠缺对目标领域的聚类假设,从而造成学习得到的分类边界不合理的技术问题。为达到上述目的解决上述技术问题,本专利技术提出一种新的领域自适应学习方法,该方法的技术方案包括如下步骤:S1.准备源领域有标记样本集Ds(x,y)和目标领域无标记样本集Dt(x);S2.构建主任务深度卷积神经网络分类模型fθ(x),该模型由特征提取网络和分类网络级联而成,即其中θ={θv,θc}为可训练的参数;S3.构建四类旋转预测辅助分类模型该模型与主任务分类模型共享特征提取网络g及其参数θv,辅助分类网络级联在特征提取网络后端,用于对图像的旋转角度进行预测;S4.构建由源领域有监督分类任务、目标领域无标记样本插值一致性任务和目标领域样本旋转预测任务组成的多任务学习模型并在Ds(x,y)和Dt(x)上进行训练,以获取主任务分类模型的最优参数θ*;S4.1.确定训练的迭代次数T、移动平均系数α、[0,1]之间的随机分布Q;S4.2.初始化网络参数Θ={θv,θc,θa},初始化主任务分类模型θ={θv,θc}的移动均值:θ′∶=θ;S4.3.利用随机梯度下降法进行迭代,更新模型参数;S4.3.1.从源领域样本集Ds(x,y)中采集小批量样本S4.3.2.计算源领域小批量样本的主任务分类损失其中损失函数可采用交叉熵损失进行计算。S4.3.3.从目标领域数据集Dt(x)采样两组无标记小批量样本S4.3.4.利用主分类网络的均值教师模型计算目标领域样本的伪标记S4.3.5.从随机分布Q中采样插值系数λ;S4.3.6.计算样本和预测的插值,样本插值的计算方法为:预测的伪标记插值结果为:S4.3.7.计算插值一致性损失具体可采用均值平方误差进行计算。S4.3.8.将目标领域的样本集中的样本随机旋转0°、90°、180°或270°,构建辅助分类任务样本集表示四种不同的旋转角度;S4.3.9.计算目标领域样本的辅助分类任务损失其中损失函数可采用交叉熵损失进行计算。S4.3.10.根据主任务分类损失、插值一致性损失和辅助任务分类损失计算总损失:其中和为权重参数;进一步的,和设置为常数或函数。S4.3.11计算总损失L相对模型参数Θ的梯度;S4.3.12.更新主任务分类模型参数的移动平均值θ′∶=αθ′+(1-α)θ;(12)S4.3.13.利用随机梯度下降法更新模型参数Θ;S5.利用最优测试模型对目标领域的图像进行分类。与现有技术相比,本专利技术有效收益在于:(1)本专利技术在不依赖目标领域样本标注的情况下学习得到适用于目标领域的分类器,大大降低了在测试数据分布发生变化时手工标注样本的压力。(2)本专利技术结合目标领域样本的插值一致性先验和目标领域无标记样本的旋转角度预测这一辅助任务进行多任务学习,既能学习到适用于目标领域的特征,又能确保分类边界在目标领域数据分布中处在合适位置,能够有效提高目标领域的分类性能。附图说明图1是本专利技术的流程示意图;图2是本专利技术中主任务分类模型和辅助任务分类模型示意图;图3是本专利技术中多任务学习的损失函数示意图。具体实施方式本专利技术提出一种新的领域自适应学习方法,该方法面向图像分类任务,不需要复杂的对抗学习,而是通过一个目标领域图像旋转预测的辅助分类任务和目标领域无标记样本插值后预测结果一致性约束,构建多任务学习模型,最终学习得到适用于目本文档来自技高网
...

【技术保护点】
1.一种新的领域自适应学习方法,其特征在于,包含如下步骤:/nS1.准备源领域有标记样本集D

【技术特征摘要】
1.一种新的领域自适应学习方法,其特征在于,包含如下步骤:
S1.准备源领域有标记样本集Ds(x,y)和目标领域无标记样本集Dt(x);
S2.构建主任务深度卷积神经网络分类模型fθ(x),该模型由特征提取网络和分类网络级联而成,即其中θ={θv,θc}为可训练的参数;
S3.构建四类旋转预测辅助分类模型该模型与主任务分类模型共享特征提取网络g及其参数θv,辅助分类网络级联在特征提取网络后端,用于对图像的旋转角度进行预测;
S4.构建由源领域有监督分类任务、目标领域无标记样本插值一致性任务和目标领域样本旋转预测任务组成的多任务学习模型并在Ds(x,y)和Dt(x)上进行训练,以获取主任务分类模型的最优参数θ*;
S4.1.确定训练的迭代次数T、移动平均系数α、[0,1]之间的随机分布Q;
S4.2.初始化网络参数Θ={θv,θc,θa},初始化主任务分类模型θ={θv,θc}的移动均值:θ′∶=θ;
S4.3.利用随机梯度下降法进行迭代,更新模型参数;
S4.3.1.从源领域样本集Ds(x,y)中采集小批量样本
S4.3.2.计算源领域小批量样本的主任务分类损失



S4.3.3.从目标领域数据集Dt(x)采样两组无标记小批量样本
S4.3.4.利用主分类网络的均值教师模型计算目标领域样本的伪标记



S4.3.5.从随机分布Q采样插值系数λ;
S4.3.6.计算样本和预测的插值,样本插值的计算方法为:



预测的伪标记插值结果为:

<...

【专利技术属性】
技术研发人员:肖良许娇龙商尔科赵大伟朱琪聂一鸣戴斌
申请(专利权)人:中国人民解放军军事科学院国防科技创新研究院
类型:发明
国别省市:北京;11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1