学生模型训练方法、装置、设备及存储介质制造方法及图纸

技术编号:35147100 阅读:16 留言:0更新日期:2022-10-05 10:24
本发明专利技术公开了一种学生模型训练方法、装置、设备及存储介质。该方法包括:在一次迭代周期中,将样本图像分别输入至学生网络和相应的教师网络;其中,学生网络和所述教师网络用于对样本图像进行目标检测;根据样本图像的标签数据和学生网络的预测结果,确定第一损失值,以及根据样本图像的标签数据和教师网络的预测结果,确定第二损失值;根据第一损失值和第二损失值,确定知识蒸馏对学生网络的第一蒸馏作用程度;基于第一蒸馏作用程度,确定学生网络对应的目标学生损失值;根据目标学生损失值调整所述学生网络的网络参数。本发明专利技术实施例提高了训练得到的学生模型的检测准确度。高了训练得到的学生模型的检测准确度。高了训练得到的学生模型的检测准确度。

【技术实现步骤摘要】
学生模型训练方法、装置、设备及存储介质


[0001]本专利技术涉及深度学习
,尤其涉及一种学生模型训练方法、装置、设备及存储介质。

技术介绍

[0002]随着深度学习的不断发展,目标检测技术在近几年飞速发展,目标检测模型不断复杂。目标检测模型在检测性能方面不断提高,但检测过程的耗时不断增加。基于上述问题,提出了基于蒸馏的方法,用学生模型(较为简单的目标检测模型)通过蒸馏方法学习教师模型(复杂的目标检测模型)的网络黑盒知识。
[0003]但现有技术中学生模型通过蒸馏方法学习教师模型的过程中,存在学生模型的检测能力受限于教师模型的检测能力,导致学生模型对教师模型完全依赖,限制了学习模型在训练过程中准确度的提高。

技术实现思路

[0004]本专利技术提供了一种学生模型训练方法、装置、设备及存储介质,以提高训练得到的学生模型的检测准确度。
[0005]根据本专利技术的一方面,提供了一种学生模型训练方法,该方法包括:
[0006]在一次迭代周期中,将样本图像分别输入至学生网络和相应的教师网络;其中,所述学生网络和所述教师网络用于对所述样本图像进行目标检测;
[0007]根据所述样本图像的标签数据和所述学生网络的预测结果,确定第一损失值,以及根据所述样本图像的标签数据和所述教师网络的预测结果,确定第二损失值;
[0008]根据所述第一损失值和所述第二损失值,确定知识蒸馏对所述学生网络的第一蒸馏作用程度;
[0009]基于所述第一蒸馏作用程度,确定所述学生网络对应的目标学生损失值;
[0010]根据所述目标学生损失值调整所述学生网络的网络参数。
[0011]根据本专利技术的另一方面,提供了一种学生模型训练装置,该装置包括:
[0012]样本图像输入模块,用于在一次迭代周期中,将样本图像分别输入至学生网络和相应的教师网络;其中,所述学生网络和所述教师网络用于对所述样本图像进行目标检测;
[0013]损失值确定模块,用于根据所述样本图像的标签数据和所述学生网络的预测结果,确定第一损失值,以及根据所述样本图像的标签数据和所述教师网络的预测结果,确定第二损失值;
[0014]第一作用程度确定模块,用于根据所述第一损失值和所述第二损失值,确定知识蒸馏对所述学生网络的第一蒸馏作用程度;
[0015]目标学生损失值确定模块,用于基于所述第一蒸馏作用程度,确定所述学生网络对应的目标学生损失值;
[0016]网络参数调整模块,用于根据所述目标学生损失值调整所述学生网络的网络参
数。
[0017]根据本专利技术的另一方面,提供了一种电子设备,所述电子设备包括:
[0018]至少一个处理器;以及
[0019]与所述至少一个处理器通信连接的存储器;其中,
[0020]所述存储器存储有可被所述至少一个处理器执行的计算机程序,所述计算机程序被所述至少一个处理器执行,以使所述至少一个处理器能够执行本专利技术任一实施例所述的学生模型训练方法。
[0021]根据本专利技术的另一方面,提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使处理器执行时实现本专利技术任一实施例所述的学生模型训练方法。
[0022]本实施例方案根据样本图像的标签数据和学生网络的预测结果,确定第一损失值,以及根据样本图像的标签数据和教师网络的预测结果,确定第二损失值;根据所述第一损失值和第二损失值,确定知识蒸馏对学生网络的第一蒸馏作用程度;基于第一蒸馏作用程度,确定学生网络对应的目标学生损失值;根据目标学生损失值调整学生网络的网络参数。上述方案通过考虑在学生网络和教师网络的训练过程中,教师网络对学生网络知识蒸馏作用程度,实现了在训练过程中动态调整学生网络的目标学生损失值,从而使得学生网络的检测能力不完全依赖于教师网络,提高了对训练得到的学生模型的检测准确度。
[0023]应当理解,本部分所描述的内容并非旨在标识本专利技术的实施例的关键或重要特征,也不用于限制本专利技术的范围。本专利技术的其它特征将通过以下的说明书而变得容易理解。
附图说明
[0024]为了更清楚地说明本专利技术实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本专利技术的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
[0025]图1是根据本专利技术实施例一提供的一种学生模型训练方法的流程图;
[0026]图2是根据本专利技术实施例二提供的一种学生模型训练方法的流程图;
[0027]图3是根据本专利技术实施例三提供的一种学生模型训练方法的流程图;
[0028]图4是根据本专利技术实施例四提供的一种学生模型训练装置的结构示意图;
[0029]图5是实现本专利技术实施例的学生模型训练方法的电子设备的结构示意图。
具体实施方式
[0030]为了使本
的人员更好地理解本专利技术方案,下面将结合本专利技术实施例中的附图,对本专利技术实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本专利技术一部分的实施例,而不是全部的实施例。基于本专利技术中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本专利技术保护的范围。
[0031]需要说明的是,本专利技术的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用
的数据在适当情况下可以互换,以便这里描述的本专利技术的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
[0032]实施例一
[0033]图1为本专利技术实施例一提供了一种学生模型训练方法的流程图,本实施例可适用于基于知识蒸馏学生模型和教师模型共同学习的情况,该方法可以由学生模型训练装置来执行,该学生模型训练装置可以采用硬件和/或软件的形式实现,该学生模型训练装置可配置于电子设备中。如图1所示,该方法包括:
[0034]S110、在一次迭代周期中,将样本图像分别输入至学生网络和相应的教师网络;其中,学生网络和教师网络用于对样本图像进行目标检测。
[0035]其中,学生网络可以是较为简单的目标检测模型,教师网络可以是较为复杂的目标检测模型,具体可以由相关技术人员进行预先设定。例如,学生网络可以是ResNet

34,相应的,教师网络可以是ResNet

50。其中,教师网络可以是采用已训练好的教师模型对教师网络进行权重赋值后的网络模型;学生网络可以是未进行权重赋值的网络模型。在知识蒸馏的过程中,教师网络对本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种学生模型训练方法,其特征在于,包括:在一次迭代周期中,将样本图像分别输入至学生网络和相应的教师网络;其中,所述学生网络和所述教师网络用于对所述样本图像进行目标检测;根据所述样本图像的标签数据和所述学生网络的预测结果,确定第一损失值,以及根据所述样本图像的标签数据和所述教师网络的预测结果,确定第二损失值;根据所述第一损失值和所述第二损失值,确定知识蒸馏对所述学生网络的第一蒸馏作用程度;基于所述第一蒸馏作用程度,确定所述学生网络对应的目标学生损失值;根据所述目标学生损失值调整所述学生网络的网络参数。2.根据权利要求1所述的方法,其特征在于,所述基于所述第一蒸馏作用程度,确定所述学生网络对应的目标学生损失值,包括:根据所述第一蒸馏作用程度,确定所述学生网络对应的第一蒸馏权重值;确定所述学生网络训练产生的第一蒸馏损失值;根据所述第一损失值、所述第一蒸馏损失值和所述第一蒸馏权重值,确定所述学生网络对应的目标学生损失值。3.根据权利要求2所述的方法,其特征在于,所述根据所述第一损失值、所述第一蒸馏损失值和所述第一蒸馏权重值,确定所述学生网络对应的目标损失值,包括:基于所述第一蒸馏权重值,调整所述第一蒸馏损失值;根据所述第一损失值与调整后的第一蒸馏损失值的加和,得到所述学生网络对应的目标学生损失值。4.根据权利要求1所述的方法,其特征在于,所述学生网络的预测结果包括第一分类预测值和第一回归预测值;相应的,根据所述样本图像的标签数据和所述学生网络的预测结果,确定第一损失值,包括:根据所述第一分类预测值和所述标签数据中的类别标签值,确定所述第一分类损失值;以及,根据所述第一回归预测值和所述标签数据中的位置标签值,确定所述第一回归损失值;根据所述第一分类损失值和所述第一回归损失值,生成所述第一损失值。5.根据权利要求1所述的方法,其特征在于,所述教师网络的预测结果包括第二分类预测值和第二回归预测值;相应的,根据所述样本图像的标签数据和所述教师网络的预测结果,确定第二损失值,包括:根据所述第二分类预测值和所述标签数据中的类别标签值,确定所述第二分类损失值;以及,根据所述第二回归预测值和所述标签数据中的位置标签值,确定所述第二回归损失值;根据所述第二分类损失值和所述第二回归损失值,生成所述第二损失值。6.根据权利要求2所述的方法,其特征在于,所述确定所述学生网络训练产生的第一蒸
馏损失值,包括:确定所述学生网络的网络层进行特征提取后得到的第一预测特征值,以及确定所述教师网络的网络层进行特征...

【专利技术属性】
技术研发人员:李林超王威周凯张腾飞
申请(专利权)人:浙江啄云智能科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1