基于一致性训练的自监督领域自适应深度学习方法技术

技术编号:24123846 阅读:18 留言:0更新日期:2020-05-13 03:56
本发明专利技术公开了一种基于一致性训练的自监督领域自适应深度学习方法。该方法首先构建一个数据增强变换集合,对每一个变换定义一个标签。针对源域样本和其对应的类别标签,构建分类任务;对源域和目标域样本应用所述数据增强变换,通过最小化预测该变换类别的误差,构建自监督学习任务;针对源域和目标域样本,通过最小化变换后的样本和原始样本在分类任务上的输出的KL散度(Kullback‑Leibler Divergence),构建一致性训练任务;构建一个多任务学习网络,将所述的分类、自监督学习和一致性训练任务进行联合训练。该方法无需对目标域样本进行标注,能有效地学习目标域特征表示,提升目标域上样本分类和识别的性能。本申请还公开了一种领域自适应深度学习可读存储介质,同样具有上述有益效果。

Adaptive deep learning method based on consistency training in self supervised domain

【技术实现步骤摘要】
基于一致性训练的自监督领域自适应深度学习方法
本专利技术属于新一代信息
,具体涉及领域自适应深度学习方法及可读存储介质。
技术介绍
机器学习特别是深度学习模型通常需要大量的标注样本来进行监督学习,比如图像、文本等的分类和识别需要收集大量的样本,同时还需要标注每一个样本的对应的类别。当模型在标注数据上训练完成之后,将其应用到测试数据上。当测试数据与训练数据具有相同的分布时,监督学习是一种非常有效的方法。然而实际应用中通常会出现测试数据与训练数据分布不同的情况,从而使得模型在测试数据集上的性能急剧下降。领域自适应(domainadaptation)是解决上述由于训练和测试数据分布差异引起模型性能下降问题的一类技术方法。通常将训练数据集称为源领域,测试数据集称为目标领域。源领域的数据是带有标注信息的,而目标领域的数据通常是没有标注信息的。领域自适应技术旨在将源领域的监督信息迁移到目标领域,提升目标领域上任务的性能。目前基于深度神经网络的领域自适应学习大多数是通过领域对抗训练来学习跨领域不变的特征表示,从而提升目标领域上的任务的性能的。然而领域对抗训练需要优化一对相互对抗的目标函数,训练过程的收敛比较困难,很难得到最优的模型。
技术实现思路
本专利技术要解决的技术问题是领域对抗训练时优化一对相互对抗的目标函数,训练过程的收敛困难,难以获取适合的模型。本专利技术为解决上述技术问题,提供基于一致性训练的自监督领域自适应深度学习方法,该方法提供一种非对抗式的训练方法,以提高目标领域上任务的性能,具体的技术方案如下:S1:构建一个多任务学习深度神经网络,包含一个参数为θe的特征提取网E,参数为θm主分类网M,以及参数为θp的数据增强变换预测网P;S2:将源域样本xs和其类别标签y组成分类任务训练集Ds={(xs,y)|y∈[0,C]},其中C是类别数;S3:构建一组数据增强变换集合G={g(x,r)|r∈[0,R)},每一个数据增强变换g(x,r)对应一个变换类别标签r;S31:对源域样本xs应用数据增强变换得到源域自监督训练集Ds*={(xs*,r)|xs*=g(xs,r)},以及源域一致性训练集Dcs={(xs,xs*)|xs*∈Ds*};S32:对目标域样本xt应用数据增强变换得到目标域自监督训练集Dt*={(xt*,r)|xt*=g(xt,r)},以及目标域一致性训练集Dct={(xt,xt*)|xt*∈Dt*};S4:将步骤S31中的源域自监督训练集Ds*={(xs*,r)|xs*=g(xs,r)}和步骤S32中的目标域自监督训练集Dt*={(xt*,r)|xt*=g(xt,r)}合并得到总的自监督训练集D*=Ds*∪Dt*;S5:将步骤S31中的源域一致性训练集Dcs={(xs,xs*)|xs*∈Ds*}和步骤S32中的目标域一致性训练集Dct={(xt,xt*)|xt*∈Dt*}合并得到总的一致性训练集Dc=Dcs∪Dct;S6:针对步骤S2中分类任务训练集Ds以及步骤S1中的特征提取网E和主分类网M,构建有监督学习任务,其训练损失函数为:S7:针对步骤S4中自监督学习训练集D*以及步骤S1中的特征提取网E和数据据增强变换预测网P,构建自监督学习任务,其训练损失函数为:S8:针对步骤S5中一致性训练集Dc以及步骤S1中的特征提取网E和主分类网M,构建一致性学习任务,通过KL散度(Kullback-LeiblerDivergence)距离构建其训练损失函数:其中DKL为KL散度距离;S9:将步骤S6、S7以及S8中的损失函数加权求和,得到总的训练损失函数:Ltotal=LM+λ1LP+λ2LC(4)其中λ1和λ2为加权系数,可通过交叉验证选取合适的值;S10:通过最小化步骤S9中的损失函数Ltotal,得到训练后优化的参数θe、θp以及θm;S11:对目标域测试样本,使用步骤S10中优化后的参数,通过公式得到其预测的样本类别,实现深度学习模型在目标域上的领域自适应。本专利技术还提供一种可读存储介质,该可读存储介质上存储有程序,当该程序被处理器执行时能够实现步骤S1-S11的基于一致性训练的自监督领域自适应深度学习方法。相对于现有技术,本专利技术的有效收益如下:1、本专利技术提供的领域自适应深度学习方法,通过数据增强来构建一致性训练和自监督训练,通过多任务学习框架联合源领域标注样本的监督学习来学习适应目标领域的特征表示,从而实现领域自适应。2、本专利技术该不依赖人工标注来构建目标领域训练集,通过目标域样本的一致性训练和自监督学习,建立适应目标领域任务的特征表示,从而提高目标领域上任务的性能。3、本专利技术还提供一种领域自适应深度学习可读存储介质,该可读存储介质上存储有程序,当该程序被处理器执行时同样具有上述有益效果。附图说明图1是本专利技术实施例的基于一致性训练的自监督领域自适应深度学习训练过程的流程示意图。具体实施方式以下结合说明书附图和图像分类领域自适应学习实例对本专利技术作进一步的详细描述,但并不因此而限制本专利技术的保护范围。图1给出了本专利技术实施例的基于一致性训练的自监督领域自适应深度学习训练流程示意图。以图像分类领域自适应学习主要包括以下步骤:S1:构建一个多任务学习深度神经网络,包含一个参数为θe的特征提取网E,参数为θm图像分类网M,以及参数为θp的数据据增强变换预测网P;本实施例中S1中的数据增强变换采用图像旋转操作。S2:将源域图像xs和其类别标签y组成分类任务训练集Ds={(xs,y)|y∈[0,C]},其中C是图像类别数目;S3:构建一组基于图像旋转的数据增强变换集合G={g(x,r)|r∈[0,R)},每一个数据增强变换g(x,r)对应一个变换类别标签r,本实例采用三种不同角度旋转(即R=3),分别为90°、180°和270°旋转,对应的变换标签为0,1和2;S31:对源域图像xs应用图像旋转数据增强变换得到源域自监督训练集Ds*={(xs*,r)|xs*=g(xs,r)},以及源域一致性训练集Dcs={(xs,xs*)|xs*∈Ds*};S32:对目标域图像xt应用图像旋转数据增强变换得到目标域自监督训练集Dt*={(xt*,r)|xt*=g(xt,r)},以及目标域一致性训练集Dct={(xt,xt*)|xt*∈Dt*};S4:将步骤S31中的源域自监督训练集Ds*={(xs*,r)|xs*=g(xs,r)}和步骤S32中的目标域自监督训练集Dt*={(xt*,r)|xt*=g(xt,r)}合并得到总的自监督训练集D*=Ds*∪Dt*;S5:将步骤S31中的源域一致性训练集Dcs={(xs,xs*)|xs*∈Ds*}和步骤S32中的目标域一致性训练集Dct={(xt,xt*)|xt*∈Dt*}合并得到总的一本文档来自技高网...

【技术保护点】
1.基于一致性训练的自监督领域自适应深度学习方法,其特征在于,该方法包括:/nS1:构建一个多任务学习深度神经网络,包含一个参数为θ

【技术特征摘要】
1.基于一致性训练的自监督领域自适应深度学习方法,其特征在于,该方法包括:
S1:构建一个多任务学习深度神经网络,包含一个参数为θe的特征提取网E,参数为θm主分类网M,以及参数为θp的数据增强变换预测网P;
S2:将源域样本xs和其类别标签y组成分类任务训练集Ds={(xs,y)|y∈[0,C]},其中C是类别数;
S3:构建一组数据增强变换集合G={g(x,r)|r∈[0,R)},每一个数据增强变换g(x,r)对应一个变换类别标签r;
S31:对源域样本xs应用数据增强变换得到源域自监督训练集以及源域一致性训练集
S32:对目标域样本xt应用数据增强变换得到目标域自监督训练集以及目标域一致性训练集
S4:将步骤S31中的源域自监督训练集和步骤S32中的目标域自监督训练集合并得到总的自监督训练集
S5:将步骤S31中的源域一致性训练集和步骤S32中的目标域一致性训练集合并得到总的一致性训练集Dc=Dcs∪Dct;
S6:针对步骤S2中分类任务训练集Ds以及步骤S1中的特征提取网E和主分类网M,构建有监督学习任务,其训练损失函数为:



S7:针对步骤...

【专利技术属性】
技术研发人员:许娇龙肖良朱琪聂一鸣
申请(专利权)人:中国人民解放军军事科学院国防科技创新研究院
类型:发明
国别省市:北京;11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1