模型训练方法、装置、电子设备和计算机可读存储介质制造方法及图纸

技术编号:38315428 阅读:10 留言:0更新日期:2023-07-29 08:57
本申请的实施例提供了一种模型训练方法、装置、电子设备和计算机可读存储介质,该方法包括:获取图像样本;利用教师模型得到图像样本的第一图像特征、第一定位特征以及第一分类特征;利用学生模型得到图像样本的第二图像特征、第二定位特征以及第二分类特征;基于第一图像特征、第一定位特征、第一分类特征、第二图像特征、第二定位特征以及第二分类特征构建损失函数;根据损失函数,采用边训练边蒸馏的方式训练学生模型。本申请中学生模型可以同时学习教师模型logits信息和feature信息,丰富了学生模型的信息量,并且logits信息和feature信息同时学习,能够相互补充、相互促进,从而能够优化学生模型的训练,提升学生模型的准确性。性。性。

【技术实现步骤摘要】
模型训练方法、装置、电子设备和计算机可读存储介质


[0001]本申请的实施例涉及机器学习
,尤其涉及一种模型训练方法、装置、电子设备和计算机可读存储介质。

技术介绍

[0002]知识蒸馏是将一个网络的知识转移到另一个网络,两个网络可以是同构或者异构。其做法是先训练一个teacher模型,然后使用这个teacher模型的输出和数据的真实标签去训练student模型。知识蒸馏可以用来将网络从大网络转化成小网络,并保留接近于大网络的性能,以此解决模型在边缘端的部署硬件不足的问题。
[0003]在目标检测场景下,需要同时完成目标分类和目标定位,其特征图相当于树干,包含了树叶中的所有信息,下游的各个任务相当于树叶,logits蒸馏能将teacher树叶的信息直接传递给student,feature蒸馏能将teacher树干的信息传递给student,但是由于树干的信息是高纬度抽象信息,和logits蒸馏直接传递的信息又有所不同,基于此,希望将两种训练模式应用到目标检测训练中来,以提升目标检测模型训练精度。

技术实现思路

[0004]未解决上述技术问题,本申请的实施例提供了一种模型训练方法、装置、电子设备和计算机可读存储介质。
[0005]在本申请的第一方面,提供了一种模型训练方法,包括:获取图像样本;利用教师模型得到所述图像样本的第一图像特征、第一定位特征以及第一分类特征;利用学生模型得到所述图像样本的第二图像特征、第二定位特征以及第二分类特征;基于所述第一图像特征、所述第一定位特征、所述第一分类特征、所述第二图像特征、所述第二定位特征以及所述第二分类特征构建损失函数;根据所述损失函数,采用边训练边蒸馏的方式训练所述学生模型。
[0006]在一种可能的实现方式中,利用教师模型得到所述图像样本的第一图像特征、第一定位特征以及第一分类特征,包括:将所述图像样本输入所述教师模型,得到所述教师模型的fpn输出的特征作为所述第一图像特征、所述教师模型的定位head输出的特征作为第一定位特征、所述教师模型的分类head输出的特征作为所述第一分类特征。
[0007]在一种可能的实现方式中,利用学生模型得到所述图像样本的第二图像特征、第二定位特征以及第二分类特征,包括:将所述图像样本输入所述学生模型,得到所述学生模型的fpn输出的特征作为所述第二图像特征、所述学生模型的定位head输出的特征作为第二定位特征、所述学生模型
的分类head输出的特征作为所述第二分类特征。
[0008]在一种可能的实现方式中,基于所述第一图像特征、所述第一定位特征、所述第一分类特征、所述第二图像特征、所述第二定位特征以及所述第二分类特征构建损失函数,包括:根据所述第一图像特征和所述第二图像特征,采用l2损失函数,得到第一损失函数:,其中,F
t
为第一图像特征,F
s
为第二图像特征;根据所述第一定位特征和所述第二定位特征,采用交叉熵损失函数,得到第二损失函数:,其中,为第一定位特征,为第二定位特征;根据所述第一分类特征和所述第二分类特征,采用交叉熵损失函数,得到第三损失函数:,其中,为第一分类特征,为第二分类特征;根据所述第一损失函数、所述第二损失函数以及所述第三损失函数,得到所述损失函数:,其中,为学生模型训练时的定位损失函数,为学生模型训练时的分类损失函数。
[0009]在本申请的第二方面,提供了一种模型训练装置,包括:样本获取模块,用于获取图像样本;第一训练模块,用于利用教师模型得到所述图像样本的第一图像特征、第一定位特征以及第一分类特征;第二训练模块,用于利用学生模型得到所述图像样本的第二图像特征、第二定位特征以及第二分类特征;第三训练模块,用于基于所述第一图像特征、所述第一定位特征、所述第一分类特征、所述第二图像特征、所述第二定位特征以及所述第二分类特征构建损失函数;第四训练模块,用于根据所述损失函数,采用边训练边蒸馏的方式训练所述学生模型。
[0010]在一种可能的实现方式中,所述第一训练模块具体用于:将所述图像样本输入所述教师模型,得到所述教师模型的fpn输出的特征作为所述第一图像特征、所述教师模型的定位head输出的特征作为第一定位特征、所述教师模型
的分类head输出的特征作为所述第一分类特征。
[0011]在一种可能的实现方式中,所述第二训练模块具体用于:将所述图像样本输入所述学生模型,得到所述学生模型的fpn输出的特征作为所述第二图像特征、所述学生模型的定位head输出的特征作为第二定位特征、所述学生模型的分类head输出的特征作为所述第二分类特征。
[0012]在一种可能的实现方式中,所述第三训练模块具体用于:根据所述第一图像特征和所述第二图像特征,采用l2损失函数,得到第一损失函数:,其中,F
t
为第一图像特征,F
s
为第二图像特征;根据所述第一定位特征和所述第二定位特征,采用交叉熵损失函数,得到第二损失函数:,其中,为第一定位特征,为第二定位特征;根据所述第一分类特征和所述第二分类特征,采用交叉熵损失函数,得到第三损失函数:,其中,为第一分类特征,为第二分类特征;根据所述第一损失函数、所述第二损失函数以及所述第三损失函数,得到所述损失函数:,其中,为学生模型训练时的定位损失函数,为学生模型训练时的分类损失函数。
[0013]在本申请的第三方面,提供了一种电子设备,包括存储器和处理器,所述存储器上存储有计算机程序,所述处理器执行所述计算机程序时实现如第一方面中任一项所述的方法。
[0014]在本申请的第四方面,提供了一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现如第一方面中任一项所述的方法。
[0015]在本申请实施例提供的模型训练方法、装置、电子设备和计算机可读存储介质中,获取图像样本,利用教师模型得到所述图像样本的第一图像特征、第一定位特征以及第一分类特征,利用学生模型得到所述图像样本的第二图像特征、第二定位特征以及第二分类特征,基于第一图像特征、第一定位特征、第一分类特征、第二图像特征、第二定位特征以及第二分类特征构建损失函数,根据损失函数,采用边训练边蒸馏的方式训练学生模型,由此可知,本申请中学生模型可以同时学习教师模型logits信息和feature信息,丰富了学生模
型的信息量,并且logits信息和feature信息同时学习,能够相互补充、相互促进,从而能够优化学生模型的训练,提升学生模型的准确性。同时,采用边训练边蒸馏的方式来训练学生模型,在优化蒸馏损失的同时一起优化学生模型的训练损失,使得训练和蒸馏互相补充,能够提升学生模型的上限,使得学生模型大于等于教师模型。
[0016]应当理解,
技术实现思路
部分中所描述的内容并非旨在限定本申请的实施例的关键或重要特征,亦非用于限制本申请的范围。本申请的其它特征将通过以下的描述变得容易理解。
附图说明
[0017]结合附图并参考以下详细说明,本申请各实施例本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,包括:获取图像样本;利用教师模型得到所述图像样本的第一图像特征、第一定位特征以及第一分类特征;利用学生模型得到所述图像样本的第二图像特征、第二定位特征以及第二分类特征;基于所述第一图像特征、所述第一定位特征、所述第一分类特征、所述第二图像特征、所述第二定位特征以及所述第二分类特征构建损失函数;根据所述损失函数,采用边训练边蒸馏的方式训练所述学生模型。2.根据权利要求1所述的方法,其特征在于,利用教师模型得到所述图像样本的第一图像特征、第一定位特征以及第一分类特征,包括:将所述图像样本输入所述教师模型,得到所述教师模型的fpn输出的特征作为所述第一图像特征、所述教师模型的定位head输出的特征作为所述第一定位特征、所述教师模型的分类head输出的特征作为所述第一分类特征。3.根据权利要求1所述的方法,其特征在于,利用学生模型得到所述图像样本的第二图像特征、第二定位特征以及第二分类特征,包括:将所述图像样本输入所述学生模型,得到所述学生模型的fpn输出的特征作为所述第二图像特征、所述学生模型的定位head输出的特征作为第二定位特征、所述学生模型的分类head输出的特征作为所述第二分类特征。4.根据权利要求1所述的方法,其特征在于,基于所述第一图像特征、所述第一定位特征、所述第一分类特征、所述第二图像特征、所述第二定位特征以及所述第二分类特征构建损失函数,包括:根据所述第一图像特征和所述第二图像特征,采用l2损失函数,得到第一损失函数:,其中,F
t
为第一图像特征,F
s
为第二图像特征;根据所述第一定位特征和所述第二定位特征,采用交叉熵损失函数,得到第二损失函数:,其中,为第一定位特征,为第二定位特征;根据所述第一分类特征和所述第二分类特征,采用交叉熵损失函数,得到第三损失函数:,其中,为第一分类特征,为第二分类特征;根据所述第一损失函数、所述第二损失函数以及所述第三损失函数,得到所述损失函数:,
其中,为学生模型训练时的定位损失函数,为学生模型训练时的分类损失函数。5.一种模型训练装置,其特征在于,包括:样本获取模块,用于获取图像样本;第一训练模块...

【专利技术属性】
技术研发人员:倪华健杨德城林亦宁
申请(专利权)人:上海闪马智能科技有限公司杭州闪马智擎科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1