轻量级模型训练方法、图像处理方法、装置及电子设备制造方法及图纸

技术编号:35907475 阅读:41 留言:0更新日期:2022-12-10 10:46
本公开提供了一种轻量级模型训练方法、图像处理方法、装置及电子设备。涉及计算机技术领域,尤其涉及机器学习、计算机视觉、图像处理等技术领域。具体方案为:获取第e轮迭代采用的第一增广概率、第二增广概率和目标权重;分别基于第一增广概率和第二增广概率对数据集进行数据增广,得到第一数据集和第二数据集;基于第一数据集得到学生模型的第一输出值和教师模型的第二输出值,基于第二数据集得到第三输出值和第四输出值;确定蒸馏损失函数;确定真值损失函数;确定目标损失函数;基于目标损失函数对学生模型进行训练,在e小于E的情况下确定第e+1轮迭代应采用的第一增广概率或目标权重。根据本公开,能提高轻量级模型的训练精度。度。度。

【技术实现步骤摘要】
轻量级模型训练方法、图像处理方法、装置及电子设备


[0001]本公开涉及计算机
,尤其涉及机器学习、计算机视觉、图像处理等


技术介绍

[0002]知识蒸馏是一种常用的模型压缩方法,它指的是使用精度更高的教师模型去指导学生模型的训练过程,从而使得学生模型也能获得与教师模型相似的精度指标。知识蒸馏本身是带有正则化的效果,对于过拟合的模型采用知识蒸馏的方法进行训练,可以在一定程度上防止模型过拟合。但是对于一些轻量级模型来说,使用知识蒸馏策略进行训练,可能会带来进一步的欠拟合现象,从而导致模型精度变差,甚至不收敛。

技术实现思路

[0003]本公开提供了一种轻量级模型训练方法、图像处理方法、装置及电子设备。
[0004]根据本公开的第一方面,提供了一种轻量级模型训练方法,包括:
[0005]获取第e轮迭代采用的第一增广概率、第二增广概率和目标权重,目标权重是蒸馏项损失值的权重,e为不大于E的正整数,E为最大迭代轮数,E为大于1的正整数;
[0006]分别基于第一增广概率和第二增广概率对数据集进行数本文档来自技高网...

【技术保护点】

【技术特征摘要】
1.一种轻量级模型训练方法,包括:获取第e轮迭代采用的第一增广概率、第二增广概率和目标权重,所述目标权重是蒸馏项损失值的权重,所述e为不大于E的正整数,所述E为最大迭代轮数,所述E为大于1的正整数;分别基于所述第一增广概率和所述第二增广概率对数据集进行数据增广,得到第一数据集和第二数据集;基于所述第一数据集得到学生模型的第一输出值和教师模型的第二输出值,以及基于所述第二数据集得到所述学生模型的第三输出值和所述教师模型的第四输出值,所述学生模型是轻量级模型;基于所述第一输出值和所述第二输出值确定蒸馏损失函数,以及基于所述第三输出值和所述第四输出值确定真值损失函数;基于所述蒸馏损失函数和所述真值损失函数确定目标损失函数;基于所述目标损失函数对所述学生模型进行训练,以及在所述e小于所述E的情况下,确定第e+1轮迭代应采用的第一增广概率或目标权重。2.根据权利要求1所述的方法,还包括:获取最大增广概率;基于所述最大增广概率、所述最大迭代轮数和所述第一增广概率,确定所述第二增广概率。3.根据权利要求2所述的方法,其中,确定第e+1轮迭代应采用的第一增广概率,包括:基于所述最大增广概率、所述最大迭代轮数和所述第e轮的所述第一增广概率,确定所述第e+1轮迭代应采用的第一增广概率。4.根据权利要求1所述的方法,还包括:获取最大目标权重;其中,确定第e+1轮迭代应采用的目标权重,包括:基于所述最大目标权重、所述最大迭代轮数和所述第e轮的所述目标权重,确定所述第e+1轮迭代应采用的目标权重。5.根据权利要求1所述的方法,其中,所述基于所述蒸馏损失函数和所述真值损失函数确定目标损失函数,包括:在所述目标权重不小于最大目标权重或者所述蒸馏损失函数不小于所述真值损失函数的情况下,将所述蒸馏损失函数确定为所述目标损失函数;否则,将所述真值损失函数确定为所述目标损失函数。6.根据权利要求1所述的方法,其中,所述基于所述第一输出值和所述第二输出值确定蒸馏损失函数,包括:根据下述公式确定所述蒸馏损失函数:l1=(a+a
dft
*2/E)*L
dist
(o1s,o1t)+(1

a

a
dft
*2/E)*L
gt
(o1s,gt);其中,l1表示所述蒸馏损失函数,L
dist
(1s,o1t)表示根据所述第一输出值和所述第二输出值确定出的蒸馏项损失值,L
gt
(o1s,gt)表示根据所述第一输出值和真值确定出的真值项损失值,a表示所述目标权重,a
dft
表示最大目标权重,E表示所述最大迭代轮数,gt表示所述真值,o1s表示所述第一输出值,o1t表示所述第二输出值。
7.根据权利要求1所述的方法,其中,所述基于所述第三输出值和所述第四输出值确定真值损失函数,包括:根据下述公式确定所述真值损失函数:l2=a*L
dist
(o2s,o2t)+(1

a)*L
gt
(o2s,gt);其中,l2表示所述真值损失函数,L
dist
(o2s,o2t)表示根据所述第三输出值和所述第四输出值确定出的蒸馏项损失值,L
gt
(o2s,gt)表示根据所述第三输出值和真值确定出的真值项损失值,a表示所述目标权重,gt表示所述真值,o2s表示所述第三输出值,o2t表示所述第四输出值。8.一种图像处理方法,包括:接收目标场景下的待处理图像;将所述待处理图像输入学生模型,获取所述学生模型输出的对所述待处理图像的处理结果;其中,所述学生模型采用根据权利要求1至7任一项所述的轻量级模型训练方法得到。9.根据权利要求8所述的方法,所述接收目标场景下的待处理图像,包括下述至少之一:获取图像分类场景下的待处理图像;获取图像识别场景下的待处理图像;获取目标检测场景下的待处理图像。10.一种轻量级模型训练装置,包括:第一获取模块,用于获取第e轮迭代采用的第一增广概率、第二增广概率和目标权重,所述目标权重是蒸馏项损失值的权重,所述e为不大于E的正整数,所述E为最大迭代轮数,所述E为大于1的正整数;数据增广模块,用于分别基于所述第一增广概率和所述第二增广概率对数据集进行数据增广,得到第一数据集和第二数据集;预测模块,用于基于所述第一数据集得到学生模型的第一输出值和...

【专利技术属性】
技术研发人员:郭若愚杜宇宁李晨霞赖宝华马艳军
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1