一种模型蒸馏训练方法、装置、电子设备及存储介质制造方法及图纸

技术编号:31911334 阅读:403 留言:0更新日期:2022-01-15 12:51
本申请提供一种模型蒸馏训练方法、装置、电子设备及存储介质,用于改善对长尾类别的样本识别正确率提高十分有限的问题。该方法包括:获取包括长尾类别的训练数据集,并使用多种数据增强手段对训练数据集进行数据增强,获得多个数据集合;使用多个数据集合分别对多个老师模型进行不同种类的损失优化训练,获得训练后的多个老师模型,一个老师模型是使用一个数据集合进行一个种类的损失优化训练获得的;根据准确率从多个老师模型中选择出预设数量的老师模型;获取预设数量的学生模型,使用老师模型对学生模型进行蒸馏训练,获得预设数量蒸馏后的学生模型;从预设数量蒸馏后的学生模型中筛选出准确率最高的学生模型。型中筛选出准确率最高的学生模型。型中筛选出准确率最高的学生模型。

【技术实现步骤摘要】
一种模型蒸馏训练方法、装置、电子设备及存储介质


[0001]本申请涉及神经网络、深度学习、数据增强和知识蒸馏的
,具体而言,涉及一种模型蒸馏训练方法、装置、电子设备及存储介质。

技术介绍

[0002]长尾类别,又被称为少样本类别,是指模型的训练数据集中样本数量较少的类别。具体例如:如果外卖平台的评价作为训练数据集,那么该训练数据集中的好评条数通常远大于差评条数,因此,此处的差评类别就可以理解为上述的长尾类别。
[0003]目前,在神经网络模型的训练过程中,当训练神经网络模型所用的训练数据集中特定类样本数量较少时,该特定类样本的识别正确率就会比其它类别的识别正确率低很多,这种现象被称为数据不均衡问题。为了增加长尾类别的样本识别正确率,通常的做法是增加长尾类别的数据条数或者减少非长尾类别的数据条数,具体例如:人工采集更多长尾类别的样本作为训练数据等等,但是,在实践的过程中发现这种做法对长尾类别的样本识别正确率提高十分有限,有时甚至没有提高。

技术实现思路

[0004]本申请实施例的目的在于提供一种模型蒸馏训练方法、装置本文档来自技高网...

【技术保护点】

【技术特征摘要】
1.一种模型蒸馏训练方法,其特征在于,包括:获取包括长尾类别的训练数据集,并使用多种数据增强手段对所述训练数据集进行数据增强,获得多个数据集合;使用所述多个数据集合分别对多个老师模型进行不同种类的损失优化训练,获得训练后的多个老师模型,一个所述老师模型是使用一个所述数据集合进行一个种类的损失优化训练获得的;根据准确率从所述多个老师模型中选择出预设数量的老师模型;获取所述预设数量的学生模型,使用所述老师模型对所述学生模型进行蒸馏训练,获得所述预设数量蒸馏后的学生模型;从所述预设数量蒸馏后的学生模型中筛选出准确率最高的学生模型。2.根据权利要求1所述的方法,其特征在于,在所述从所述预设数量蒸馏后的学生模型中筛选出准确率最高的学生模型之后,还包括:获取待处理数据;使用所述准确率最高的学生模型对所述待处理数据进行分类预测,获得分类结果。3.根据权利要求1所述的方法,其特征在于,所述训练数据集是文本数据集;所述使用多种数据增强手段对所述训练数据集进行数据增强,包括:使用同义词替换、回译、动态遮掩、随机插入、随机交换和/或随机删除的数据增强手段对所述文本数据集进行数据增强。4.根据权利要求1所述的方法,其特征在于,所述训练数据集是图像数据集和/或视频数据集;所述使用多种数据增强手段对所述训练数据集进行数据增强,包括:使用图像缩放、图像旋转、水平翻转和垂直翻转的数据增强手段对所述图像数据集和/或所述视频数据集进行数据增强。5.根据权利要求1所述的方法,其特征在于,所述老师模型包括:第一嵌入层、第一转换器层和第一预测层,所述学生模型包括:第二嵌入层、第二转换器层和第二预测层;所述使用所述老师模型对所述学生模型进行蒸馏训练,包括:利用推土机距离EMD对所述数据集合中的数据标签、所述第一预测层的输出与所述第二预测层的输出进行计算,获得第一蒸馏损失,并分别计算所述第一转换器层的输出与所述第二转换器层的输出之间的第二蒸馏损失,以及所述第一嵌入层的输出与所述第二嵌入层的输出...

【专利技术属性】
技术研发人员:胡加明李健铨吴相博刘小康
申请(专利权)人:鼎富智能科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1