一种提升数字图像分类模型泛化能力的方法及系统技术方案

技术编号:33993401 阅读:74 留言:0更新日期:2022-07-02 10:17
本发明专利技术公开了一种提升数字图像分类模型泛化能力的方法及系统,属于计算机视觉与迁移学习技术领域中的适用于数字图像的分类模型及泛化方法,其目的在于解决现有技术中没有在领域泛化中对数字图像使用基于混合样本的数据增强的问题,其通过不同分类器的梯度信息获取数据样本图像的类别相关信息和域相关信息,将数据样本图像的类别信息和其他数据样本图像的域信息相结合,生成新的数据增强样本加入模型训练。该方法将基于混合样本的数据增强运用到了领域泛化任务中,通过混合不同样本的类别信息与域信息生成数据增强样本,生成的数据样本直观并且可解释性高,从而有益于提高模型的鲁棒性和泛化能力。的鲁棒性和泛化能力。的鲁棒性和泛化能力。

【技术实现步骤摘要】
一种提升数字图像分类模型泛化能力的方法及系统


[0001]本专利技术属于计算机视觉与迁移学习
,涉及一种适用于数字图像的分类模型的泛化方法及系统,更具体的是涉及一种基于数据增强的可适用于手写数字图像的泛化。

技术介绍

[0002]由于深度学习对带标签数据的巨大需求以及人工标注在某些领域的低效性,迁移学习、无监督学习是当下研究的热点之一。迁移学习专注于将已经训练好的优秀模型应用到与源领域有一定相似性的目标领域中,从而减少了对新领域带标签数据的巨大需求。
[0003]领域泛化是迁移学习的一个子领域,其目的为训练一个具有较好鲁棒性的模型,通常是分类器,使得这个模型在任意一个训练时不可见的域都能有比较好的表现。在领域泛化任务中,虽然有源领域的数据和标签,但它并没有明确的目标域,更没有任何目标域的数据参与训练。所以领域泛化关注的不是模型在特定目标域的表现,而是在任意目标域的表现。这也使得模型的鲁棒性和泛化性更高,可以在不同的目标场景中重复使用,而不用针对当前目标域重新训练。
[0004]国内外的领域泛化研究中,常使用的数据集包括Digits

DG,PACS, Office

Home等,常用方法主要基于深度神经网络相关技术,具体又可以细分为基于领域对齐的方法、基于元学习的方法、基于数据增强的方法和基于自监督的方法等。其中,基于数据增强的方法又具体可以分为四类,分别为人工数据增强、基于梯度的数据增强、基于模型的数据增强和基于特征的数据增强。人工数据增强主要包括对原图进行对比度变化、亮度变化、旋转图片等等;基于梯度的数据增强借鉴对抗攻击的思路,通过梯度在原图上增加扰动,生成让标签分类器或领域分类器难以分辨类别标签或者域标签的新数据样本图像;基于模型的增强方法,主要包括用神经网络,cycleGAN等来对图片进行不同域之间的风格迁移;基于特征的增强方法则是在特征层面进行变换、融合,来生成新的样本。
[0005]尽管研究者们提出的诸多数据增强方法已经在大量公开数据集中取得了较高的正确率,但现有的数据增强方法中,还存在一些不足。首先,混合样本的数据增强在迁移学习的任务中并不常见,除了Minghao Xu等人将Mixup和领域自适应任务相结合,目前没有了解到其他工作在领域泛化中使用基于混合样本的数据增强。其次,目前大部分用于领域泛化的数据增强方法都较复杂,比如目前在领域泛化数据增强方面效果最优的工作DDAIG使用了专门设计的神经网络生成扰动,再添加到原图片上,从而生成新的数据样本图像,这样增加新网络的设计加大了整体的计算量,并且难以直接和其他分类网络相结合,而本专利的方法更加简洁并且普适。
[0006]综上,通过本专利提出的混合样本的数据增强方法,能够为领域泛化任务提供简单且可解释性高的数据泛化途径,本方法直接使用原本分类网络的梯度信息进行数据增强,新增计算量小并且可以迁移到任何分类网络进行使用。

技术实现思路

[0007]本专利技术的目的在于:为了解决现有技术中没有在领域泛化中对数字图像使用基于混合样本的数据增强的问题,本专利技术提供一种可适用于数字图像(尤其是手写数字)的数字图像分类模型、提升数字图像分类模型的泛化能力的方法及系统,通过不同分类器的梯度信息获取数据样本图像的类别相关信息和域相关信息,将数据样本图像的类别信息和其他数据样本图像的域信息相结合,生成新的数据增强样本加入模型训练,以实现模型更好的鲁棒性与泛化能力。
[0008]本专利技术为了实现上述目的具体采用以下技术方案:一种提升数字图像分类模型泛化能力的方法,具体步骤为:步骤S1,样本获取获取手写的数据样本图像;步骤S2,数据预处理对步骤S1获取的数据样本图像进行预处理,预处理包括图像放缩、图像翻转以及图像裁剪;步骤S3,模型搭建搭建数字图像分类模型,数字图像分类模型包括特征提取模块、类别分类器模块、域分类器模块以及数据增强模块;特征提取模块,采用卷积神经网络CNN或残差神经网络ResNet

101对数据样本图像进行特征提取,其中卷积神经网络CNN使用六层卷积层和四层最大池化,最终得到的特征维度为256维,残差神经网络ResNet

101去掉最后一层全连接层后,最终得到的特征维度为2048维;类别分类器模块,通过第一全连接层将特征提取模块提取出的特征分类到对应的类别;域分类器模块,通过第二全连接层将特征提取模块提取出的特征分类到对应的域;数据增强模块,通过第一全连接层输出的得分计算出样本图像中每个像素点对于类别信息的贡献;通过第二全连接层输出的得分计算出样本图像中每个像素点对于域信息的贡献;通过将数据样本图像的类别信息的贡献大的像素点,与其他数据样本图像的域信息的贡献大的数据点相结合,生成新的增强的数据样本图像并输入至类别分类器模块、域分类器模块;类别信息的贡献大于设置的最小贡献值即为类别信息的贡献大,域信息的贡献大于设置的最小贡献值即为域信息的贡献大;步骤S4,模型预训练使用步骤S2预处理后的数据样本图像对数字图像分类模型中的特征提取模块、类别分类器模块、域分类器模块进行预训练,特征提取模块提取数据样本图像中的特征,并将提取到的特征输入类别分类器模块、域分类器模块,类别分类器模块将特征提取模块提取出的特征分类到对应的类别并定位出数据样本图像中与类别信息有关的像素点,域分类器模块将特征提取模块提取出的特征分类到对应的域并定位出数据样本图像中与域信息有关的像素点,类别分类器模块、域分类器模块定位到的像素点再输入至数据增强模块,生成新的增强的数据样本图像;
步骤S5,模型训练每个类别随机选取一定量的数据样本图像,并加入增强的数据样本图像,组成新的训练集,继续对数字图像分类模型进行训练;数据样本图像、增强的数据样本图像作为输入,通过第一全连接层输出的得分,数据增强模块根据得分计算出样本图像中每个像素点对于类别信息的贡献;通过第二全连接层输出的得分,数据增强模块根据得分计算出样本图像中每个像素点对于域信息的贡献;通过将训练集中数据样本图像的类别信息的贡献大的像素点,与训练集中其他数据样本图像的域信息的贡献大的数据点相结合,生成新的增强的数据样本图像;加入增强的数据样本图像后,数字图像分类模型的整体损失函数为:其中,、分别是类别分类器模块的第一损失函数、域分类器模块的第二损失函数,、分别是、的权重参数,按损失函数对整个数字图像分类模型的特征提取模块、类别分类器模块和域分类器模块的参数进行更新。
[0009]优选地,数据样本图像包括Digit

DG、Office

Home和PACS三个公开数据集;Digit

DG收集了四个不同种类的手写数字识别数据集的子集,分别为MNIST

M、MNIST、SVHN和Synthetic Digits,每个子集均含有25000张训练图片和9000张测试图片;Office

Home包含15588张图片,65个类别,展现在四个域中:Art、Clipart本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种提升数字图像分类模型泛化能力的方法,其特征在于:具体步骤为:步骤S1,样本获取获取手写的数据样本图像;步骤S2,数据预处理对步骤S1获取的数据样本图像进行预处理,预处理包括图像放缩、图像翻转以及图像裁剪;步骤S3,模型搭建搭建数字图像分类模型,数字图像分类模型包括特征提取模块、类别分类器模块、域分类器模块以及数据增强模块;特征提取模块,采用卷积神经网络CNN或残差神经网络ResNet

101对数据样本图像进行特征提取,其中卷积神经网络CNN使用六层卷积层和四层最大池化,最终得到的特征维度为256维,残差神经网络ResNet

101去掉最后一层全连接层后,最终得到的特征维度为2048维;类别分类器模块,通过第一全连接层将特征提取模块提取出的特征分类到对应的类别;域分类器模块,通过第二全连接层将特征提取模块提取出的特征分类到对应的域;数据增强模块,通过第一全连接层输出的得分计算出样本图像中每个像素点对于类别信息的贡献;通过第二全连接层输出的得分计算出样本图像中每个像素点对于域信息的贡献;通过将数据样本图像的类别信息的贡献大的像素点,与其他数据样本图像的域信息的贡献大的数据点相结合,生成新的增强的数据样本图像并输入至类别分类器模块、域分类器模块;类别信息的贡献大于设置的最小贡献值即为类别信息的贡献大,域信息的贡献大于设置的最小贡献值即为域信息的贡献大;步骤S4,模型预训练使用步骤S2预处理后的数据样本图像对数字图像分类模型中的特征提取模块、类别分类器模块、域分类器模块进行预训练,特征提取模块提取数据样本图像中的特征,并将提取到的特征输入类别分类器模块、域分类器模块,类别分类器模块将特征提取模块提取出的特征分类到对应的类别并定位出数据样本图像中与类别信息有关的像素点,域分类器模块将特征提取模块提取出的特征分类到对应的域并定位出数据样本图像中与域信息有关的像素点,类别分类器模块、域分类器模块定位到的像素点再输入至数据增强模块,生成新的增强的数据样本图像;步骤S5,模型训练每个类别随机选取一定量的数据样本图像,并加入增强的数据样本图像,组成新的训练集,继续对数字图像分类模型进行训练;数据样本图像、增强的数据样本图像作为输入,通过第一全连接层输出的得分,数据增强模块根据得分计算出样本图像中每个像素点对于类别信息的贡献;通过第二全连接层输出的得分,数据增强模块根据得分计算出样本图像中每个像素点对于域信息的贡献;通过将训练集中数据样本图像的类别信息的贡献大的像素点,与训练集中其他数据样本图像的域信息的贡献大的数据点相结合,生成新的增强的数据样本图像;
加入增强的数据样本图像后,数字图像分类模型的整体损失函数为:其中,、分别是类别分类器模块的第一损失函数、域分类器模块的第二损失函数,、分别是、的权重参数,按损失函数对整个数字图像分类模型的特征提取模块、类别分类器模块和域分类器模块的参数进行更新。2.如权利要求1所述的一种提升数字图像分类模型泛化能力的方法,其特征在于:数据样本图像包括Digit

DG、Office

Home和PACS三个公开数据集;Digit

DG收集了四个不同种类的手写数字识别数据集的子集,分别为MNIST

M、MNIST、SVHN和Synthetic Digits,每个子集均含有25000张训练图片和9000张测试图片;Office

Home包含15588张图片,65个类别,展现在四个域中:Art、Clipart、Product和Real;PACS总共包含9991张图片,有四个域的数据:Art painting、Cartoon、Photo和Sketch,其中每个域都包含有7个类别的图像。3.如权利要求1所述的一种提升数字图像分类模型泛化能力的方法,其特征在于:特征提取模块提取数据样本图像中的特征时,对于Digit

DG数据集,图像大小缩放为,特征提取模块选用六层卷积层和四层最大池化的卷积神经网络,卷积核大小均为3,最终拉伸为一维向量得到的特征维度为256维;对于Office

Home数据集,图像大小为,特征提取模块选用去掉最后一层全连接层的残差神经网络ResNet

101,输出为2048维的特征向量;对于PACS数据集,图像大小为,特征提取模块同样选用去掉最后一层全连接层的残差神经网络ResNet

101,输出为2048维的特征向量;在训练中,Digit

DG数据集的一批训练数据量为128,Office

Home和PACS数据集的一批训练数据量为32,因此训练时Digit

DG的特征图维度为,Office

Home和PACS数据集的特征图维度为;训练时,Digit

DG数据集训练50个epoch,Office

Home和PACS数据集由于数据量相对较大,只训练30个epoch。4.如权利要求1所述的一种提升数字图像分类模型泛化能力的方法,其特征在于:类别分类器模块采用第一损失函数进行训练,第一损失函数为:其中N为类别数,为符号函数,如果数据样本图像i的真实类别为c,则取1,否则取0;,是数据样本图像i属于类别c的预测概率,其中为数据样本图像i经过特征提取模块与第一全连接层后,得到的增强的数据样本图像i关于类别c的得分;为数据样本图像i经过特征提取模块与第一全连接层后,得到的数据样本图像i关
于类别j的得分;为符号函数,如果增强的数据样本图像i的真实类别为c,则取1,否则取0;,是增强...

【专利技术属性】
技术研发人员:徐行唐嘉翊沈复民申恒涛
申请(专利权)人:成都考拉悠然科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1