一种模型训练方法、装置、设备以及存储介质制造方法及图纸

技术编号:33436680 阅读:24 留言:0更新日期:2022-05-19 00:25
本公开提供了一种模型训练方法、装置、设备以及存储介质,涉及人工智能技术领域,尤其涉及深度学习、计算机视觉技术领域,可应用于图像处理、图像检测等场景领域。具体实现方案为:将样本图像输入至特征提取网络,得到所述特征提取网络对应的样本特征图;其中,所述特征提取网络包括老师特征提取网络和学生特征提取网络;根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失;根据所述第一蒸馏损失,对所述学生特征提取网络进行训练。能够提高对学生特征提取网络训练的精准性。征提取网络训练的精准性。征提取网络训练的精准性。

【技术实现步骤摘要】
一种模型训练方法、装置、设备以及存储介质


[0001]本公开涉及人工智能
,尤其涉及深度学习、计算机视觉
,可应用于图像处理、图像检测等场景。

技术介绍

[0002]随着人工智能技术的发展,知识蒸馏技术在模型训练过程中的应用越来越广泛。其中,知识蒸馏是一种采用预先训练好的结构复杂的老师模型(Teacher Model)来训练结构简单的学生模型(Student Model),以实现将老师模型的功能赋予学生模型的技术,那么,如何基于知识蒸馏技术,高精度的训练学生模型至关重要。

技术实现思路

[0003]本公开提供了一种模型训练方法、装置、设备以及存储介质。
[0004]根据本公开的一方面,提供了一种模型训练方法,包括:
[0005]将样本图像输入至特征提取网络,得到特征提取网络对应的样本特征图;其中,特征提取网络包括老师特征提取网络和学生特征提取网络;
[0006]根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失;
[0007]根据第一蒸馏损失,对学生本文档来自技高网...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,包括:将样本图像输入至特征提取网络,得到所述特征提取网络对应的样本特征图;其中,所述特征提取网络包括老师特征提取网络和学生特征提取网络;根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失;根据所述第一蒸馏损失,对所述学生特征提取网络进行训练。2.根据权利要求1所述的方法,其中,所述根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失,包括:根据不同特征提取网络对应的样本特征图中至少两个目标对象的特征值,以及所述至少两个目标对象之间的类别关系,确定不同特征提取网络对应的对象关系表示;根据所述不同特征提取网络对应的对象关系表示,确定第一蒸馏损失。3.根据权利要求2所述的方法,其中,所述根据不同特征提取网络对应的样本特征图中至少两个目标对象的特征值,以及所述至少两个目标对象之间的类别关系,确定不同特征提取网络对应的对象关系表示,包括:根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定不同特征提取网络对应的样本特征图中每一目标对象的目标关系;根据不同特征提取网络对应的样本特征图中每一目标对象的特征值,以及该目标对象的目标关系,确定不同特征提取网络对应的对象关系表示。4.根据权利要求1所述的方法,还包括:采用类别关系预测网络确定不同特征提取网络对应的样本特征图中不同目标对象属于同一类别的概率值,作为所述不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系。5.根据权利要求4所述的方法,还包括:根据所述第一蒸馏损失,对所述类别关系预测网络进行训练。6.根据权利要求1所述的方法,还包括:根据不同特征提取网络对应的样本特征图中不同像素点之间的像素关系,确定第二蒸馏损失;相应的,根据所述第一蒸馏损失,对所述学生特征提取网络进行训练,包括:根据所述第一蒸馏损失和所述第二蒸馏损失,对所述学生特征提取网络进行训练。7.根据权利要求6所述的方法,还包括:采用图神经网络确定不同特征提取网络对应的样本特征图中不同像素点之间的特征相似度,作为所述不同特征提取网络对应的样本特征图中不同像素点之间的像素关系。8.根据权利要求7所述的方法,还包括:根据所述第二蒸馏损失,对所述图神经网络进行训练。9.根据权利要求1

8中任一项所述的方法,其中,所述学生特征提取网络属于检测模型中的网络;相应的,所述方法还包括:将所述样本图像输入至训练后的学生特征提取网络,得到目标特征图;其中,所述训练后的学生特征提取网络采用权利要求1

8中任一所述的模型训练方法训练得到;
根据所述目标特征图对所述检测模型中的其他网络进行训练;其中,所述其他网络至少包括分类网络和回归网络。10.一种模型训练装置,包括:特征提取模块,用于将样本图像输入至特征提取网络,得到所述特征提取网络对应的样本特征图;其中,所述特征提取网络包括老师特征提取网络和学生特征提取网络;第一损失确定模块,用于根据不同特征提取网络对应的样本特征图中至少两个目标对象之间的类别关系,确定第一蒸馏损失;网络训练模块,用于根据所述第一蒸馏损失,...

【专利技术属性】
技术研发人员:杨馥魁韩钧宇
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1