一种无源域数据的无监督领域适应方法技术

技术编号:37796721 阅读:8 留言:0更新日期:2023-06-09 09:26
本发明专利技术涉及一种无源域数据的无监督领域适应方法,以有标签的源域样本训练模型,得到预训练好的源域模型;以源域模型初始化目标域模型;以源域模型的BN层存储的统计信息近似源域的特征分布,与目标域样本的特征分布显式对齐,最小化分布对齐损失,尽可能拉近源域和目标域特征分布空间;基于源域模型的分类器的预测对目标域样本的特征进行模糊聚类,以聚类隶属度作为目标域样本的软标签,计算软标签与模型分类器对目标域样本的预测之间的交叉熵损失,对目标域样本计算信息最大化损失;以所有损失函数共同训练目标域模型,实现无源域数据的无监督领域适应,纠正部分最初分类器分错的目标域样本,提高分类准确度。提高分类准确度。提高分类准确度。

【技术实现步骤摘要】
一种无源域数据的无监督领域适应方法


[0001]本专利技术涉及计算;推算或计数的
,特别涉及一种机器学习领域的、基于BN层信息和软聚类的无源域数据的无监督领域适应方法。

技术介绍

[0002]近年来,深度神经网络在视觉分类领域取得了非常不错的应用效果,被广泛地运用在各个行业。神经网络表现出卓越性能的一个前提是测试数据与训练数据服从独立同分布,然而,在现实世界中这个条件难以满足,理想情况下是希望模型能在标签丰富的数据集上获得的知识可以转移或者应用到其他未标记的数据上,但即使数据集之间的差异很小,深度网络也难以应用到未知的数据域中,在训练中,影响模型泛化能力的重要因素是来自不同领域数据之间的分布偏移。因此,领域适应就是针对这类问题进行的研究。
[0003]近年来,在该技术问题上取得了巨大的进展,尤其是无监督的领域适应。当我们可以直接访问源域数据集时,可以直接对齐源域和目标域的分布偏移,现有的许多领域适应方法即使对无标签的目标域数据都非常有效。然而,传统的领域适应都是基于源域数据及其标签可用的前提,在一些实际情况下,包括但不限于数据集过大存储困难、共享数据的挑战、数据隐私和其他数据集处理问题,使得源数据不容易获取,只能获取预训练好的模型,这让传统的无监督领域适应模型有了局限性,因此提出了无源领域适应。
[0004]无源领域适应与无监督领域适应的不同在于,无源即不能获取有标签的源域数据,只能用源域数据训练好的模型和无标签的目标域数据进行训练。目前无源领域适应常用的方法有两类:一类从预训练好的模型里挖掘包含源域特征的信息与目标域样本进行训练,微调预训练模型;另一类使用生成模型,利用目标域数据、预训练模型生成含有源域信息的生成样本,用生成样本和目标域样本进行领域适应;但因为没有源域数据,这些方法都没有对源域和目标域显式对齐,只是利用无监督方法进行微调或者生成的样本是类似目标域样本的伪源域样本。

技术实现思路

[0005]本专利技术解决了现有技术中存在的问题,提供了一种无源域数据的无监督领域适应方法。
[0006]本专利技术所采用的技术方案是,一种无源域数据的无监督领域适应方法,所述方法包括以下步骤:
[0007]步骤1:以有标签的源域样本训练模型,得到预训练好的源域模型;
[0008]步骤2:以源域模型初始化目标域模型,包括特征提取器和分类器;
[0009]步骤3:以源域模型的BN层存储的统计信息近似源域的特征分布,与目标域样本的特征分布显式对齐,计算分布对齐损失L
BN

[0010]步骤4:基于目标域模型的分类器的预测,对目标域样本的特征进行模糊聚类,以聚类隶属度作为目标域样本的软标签,计算软标签与模型分类器对目标域样本的预测之间
的交叉熵损失L
clu

[0011]步骤5:对目标域样本计算信息最大化损失L
IM
,信息最大化损失包括最小化熵损失和最大化平均熵损失,使样本预测置信度更高,并且避免模型坍塌;
[0012]步骤6:以所述分布对齐损失L
BN
、交叉熵损失L
clu
和信息最大化损失L
IM
共同训练目标域模型,实现无源域数据的无监督领域适应,提高对目标域样本的识别准确率。
[0013]优选地,所述步骤1中,为防止预训练模型在源域数据上过拟合,通过标签平滑后再计算交叉熵损失,以提升模型到目标域的泛化性能,目标函数为,
[0014][0015]其中,f
s
表示预训练好的源域模型,包括特征提取器g
s
和分类器h
s
,满足给定输入x,f
s
(x)=h
s
(g
s
(x));K表示类别数目,k对应任一类别,X
s
为源域样本集;给定q
k
为源域样本x
s
的标签,则是对q
k
平滑后的标签,满足α是平滑系数,0<α<1,一般取0.05≤α≤0.15;
[0016]σ(
·
)表示对某一给定向量的softmax归一化操作,假设给定向量a和温度参数T,用σ
k
表示对某个向量σ(
·
)操作后得到的第k维的值,
[0017][0018]a
k
表示向量a第k维的值,j指向量a第j维,式(1)中T为1。
[0019]优选地,所述目标域模型的分类器固定不变。所述步骤2中,目标域模型f
t
,包括特征提取器g
t
和分类器h
t
,分别初始化为源域模型中的特征提取器和分类器,给定任意输入x,满足f
t
(x)=h
t
(g
t
(x));通过损失函数优化目标域模型的特征提取器,分类器初始化以后冻结不更新。
[0020]优选地,所述步骤3中,BN层的统计信息包括该层每个通道的均值和方差,这些统计信息可以用来近似训练样本的全局特征分布,具体地说,每个BN层的每个通道的数据分布可以用一个高斯分布N(μ,σ2)表示,其中μ、σ2表示高斯分布的均值和方差;以源域模型的每个BN层中每个通道的均值、方差表示的高斯分布与目标域样本对应BN层的当前batch样本的每个通道的均值、方差表示的高斯分布,计算它们之间KL散度的平均值,作为衡量源域和目标域样本特征分布的距离。
[0021]优选地,所述分布对齐损失L
BN
为,
[0022][0023]其中,M表示模型中BN层的总数,C
m
表示第m个BN层的通道总数,和表示源域模型中第m个BN层第cm个通道存储的均值和方差,和表示当前batch经过目标域模型第m个BN层的第c
m
个通道的均值和方差;D
KL
为KL散度;
[0024]最小化损失函数L
BN
,通过最小化该损失函数,用BN层中的均值方差表示的高斯分
布来近似因没有源域样本而无法获取的源域特征分布,实现与目标域特征的分布对齐。
[0025]优选地,所述步骤4包括以下步骤:
[0026]步骤4.1:因为目标域和源域的数据差异,固定的源域分类器对目标域样本的预测存在噪声,使得分类器对难以分辨的目标域样本没有纠正作用,为缓解这一问题,引入基于聚类的软标签对齐损失,软标签生成过程类似模糊聚类。先以目标域模型分类器输出的概率为权重,对提取的特征进行加权平均,初始化簇中心:
[0027][0028]式(3)中,δ
k
表示第k个类的簇中心,f
t
表示目标模型,包括特征提取器g
t
和分类器h
t
,满足给定输入x,f
t
(x)=h
t
(g
t
(x));x
t
...

【技术保护点】

【技术特征摘要】
1.一种无源域数据的无监督领域适应方法,其特征在于:所述方法包括以下步骤:步骤1:以有标签的源域样本训练模型,得到预训练好的源域模型;步骤2:以源域模型初始化目标域模型,包括特征提取器和分类器;步骤3:以源域模型的BN层存储的统计信息近似源域的特征分布,与目标域样本的特征分布显式对齐,计算分布对齐损失L
BN
;步骤4:基于目标域模型的分类器的预测,对目标域样本的特征进行模糊聚类,以聚类隶属度作为目标域样本的软标签,计算软标签与模型分类器对目标域样本的预测之间的交叉熵损失L
clu
;步骤5:对目标域样本计算信息最大化损失L
IM
,信息最大化损失包括最小化熵损失和最大化平均熵损失;步骤6:以所述分布对齐损失L
BN
、交叉熵损失L
clu
和信息最大化损失L
IM
共同训练目标域模型,实现无源域数据的无监督领域适应。2.根据权利要求1所述的一种无源域数据的无监督领域适应方法,其特征在于:所述步骤1中,通过标签平滑后再计算交叉熵损失,目标函数为,其中,f
s
表示预训练好的源域模型,包括特征提取器g
s
和分类器h
s
,满足给定输入x,f
s
(x)=h
s
(g
s
(x));K表示类别数目,k对应任一类别,X
s
为源域样本集;给定q
k
为源域样本x
s
的标签,则是对q
k
平滑后的标签,满足α是平滑系数,0<α<1;σ(
·
)表示对某一给定向量的softmax归一化操作,假设给定向量a和温度参数T,用σ
k
表示对某个向量σ(
·
)操作后得到的第k维的值,a
k
表示向量a第k维的值,j指向量a第j维,式(1)中T为1。3.根据权利要求1所述的一种无源域数据的无监督领域适应方法,其特征在于:所述目标域模型的分类器固定不变。4.根据权利要求1所述的一种无源域数据的无监督领域适应方法,其特征在于:所述步骤3中,BN层的统计信息包括均值和方差;以源域模型的每个BN层中每个通道的均值、方差表示的高斯分布与目标域样本对应BN层的当前batch样本的每个通道的均值、方差表示的高斯分布,计算KL散度的平均值,作为衡量源域和目标域样本特征分布的距离。5.根据权利要求4所述的一种无源域数据的无监督领域适应方法,其特征在于:所述分布对齐损失L
...

【专利技术属性】
技术研发人员:梅建萍翁烨涛
申请(专利权)人:浙江工业大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1