模型训练方法、图像分类方法、装置、设备及存储介质制造方法及图纸

技术编号:39065537 阅读:15 留言:0更新日期:2023-10-12 19:58
本申请实施例公开了一种模型训练方法、图像分类方法、装置、设备及存储介质,其中,所述模型训练方法包括:获取样本数据集;其中,所述样本数据集包括原始数据集中的至少一个原始样本;基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型;其中,所述目标损失函数至少包括差异抑制损失;所述差异抑制损失用于表征针对所述样本数据集中同一样本,所述第二分类模型与第一分类模型分别对应的类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的。训练得到的。训练得到的。

【技术实现步骤摘要】
模型训练方法、图像分类方法、装置、设备及存储介质


[0001]本申请涉及但不限于人工智能
,尤其涉及一种模型训练方法、图像分类方法、装置、设备及存储介质。

技术介绍

[0002]深度学习广泛应用于工业视觉,其在复杂场景的表现明显优于传统图像处理算法。但是随着训练数据的不断增加,模型在不同数据集上迁移不可避免地会存在遗忘性问题,即在新数据集上训练深度学习分类模型,训练得到的新深度学习分类模型虽能够精确识别新数据特征,遗忘了在旧数据上学习到的知识的问题。
[0003]目前为解决模型遗忘性问题,主要采用蒸馏方法和模型组合方法。蒸馏往往会牺牲模型的准确率,其最主要的原因是:当旧模型的输出和新模型的输出偏差非常大的时候,通过蒸馏方法强行让他们一致,往往会得到负面的结果。模型组合的方式则会增加推理成本,推理时间变长。

技术实现思路

[0004]有鉴于此,本申请实施例至少提供一种模型训练方法、图像分类方法、装置、设备及存储介质。
[0005]本申请实施例的技术方案是这样实现的:第一方面,本申请实施例提供一种模型训练方法,所述方法包括:获取样本数据集;其中,所述样本数据集包括原始数据集中的至少一个原始样本;基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型;其中,所述目标损失函数包括差异抑制损失;所述差异抑制损失用于表征针对所述样本数据集中同一样本,所述第二分类模型与第一分类模型分别对应的类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的。
[0006]第二方面,本申请实施例提供一种图像分类方法,所述方法包括:获取待分类的图像数据集;通过已训练的图像分类模型对所述图像数据集进行分类,得到所述图像数据集中每一图像的分类结果;其中,所述图像分类模型是基于上述第一方面所述的模型训练方法进行训练得到的。
[0007]第三方面,本申请实施例提供一种模型训练装置,所述装置包括:样本获取模块,用于获取样本数据集;其中,所述样本数据集包括原始数据集中的至少一个原始样本;模型训练模块,用于基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型;其中,所述目标损失函数包括差异抑制损失;所述差异抑制损失用于表征针对所述样本数据集中同一样本,所述第二分类模型与第一分类模型分别对应的类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的。
[0008]第四方面,本申请实施例提供一种图像分类装置,所述装置包括:数据获取模块,用于获取待分类的图像数据集;图像分类模块,用于通过已训练的图像分类模型对所述图像数据集进行分类,得到所述图像数据集中每一图像的分类结果;其中,所述图像分类模型是基于上述第一方面所述的模型训练方法进行训练得到的。
[0009]第五方面,本申请实施例提供一种计算机设备,包括存储器和处理器,所述存储器存储有可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述第一方面或第二方面方法中的部分或全部步骤。
[0010]第六方面,本申请实施例提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现上述第一方面或第二方面方法中的部分或全部步骤。
[0011]本申请实施例中,在利用原始数据集训练得到第一分类模型的基础上,获取包括原始数据集中的至少一个原始样本的样本数据集对第二分类模型进行训练,在训练过程中通过计算第一分类模型和第二分类模型针对同一样本输出的类别得分的差异得到差异抑制损失,通过在损失函数中增加差异抑制损失,惩罚新旧模型对于同一样本输出的类别得分变化,从而使得第二分类模型在精确识别新数据特征的同时,保持在旧数据上的识别精度。
[0012]应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,而非限制本公开的技术方案。
附图说明
[0013]此处的附图被并入说明书中并构成本说明书的一部分,这些附图示出了符合本申请的实施例,并与说明书一起用于说明本申请的技术方案。
[0014]图1为本申请实施例提供的模型训练方法的一种流程示意图;图2为本申请实施例提供的模型训练方法的另一种流程示意图;图3为本申请实施例提供的模型训练方法的再一种流程示意图;图4为本申请实施例提供的图像分类方法的可选的流程示意图;图5为本申请实施例提供的一种模型训练装置的组成结构示意图;图6为本申请实施例提供的一种图像分类装置的组成结构示意图;图7为本申请实施例提供的一种计算机设备的硬件实体示意图。
具体实施方式
[0015]为了使本申请的目的、技术方案和优点更加清楚,下面结合附图和实施例对本申请的技术方案进一步详细阐述,所描述的实施例不应视为对本申请的限制,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本申请保护的范围。
[0016]在以下的描述中,涉及到“一些实施例”,其描述了所有可能实施例的子集,但是可以理解,“一些实施例”可以是所有可能实施例的相同子集或不同子集,并且可以在不冲突的情况下相互结合。
[0017]所涉及的术语“第一/第二/第三”仅仅是区别类似的对象,不代表针对对象的特定
排序,可以理解地,“第一/第二/第三”在允许的情况下可以互换特定的顺序或先后次序,以使这里描述的本申请实施例能够以除了在这里图示或描述的以外的顺序实施。
[0018]除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的
的技术人员通常理解的含义相同。本文中所使用的术语只是为了描述本申请的目的,不是旨在限制本申请。
[0019]在对本申请实施例进行进一步详细说明之前,先对本申请实施例中涉及的名词和术语进行说明,本申请实施例中涉及的名词和术语适用于如下的解释。
[0020]模型蒸馏,旨在把一个大模型或者多个模型全体学到的知识迁移到另一个轻量级单模型上,方便部署。即用小模型去学习大模型的预测结果,而不是直接学习训练集中的标签(label)。
[0021]模型组合(bagging),指通过结合几个模型降低泛化误差的技术。主要想法是分别训练几个不同的模型,然后让所有模型表决测试样例的输出。
[0022]模型的遗忘性问题,指的是在新数据集上训练深度学习分类模型,训练得到的新深度学习分类模型虽能够精确识别新数据特征,遗忘了在旧数据上学习到的知识的问题。
[0023]logit(指类别得分)为深度学习中一种表示模型输出的方式,通常是指模型输出的未经过softmax(归一化)函数处理的原始数值,也就是各个类别的置信度得分(score),不一定归一化或具有概率(probability)的意义。
[0024]在训练过程中,模型通常使用logits值作为损失函数的输入,然后通过softmax函数将其转化为概率分布,最终用于计算损失。在测试或推理阶段,通常使用softmax函数将logits转化为概率分布,以便根据得分最高的类别做出预测。
[002本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,所述方法包括:获取样本数据集;其中,所述样本数据集包括原始数据集中的至少一个原始样本;基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型;其中,所述目标损失函数至少包括差异抑制损失;所述差异抑制损失用于表征针对所述样本数据集中同一样本,所述第二分类模型与第一分类模型分别对应的类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的。2.根据权利要求1所述的方法,其特征在于,所述第二分类模型至少包括全连接层,所述基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型,包括:将所述样本数据集中的目标样本输入所述第二分类模型,得到所述全连接层输出的第二类别得分;基于所述第二类别得分,利用所述目标损失函数确定所述第二分类模型的学习损失值;基于所述学习损失值对所述第二分类模型的网络参数进行反向传播更新;响应于满足收敛条件,确定所述第二分类模型为所述图像分类模型。3.根据权利要求2所述的方法,其特征在于,所述基于所述第二类别得分,基于所述目标损失函数确定所述第二分类模型的学习损失值,包括:确定所述第一分类模型针对所述目标样本输出的第一类别得分;基于所述目标样本对应的所述第二类别得分和所述第一类别得分,确定所述差异抑制损失;基于所述差异抑制损失,确定所述学习损失值。4.根据权利要求2所述的方法,其特征在于,所述目标损失函数还包括拟合损失,所述拟合损失用于表征所述第二分类模型的预测类别与样本标签之间的差异;所述第二分类模型还包括所述全连接层之后的归一化层;所述基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型,还包括:将所述样本数据集中的目标样本输入所述第二分类模型,得到所述归一化层输出的第二预测类别;其中,所述第二预测类别是所述归一化层对所述第二类别得分处理得到的;基于所述第二类别得分和所述第二预测类别,利用所述目标损失函数确定所述第二分类模型的学习损失值。5.根据权利要求4所述的方法,其特征在于,所述基于所述第二类别得分和所述第二预测类别,基于所述目标损失函数确定所述第二分类模型的学习损失值,包括:确定所述第一分类模型针对所述目标样本输出的第一类别得分;基于所述目标样本对应的所述第二类别得分和所述第一类别得分,确定所述差异抑制损失;基于所述第二预测类别和所述样本数据集的样本标签,确定所述拟合损失;对所述拟合损失和所述差异抑制损失进行加权求和,得到所述学习损失值。6.根据权利要求5所述的方法,其特征在于,所述基于所述目标样本对应的所述第二类
别得分和所述第一类别得分,确定所述差异抑制损失,包括:针对所述目标样本,确定所述第二类别得分和所述第一类别得分之间的变化距离;基于所...

【专利技术属性】
技术研发人员:请求不公布姓名
申请(专利权)人:摩尔线程智能科技北京有限责任公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1