基于数据增强的分类模型泛化性的提升方法及相关设备技术

技术编号:39006703 阅读:14 留言:0更新日期:2023-10-07 10:38
本申请涉及人工智能及数字医疗领域,提出一种基于数据增强的分类模型泛化性的提升方法、装置、电子设备及存储介质,所述方法包括:搭建包括原始分支和带有数据增强模块的数据增强分支的初始分类模型;采集带有类别标签的待分类数据;将待分类数据输入原始分支和数据增强分支以获取原始下采样图、原始分类结果、增强下采样图、增强特征图和增强分类结果;基于增强特征图、原始下采样图和增强下采样图构建数据增强损失,基于原始分类结果、增强分类结果和类别标签构建分类损失,将数据增强损失和分类损失之和作为目标损失以训练初始分类模型,得到目标分类模型。本申请能通过数据增强模块增加随机扰动从而提升分类模型的泛化性。性。性。

【技术实现步骤摘要】
基于数据增强的分类模型泛化性的提升方法及相关设备


[0001]本申请涉及人工智能及数字医疗
,尤其涉及一种基于数据增强的分类模型泛化性的提升方法及相关设备。

技术介绍

[0002]分类模型已经广泛应用于金融、数字医疗等多个行业和领域,但在使用分类模型解决实际问题时,最大的问题在于分类模型的泛化性较差,即在训练数据集上训练完成的分类模型在除训练数据集之外的其他数据集上难以获得准确的分类结果,例如,利用第一医学图像数据集训练的医学图像分类模型,在应用于第二医学图像数据集的分类任务时,通常难以获得较准确的分类结果,故需要提升分类模型泛化性,以使得分类模型在解决实际问题可以保持较高的分类精度。
[0003]目前,经常是先对已有的原始数据(例如上述第一医学图像数据集)进行数据增强,并利用增强后的数据和原始数据对分类模型进行训练,以提升分类模型的泛化性,其中,数据增强的常用方法包括:方法一:对数据进行无方向或者随机的数据增强,方法二:基于生成对抗模型直接生成增强后的数据,然而,方法一不能寻找到对分类最有益的增强数据,且方法二需要消耗大量训练资源,上述两种方法均不能快速有效的提升分类模型泛化性。

技术实现思路

[0004]鉴于以上内容,有必要提出一种基于数据增强的分类模型泛化性的提升方法及相关设备,以解决如何快速有效的提升分类模型泛化性这一技术问题,其中,相关设备包括基于数据增强的分类模型泛化性的提升装置、电子设备及存储介质。
[0005]本申请提供基于数据增强的分类模型泛化性的提升方法,所述方法包括:
[0006]S10,搭建初始分类模型,所述初始分类模型包括原始分支和数据增强分支,所述原始分支包括原始编码器和原始分类器,所述数据增强分支包括增强编码器和增强分类器,所述增强编码器为包括至少一个数据增强模块的原始编码器;
[0007]S11,采集多个带有类别标签的待分类数据以作为训练数据集;
[0008]S12,从所述训练数据集中选取任意待分类数据输入所述原始分支,所述原始编码器对所述待分类数据进行多次下采样以获取至少一个原始下采样图,并将最后一个原始下采样图输入所述原始分类器,得到原始分类结果;
[0009]S13,将所述待分类数据输入所述数据增强分支,所述增强编码器对所述待分类数据进行多次下采样以获取至少一个增强下采样图,并基于所述数据增强模块对预先选定的增强下采样图进行随机扰动以获取至少一个增强特征图,将最后一个增强下采样图输入所述增强分类器,得到增强分类结果;
[0010]S14,基于所述增强特征图、所述原始下采样图和所述增强下采样图构建数据增强损失,基于所述原始分类结果、所述增强分类结果和所述待分类数据的类别标签构建分类
损失,将所述数据增强损失和所述分类损失之和作为目标损失;
[0011]S15,依据梯度下降法更新所述初始分类模型以完成一次迭代训练,返回步骤S12,直到所述目标损失的取值小于预设取值时,得到目标分类模型。
[0012]在一些实施例中,所述原始编码器包括多个卷积层,所述原始分类器包括多个全连接层;
[0013]在所述原始编码器的所述多个卷积层中至少一个预设位置处插入数据增强模块以获取所述增强编码器,所述增强分类器包括多个全连接层,所述数据增强模块用于对卷积层输出的下采样图添加随机扰动以实现数据增强;
[0014]所述原始分类器与所述增强分类器的结构相同或不同,且当所述原始分类器与所述增强分类器的结构相同时,所述原始分类器与所述增强分类器的结构网络参数共享或不共享。
[0015]在一些实施例中,所述增强编码器对所述待分类数据进行多次下采样以获取至少一个增强下采样图,并基于所述数据增强模块对预先选定的增强下采样图进行随机扰动以获取至少一个增强特征图,包括:
[0016]对于所述增强编码器中的任一卷积层,基于所述卷积层对输入数据执行下采样处理得到增强下采样图,所述增强下采样图的尺寸小于或等于所述输入数据的尺寸;
[0017]当该卷积层的末端与数据增强模块直接相连时,表示所述增强下采样图为预先选定的增强下采样图,则基于所述数据增强模块对该增强下采样图进行随机扰动以获取增强特征图;
[0018]判断该卷积层是否为所述增强编码器中最后一个卷积层;
[0019]若该卷积层不为所述增强编码器中最后一个卷积层,则将所述增强下采样图或所述增强特征图作为该卷积层的下一个卷积层的输入数据;
[0020]若该卷积层为所述增强编码器中最后一个卷积层,则将所述增强下采样图作为最后一个增强下采样图。
[0021]在一些实施例中,所述基于所述数据增强模块对该增强下采样图进行随机扰动以获取增强特征图包括:
[0022]获取所述增强下采样图中所有特征值;
[0023]依据随机扰动公式对所述增强下采样图中每个特征值进行随机扰动,得到各特征值对应的扰动值,所述随机扰动公式满足关系式:
[0024][0025]其中,x为所述增强下采样图中任一特征值,μ和σ分别所有特征值的均值和方差,γ和β为所述数据增强模块的可学习参数,LDP(x)为特征值x对应的扰动值;
[0026]将所述增强下采样图中所有特征值替换为对应的扰动值,得到该增强下采样图对应的增强特征图。
[0027]在一些实施例中,所述增强下采样图和所述原始下采样图一一对应,相对应的增强下采样图和原始下采样图尺寸相同,所述基于所述增强特征图、所述原始下采样图和所述增强下采样图构建数据增强损失包括:
[0028]对于任意增强特征图,获取所述增强特征图对应的增强下采样图,并将所述增强
下采样图在所述原始分支中对应的原始下采样图作为目标下采样图;
[0029]分别计算所述目标下采样图和所述增强特征图的格拉姆矩阵,并依据格拉姆矩阵计算所述增强特征图的风格偏差,所述风格偏差满足关系式:
[0030][0031]其中,为第i个增强特征图,f
i
为第i个增强特征图对应的目标下采样图,G(f
i
)和分别表示f
i
和的格拉姆矩阵,表示计算表示计算的F范数,为第i个增强特征图的风格偏差;
[0032]基于最后一个原始下采样图和最后一个增强下采样图计算语义偏差,所述语义偏差满足关系式:
[0033][0034]其中,f
*
和分别为最后一个原始下采样图和最后一个增强下采样图,表示计算的2范数,L
sem
为所述语义偏差;
[0035]基于所述语义偏差和每个增强特征图的所述风格偏差构建数据增强损失,所述数据增强损失满足关系式:
[0036][0037]其中,L
sem
为所述语义偏差,K为所有增强特征图的数量,为第i个增强特征图的风格偏差,λ
sem
和λ
spe
为大于0的预设系数,L
zq
为所述数据增强损失。
[003本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于数据增强的分类模型泛化性的提升方法,其特征在于,所述方法包括:S10,搭建初始分类模型,所述初始分类模型包括原始分支和数据增强分支,所述原始分支包括原始编码器和原始分类器,所述数据增强分支包括增强编码器和增强分类器,所述增强编码器为包括至少一个数据增强模块的原始编码器;S11,采集多个带有类别标签的待分类数据以作为训练数据集;S12,从所述训练数据集中选取任意待分类数据输入所述原始分支,所述原始编码器对所述待分类数据进行多次下采样以获取至少一个原始下采样图,并将最后一个原始下采样图输入所述原始分类器,得到原始分类结果;S13,将所述待分类数据输入所述数据增强分支,所述增强编码器对所述待分类数据进行多次下采样以获取至少一个增强下采样图,并基于所述数据增强模块对预先选定的增强下采样图进行随机扰动以获取至少一个增强特征图,将最后一个增强下采样图输入所述增强分类器,得到增强分类结果;S14,基于所述增强特征图、所述原始下采样图和所述增强下采样图构建数据增强损失,基于所述原始分类结果、所述增强分类结果和所述待分类数据的类别标签构建分类损失,将所述数据增强损失和所述分类损失之和作为目标损失;S15,依据梯度下降法更新所述初始分类模型以完成一次迭代训练,返回步骤S12,直到所述目标损失的取值小于预设取值时,得到目标分类模型。2.如权利要求1所述的基于数据增强的分类模型泛化性的提升方法,其特征在于,所述原始编码器包括多个卷积层,所述原始分类器包括多个全连接层;在所述原始编码器的所述多个卷积层中至少一个预设位置处插入数据增强模块以获取所述增强编码器,所述增强分类器包括多个全连接层,所述数据增强模块用于对卷积层输出的下采样图添加随机扰动以实现数据增强;所述原始分类器与所述增强分类器的结构相同或不同,且当所述原始分类器与所述增强分类器的结构相同时,所述原始分类器与所述增强分类器的结构网络参数共享或不共享。3.如权利要求2所述的基于数据增强的分类模型泛化性的提升方法,其特征在于,所述增强编码器对所述待分类数据进行多次下采样以获取至少一个增强下采样图,并基于所述数据增强模块对预先选定的增强下采样图进行随机扰动以获取至少一个增强特征图,包括:对于所述增强编码器中的任一卷积层,基于所述卷积层对输入数据执行下采样处理得到增强下采样图,所述增强下采样图的尺寸小于或等于所述输入数据的尺寸;当该卷积层的末端与数据增强模块直接相连时,表示所述增强下采样图为预先选定的增强下采样图,则基于所述数据增强模块对该增强下采样图进行随机扰动以获取增强特征图;判断该卷积层是否为所述增强编码器中最后一个卷积层;若该卷积层不为所述增强编码器中最后一个卷积层,则将所述增强下采样图或所述增强特征图作为该卷积层的下一个卷积层的输入数据;若该卷积层为所述增强编码器中最后一个卷积层,则将所述增强下采样图作为最后一个增强下采样图。
4.如权利要求3所述的基于数据增强的分类模型泛化性的提升方法,其特征在于,所述基于所述数据增强模块对该增强下采样图进行随机扰动以获取增强特征图包括:获取所述增强下采样图中所有特征值;依据随机扰动公式对所述增强下采样图中每个特征值进行随机扰动,得到各特征值对应的扰动值,所述随机扰动公式满足关系式:其中,x为所述增强下采样图中任一特征值,μ和σ分别为所有特征值的均值和方差,γ和β为所述数据增强模块的可学习参数,LDP(x)为特征值x对应的扰动值;将所述增强下采样图中所有特征值替换为对应的扰动值,得到该增强下采样图对应的增强特征图。5.如权利要求1所述的基于数据增强的分类模型泛化性的提升方法,其特征在于,所述增强下采样图和所述原始下采样图一一对应,相对应的增强下采样图和原始下采样图尺寸相同,所述基于所述增强特征图、所述原始下采样图和所述增强下采样图构建数据增强损失包括:对于任意增强特征图,获取所述增强特征图对应的增强下采样图,并将所...

【专利技术属性】
技术研发人员:郑智琳张道安高良心黄凌云
申请(专利权)人:平安科技深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1