一种基于半监督学习的领域泛化图像分类方法技术

技术编号:34398519 阅读:20 留言:0更新日期:2022-08-03 21:35
本发明专利技术公开了一种基于半监督学习的领域泛化图像分类方法,获取图像数据集,对图像进行预处理;构建包括双分支领域对抗网络和集成网络的网络框架;双分支领域对抗网络包含两个结构相同的领域对抗网络,每个领域对抗网络由一个特征提取器、一个分类器和一个领域判别器组成;集成网络由一个和领域对抗网络相同的特征提取器和一个分类器组成;使用经过预处理的图像数据集对步骤两个领域对抗网络进行训练,更新和保存无标签数据的伪标签;使用训练好的模型对目标域图像进行测试,得到分类结果。本发明专利技术构建的双分支领域对抗网络的结构框架同时实现了半监督学习和领域泛化的目标,提升了模型在目标域上的测试精度,具有较高的使用价值。值。值。

【技术实现步骤摘要】
一种基于半监督学习的领域泛化图像分类方法


[0001]本专利技术属于机器学习和人工智能领域,具体涉及一种基于半监督学习的领域泛化图像分类方法。

技术介绍

[0002]机器学习通过将经验数据提供给机器学习模型进行参数的学习,而训练好的模型可被应用于对新的测试数据提供预测结果。因此在机器学习中,假如测试集合和训练集合存在分布不一致的问题(domain shift),会导致模型在测试阶段精度下降。为了缓解这一现象,领域泛化(domain generalization)技术通过在训练阶段利用多个源领域联合训练,提升模型的通用性和泛化性,使得模型可以泛化到未知的目标域上。领域泛化需要多个有标签的源域,但是在实际的应用中,获取标签往往是最困难的。因此在实际的领域泛化场景中,如果源域由一个或多个有标签的领域和多个完全无标签的领域组成,则需要专利技术基于半监督学习的领域泛化算法。
[0003]领域泛化近年来受到了广泛关注,现有的领域泛化技术主要从以下几个角度解决领域偏移问题:(1)基于数据扩充或图像生成技术增强训练数据的多样性;(2)基于领域对抗训练等方式提取与领域无关的特征表示;(3)基于集成学习、元学习等学习策略提升模型泛化性。
[0004]在传统半监督学习中,只有少量的数据是有标签的,其主要目的是利用大量的无标签数据提升模型的性能。伪标签、熵最小化和一致性正则化是三类有效的方法。但是在领域泛化的问题中,如何在分布不一致的各源域之间进行半监督学习是一个难题。另外,因为目标领域也和源域之间存在着分布差异,如何在半监督学习的基础上提升模型的泛化性是半监督领域泛化的总目标。
[0005]本专利技术针对传统半监督方法在解决领域泛化问题中的不足,提出了领域差异下的半监督学习方法和一个基于半监督学习的领域泛化图像分类框架,使领域泛化技术适用于更多的标签不足的实际场景中。

技术实现思路

[0006]专利技术目的:本专利技术为了解决源域不完全有标签的领域泛化图像分类问题,提出一种基于半监督学习的领域泛化图像分类方法,有效利用了无标签领域样本,同时提升了模型的泛化性,提高了模型在目标域上的测试精度。
[0007]技术方案:本专利技术所述的一种基于半监督学习的领域泛化图像分类方法,具体包括以下步骤:
[0008](1)获取图像数据集,对图像进行预处理;
[0009](2)构建包括双分支领域对抗网络和集成网络的网络框架;所述双分支领域对抗网络包含两个结构相同的领域对抗网络,每个领域对抗网络由一个特征提取器、一个分类器和一个领域判别器组成;所述集成网络由一个和领域对抗网络相同的特征提取器和一个
分类器组成;
[0010](3)使用步骤(1)中经过预处理的图像数据集对步骤(2)中的两个领域对抗网络进行训练,更新和保存无标签数据的伪标签;
[0011](4)使用训练好的模型对目标域图像进行测试,得到分类结果。
[0012]进一步地,所述步骤(1)实现过程如下:
[0013]获取含有多个源域的图像数据集;半监督领域泛化任务中,给定有标签源域的集合S
l
={D1,...,D
m
}和无标签源域的集合S
u
={D1,...,D
n
},S
l
中的第j个领域可以表示为每个样本由输入x、类别标签y和领域标签z组成,其中n
j
是领域中样本个数;S
u
中的第j个领域可以表示为样本不包含类别标签;目标域其中x
i
是第i个输入图像,n
t
是目标域的样本个数;
[0014]将预先获取的数据集图像裁剪为统一大小的图像,并将图像的R、G和B三个通道分别做标准化,其中均值和方差使用ImageNet的均值和方差。
[0015]进一步地,所述步骤(2)实现过程如下:
[0016]将预处理的图像输入到特征提取器,特征提取网络输出高维的特征向量,特征向量分别输入到分类器和领域判别器,最终输出对样本类别的预测和对领域的预测;两个领域对抗网络分别表示为P
θ1
(y,z|x)和P
θ2
(y,z|x),其中θ1和θ2为网络参数,y为类别值,z为领域值;另外,用于预测伪标签的网络由P
θ1
(y,z|x网络集成所得,即网络不参与训练,通过P
θ1
(y,z|x)网络的参数更新自身参数,集成的网络表示为P
θ1

(y,z|x),其中θ1

为网络参数。
[0017]进一步地,所述步骤(3)包括以下步骤:
[0018](31)获取到含有多个源域的图像数据集;
[0019](32)对有标签源域和无标签源域进行差值,扩充领域之间的混合样本,如下:
[0020][0021][0022][0023]其中,为无标签样本的伪标签,λ服从Beta分布;领域间插值后得到新的训练样本
[0024](33)利用混合样本训练P
θ2
(y,z|x),利用原始的有标签样本和带有伪标签的无标签样本训练P
θ1
(y,z|x),假设领域对抗网络中的特征提取器,分类器和领域判别器分别表示为F
g
、F
c
和F
d
,P
θ2
(y,z|x)和P
θ1
(y,z|x)的分类损失分别为:
[0025][0026][0027]其中,N
p
是带有伪标签的无标签样本的数量,N
l
是有标签样本的数量,表示交叉熵损失函数;领域对抗的损失分别为:
[0028][0029][0030]其中,N代表样本总数;对无标签样本采用熵损失来训练P
θ1
(y,z|x),损失函数如下:
[0031][0032]其中,H表示熵函数,N
u
是无标签样本中还没有标记伪标签样本的数量;
[0033]另外,对混合样本约束特征一致性来训练P
θ2
(y,z|x),损失函数如下:
[0034][0035]其中,表示均方误差函数,N
l
是有标签样本的数量,N
p
是带有伪标签的无标签样本的数量,λ表示样本混合插值操作中超参数的值;
[0036](34)每次对全部样本做一次前向传播和后向传播后,进行一次伪标签的更新和预测,预测伪标签的网络使用由P
θ1
(y,z|x)通过指数移动平均集成得到的网络,指数移动平均过程的算法公式如下:
[0037]θ

t+
1=βθ1

t
+(1

β)θ1
t+1
[0038]其中,β是一个超参数,θ1

是P
θ1

(y,z|x)的网络本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于半监督学习的领域泛化图像分类方法,其特征在于,包括以下步骤:(1)获取图像数据集,对图像进行预处理;(2)构建包括双分支领域对抗网络和集成网络的网络框架;所述双分支领域对抗网络包含两个结构相同的领域对抗网络,每个领域对抗网络由一个特征提取器、一个分类器和一个领域判别器组成;所述集成网络由一个和领域对抗网络相同的特征提取器和一个分类器组成;(3)使用步骤(1)中经过预处理的图像数据集对步骤(2)中的两个领域对抗网络进行训练,更新和保存无标签数据的伪标签;(4)使用训练好的模型对目标域图像进行测试,得到分类结果。2.根据权利要求1所述的基于半监督学习的领域泛化图像分类方法,其特征在于,所述步骤(1)实现过程如下:获取含有多个源域的图像数据集;半监督领域泛化任务中,给定有标签源域的集合和无标签源域的集合和无标签源域的集合中的第j个领域可以表示为每个样本由输入x、类别标签y和领域标签z组成,其中n
j
是领域中样本个数;中的第j个领域可以表示为样本不包含类别标签;目标域其中x
i
是第i个输入图像,n
t
是目标域的样本个数;将预先获取的数据集图像裁剪为统一大小的图像,并将图像的R、G和B三个通道分别做标准化,其中均值和方差使用ImageNet的均值和方差。3.根据权利要求1所述的基于半监督学习的领域泛化图像分类方法,其特征在于,所述步骤(2)实现过程如下:将预处理的图像输入到特征提取器,特征提取网络输出高维的特征向量,特征向量分别输入到分类器和领域判别器,最终输出对样本类别的预测和对领域的预测;两个领域对抗网络分别表示为P
θ1
(y,z|x)和P
θ2
(y,z|x),其中θ1和θ2为网络参数,y为类别值,z为领域值;另外,用于预测伪标签的网络由P
θ1
(y,z|x)网络集成所得,即网络不参与训练,通过P
θ1
(y,z|x)网络的参数更新自身参数,集成的网络表示为P
θ1

(y,z|x),其中θ1

为网络参数。4.根据权利要求1所述的基于半监督学习的领域泛化图像分类方法,其特征在于,所述步骤(3)包括以下步骤:(31)获取到含有多个源域的图像数据集;(32)对有标签源域和无标签源域进行差值,扩充领域之间的混合样本,如下:(32)对有标签源域和无标签源域进行差值,扩充领域之间的混合样本,如下:(32)对有标签源域和无标签源域进行差值,扩充领域之间的混合样本,如下:其中,为无标签样本的伪标签,λ服从Beta分布;领域间插值后得到新的训练样本(33)利用混合样本训练P
θ2
(y...

【专利技术属性】
技术研发人员:史颖欢汪睿琪凌彤高阳冯华
申请(专利权)人:中和智同北京科技有限公司
类型:发明
国别省市:

相关技术
    暂无相关专利
网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1