当前位置: 首页 > 专利查询>浙江大学专利>正文

一种基于Transformer的多源无数据域自适应方法和系统技术方案

技术编号:38373985 阅读:9 留言:0更新日期:2023-08-05 17:36
本发明专利技术公开了一种基于Transformer的多源无数据域自适应方法和系统,属于迁移学习技术领域。每个源域训练一个包括特征提取器和分类器的模型;为每个源域分配一个权重系数,初始化老师模型和学生模型;在目标域数据集上,计算硬伪标签和软类别正交伪标签;联合信息最大化损失、基于硬伪标签和软类别正交伪标签的交叉熵损失,迭代更新权重系数和学生模型,再更新教师模型,直至学生模型和权重系数收敛;最后将收敛后的学生模型的知识蒸馏给目标域模型。本发明专利技术在特征提取器中引入视觉Transformer网络,有效地聚焦于图像中的物体,结合软类别正交伪标签,解决了多源无数据域自适应中源域模型对目标域数据泛化能力不高的问题。问题。问题。

【技术实现步骤摘要】
一种基于Transformer的多源无数据域自适应方法和系统


[0001]本专利技术属于迁移学习
,尤其涉及一种基于Transformer的多源无数据域自适应方法。

技术介绍

[0002]为了解决深度学习模型在应用于不可见的目标域时由于域偏移而导致的性能下降问题,无监督域自适应应运而生,以将知识从完全标记的源域转移到未标记的目标域。无监督域自适应在各种各样的应用中均取得了巨大成功,如图像分类、目标检测、语义分割等。
[0003]域自适应中的一个常见策略是通过匹配分布统计矩来最小化域之间的分布差异,或者采用对抗训练的方式,以使学习到的源域和目标域特征彼此无法区分。然而,现有方法大多假设源数据仅从单个域中提取,但更多实际场景的应用中,标记样本通常从多个域中收集,例如,不同的天气或照明条件、不同的视觉提示、不同的模态等。因此,出现了多源域自适应方法,通过将知识从多个不同的源域转移到一个未标记的目标域来完成域自适应。
[0004]但是,现有的多源域自适应方法需要在严格的条件下才行得通,即当将源域适配到目标域时,源域数据始终可用。然而,由于数据隐私政策和存储限制,这样的条件可能会使它们在现实应用中行不通。事实上,解决多源无数据域自适应问题的一个简单方案是单独适应每个源域模型,并直接对这些源域模型的预测进行平均,但它无法探索每个源域对目标域的贡献。到目前为止,针对多源无数据域自适应的方案很少,并且主要依赖一些基本约束(例如信息最大化、标签一致性)来对齐源域和目标域的分布。但是,当执行目标域自适应时,他们都忽略了源域模型的泛化能力。由于多源无数据域自适应的性能主要取决于源域模型在目标域上的初始精度,如果源域模型不能很好地泛化,生成的伪标签在一开始就不可避免地包含显著的噪声,这将严重降低自适应性能。

技术实现思路

[0005]为了克服现有技术的缺陷,解决多源无数据域自适应中源域模型对目标域数据泛化能力不高的问题,本专利技术提出了一种基于Transformer的多源无数据域自适应方法和系统。
[0006]本专利技术是通过下述技术方案实现的:
[0007]第一方面,本专利技术提供了一种基于Transformer的多源无数据域自适应方法,包括以下步骤:
[0008]步骤1,构建包括特征提取器和分类器的模型,所述的特征提取器由卷积神经网络和Transformer网络构成;
[0009]步骤2,为每个源域随机初始化一个模型,然后在每个源域数据集上,利用标签平滑交叉熵损失基于随机梯度下降训练模型至收敛;
[0010]步骤3,将各个源域对应的训练好的模型传递至目标域;目标域为每个源域分配一
个可学习的权重系数,利用每个源域模型的参数初始化一个老师模型和一个学生模型;
[0011]步骤4,在目标域数据集上,计算目标域样本类别的特征中心;
[0012]步骤5,计算老师模型中的特征提取器输出的特征向量与每个特征中心的距离,将最近的特征中心对应的类别确定为硬伪标签;并基于特征向量跟特征中心间的距离计算软类别正交伪标签;
[0013]步骤6,联合计算信息最大化损失、基于硬伪标签和软类别正交伪标签的交叉熵损失,利用总损失迭代更新每个源域的权重系数和学生模型;再根据更新后的学生模型,更新对应的教师模型;
[0014]步骤7,重复步骤4至步骤6,直至学生模型和权重系数收敛;
[0015]步骤8,初始化一个目标域模型,并将收敛后的学生模型的知识蒸馏给目标域模型,利用蒸馏训练后的目标域模型实现对目标域数据的分类。
[0016]进一步的,所述的标签平滑交叉熵损失的具体形式为:
[0017][0018]其中,L
ce
为标签平滑交叉熵损失,为第j个源域数据集,表示第j个源域数据集中的样本,K为样本的类别数;表示平滑后的第k类样本标签的值,y
k
表示样本标签的one

hot编码中的第k个值,η表示标签平滑的程度;σ(
·
)表示softmax操作,表示第j个源域模型,E(
·
)表示取均值操作。
[0019]进一步的,全部源域的权重系数之和为1。
[0020]进一步的,初始化的老师模型和学生模型中的分类器参数冻结,在后续训练过程中不再更新。
[0021]进一步的,所述的步骤4包括:
[0022](4.1)基于每个老师模型,对目标域数据集的每个类别k产生特征中心:
[0023][0024]其中,D
T
表示目标域数据集,x
T
表示目标域数据集中的样本,σ
k
(
·
)表示softmax操作后的结果中第k个值,表示第j个老师模型,表示第j个老师模型中的分类器,表示第j个老师模型中的特征提取器;表示第j个老师模型生成的类别k的特征中心;
[0025](4.2)对每个老师模型产生的特征中心进行聚合:
[0026][0027]其中,表示类别k的特征中心,α
j
表示第j个源域的权重系数;
[0028](4.3)获取每个目标域样本的初始伪标签:
[0029][0030][0031]其中,表示目标域样本的初始伪标签,表示加权聚合后的老师模型的特征提取器,表示L2范数的平方;
[0032](4.4)利用初始伪标签的one

hot编码重新加权平均老师模型中的特征提取器输出的特征向量,更新特征中心:
[0033][0034][0035][0036]其中,1(
·
)表示当括号中的等式成立为1,否则为0,表示目标域样本的更新后的伪标签,表示更新后的类别k的特征中心,表示更新后的由第j个老师模型生成的类别k的特征中心。
[0037]进一步的,软类别正交伪标签的计算公式为:
[0038][0039]其中,ξ(
·
,
·
)表示用于计算向量间距离的函数,τ表示用于控制软类别正交伪标签的平滑程度的温度参数,γ表示用于控制硬伪标签类别的占比系数,表示软类别正交伪标签,表示更新后的类别k的特征中心,K表示样本类别数,表示加权聚合后的老师模型的特征提取器,表示目标域数据集中样本x
T
的硬伪标签。
[0040]进一步的,所述的基于硬伪标签和软类别正交伪标签的交叉熵损失的计算公式为:
[0041][0042][0043]其中,L
pl
表示基于硬伪标签的交叉熵损失,L
ssco
是基于软类别正交伪标签的交叉熵损失,表示目标域数据集D
T
中样本x
T
的硬伪标签,1(
·
)表示当括号中的等式成立是为1,否则为0;K表示样本类别数,σ
k
(
·
)表示softmax操作后的结果中第k个值,θ
Stu
(
·...

【技术保护点】

【技术特征摘要】
1.一种基于Transformer的多源无数据域自适应方法,其特征在于,包括以下步骤:步骤1,构建包括特征提取器和分类器的模型,所述的特征提取器由卷积神经网络和Transformer网络构成;步骤2,为每个源域随机初始化一个模型,然后在每个源域数据集上,利用标签平滑交叉熵损失基于随机梯度下降训练模型至收敛;步骤3,将各个源域对应的训练好的模型传递至目标域;目标域为每个源域分配一个可学习的权重系数,利用每个源域模型的参数初始化一个老师模型和一个学生模型;步骤4,在目标域数据集上,计算目标域样本类别的特征中心;步骤5,计算老师模型中的特征提取器输出的特征向量与每个特征中心的距离,将最近的特征中心对应的类别确定为硬伪标签;并基于特征向量跟特征中心间的距离计算软类别正交伪标签;步骤6,联合计算信息最大化损失、基于硬伪标签和软类别正交伪标签的交叉熵损失,利用总损失迭代更新每个源域的权重系数和学生模型;再根据更新后的学生模型,更新对应的教师模型;步骤7,重复步骤4至步骤6,直至学生模型和权重系数收敛;步骤8,初始化一个目标域模型,并将收敛后的学生模型的知识蒸馏给目标域模型,利用蒸馏训练后的目标域模型实现对目标域数据的分类。2.根据权利要求1所述的基于Transformer的多源无数据域自适应方法,其特征在于,所述的标签平滑交叉熵损失的具体形式为:其中,L
ce
为标签平滑交叉熵损失,为第j个源域数据集,表示第j个源域数据集中的样本,K为样本的类别数;表示平滑后的第k类样本标签的值,y
k
表示样本标签的one

hot编码中的第k个值,η表示标签平滑的程度;σ(
·
)表示softmax操作,表示第j个源域模型,E(
·
)表示取均值操作。3.根据权利要求1所述的基于Transformer的多源无数据域自适应方法,其特征在于,全部源域的权重系数之和为1。4.根据权利要求1所述的基于Transformer的多源无数据域自适应方法,其特征在于,初始化的老师模型和学生模型中的分类器参数冻结,在后续训练过程中不再更新。5.根据权利要求1所述的基于Transformer的多源无数据域自适应方法,其特征在于,所述的步骤4包括:(4.1)基于每个老师模型,对目标域数据集的每个类别k产生特征中心:其中,D
T
表示目标域数据集,x
T
表示目标域数据集中的样本,σ
k
(
·
)表示softmax操作后的结果中第k个值,表示第j个老师模型,表示第j个老师模型中
的分类器,表示第j个老师模型中的特征提取器;表示第j个老师模型生成的类别k的特征中心;(4.2)对每个老师模型产生的特征中心进行聚合:其中,表示类别k的特征中心,α
j
表示第j个源域的权重系数;(4.3)获取每个目标域样本的初始伪标签:(4.3)获取每个目标域样本的初始伪标签:其中,表示目标域样本的初始伪标签,表示加权聚合后的老师模型的特征提取器,表示L2范数的平方;(4.4)...

【专利技术属性】
技术研发人员:吴超李港李皓
申请(专利权)人:浙江大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1