基于知识蒸馏的模型训练方法、相关设备及可读存储介质技术

技术编号:34458975 阅读:28 留言:0更新日期:2022-08-06 17:13
本申请公开了一种基于知识蒸馏的模型训练方法、相关设备及可读存储介质。在获取教师模型、学生模型、训练数据以及训练数据的标注标签之后;以训练数据为训练样本,以学生模型中间网络层的输出分布趋近于教师模型中间网络层的输出分布,学生模型的最终输出分布趋近于教师模型的最终输出分布,且学生模型的最终输出趋近于训练数据的标注标签为训练目标,对待训练的学生模型进行训练,得到训练好的学生模型。由于在训练过程中,同时利用教师模型中间网络层的输出和最终输出指导学生模型的学习,能够使学生模型中间网络层的输出与教师模型中间网络层的输出尽可能接近,从而保证了学生模型的最终输出与教师模型的最终输出也尽可能接近。可能接近。可能接近。

【技术实现步骤摘要】
基于知识蒸馏的模型训练方法、相关设备及可读存储介质


[0001]本申请涉及神经网络
,更具体的说,是涉及一种基于知识蒸馏的模型训练方法、相关设备及可读存储介质。

技术介绍

[0002]知识蒸馏是基于教师

学生的模型压缩方式,通过引入大规模教师模型以诱导小规模学生模型的训练,实现知识迁移。传统的基于知识蒸馏的模型训练方法,是先训练一个教师模型,然后使用教师模型的最终输出和训练样本的标注标签去训练学生模型,使得学生模型不仅可以从训练样本中学习如何判断正确样本的类别,还可以从教师模型中学习类间关系。
[0003]但是,在一些场景中,教师模型和学生模型均包含中间网络层,模型的最终输出与中间网络层的输出相关。比如,在基于传统的知识蒸馏方法对基于流式端到端语音识别模型进行训练的场景中,教师模型和学生模型均包含编码器、解码器和联合网络层,其中,编码器和解码器作为中间网络层,联合网络层的输出作为模型的最终输出,联合网络层的输出与编码器和解码器的输出相关。在这些场景中,仅仅使用教师模型的最终输出和训练样本的标注标签去训练学生模型,可能会导致学生模型中间网络层的输出不能与教师模型中间网络层的输出相似,最终可能导致学生模型的最终输出无法向教师模型的最终输出靠近。
[0004]因此,如何对传统的基于知识蒸馏的模型训练方法进行优化,成为本领域技术人员亟待解决的技术问题。

技术实现思路

[0005]鉴于上述问题,本申请提出了一种基于知识蒸馏的模型训练方法、相关设备及可读存储介质。具体方案如下:
[0006]一种基于知识蒸馏的模型训练方法,所述方法包括:
[0007]获取预先训练好的教师模型、待训练的学生模型、训练数据以及所述训练数据的标注标签;所述教师模型和所述学生模型均包括中间网络层;
[0008]以所述训练数据为训练样本,以所述学生模型中间网络层的输出分布趋近于所述教师模型中间网络层的输出分布,所述学生模型的最终输出分布趋近于所述教师模型的最终输出分布,且所述学生模型的最终输出趋近于所述训练数据的标注标签为训练目标,对所述待训练的学生模型进行训练,得到训练好的学生模型。
[0009]可选地,所述以所述训练数据为训练样本,以所述学生模型中间网络层的输出分布趋近于所述教师模型中间网络层的输出分布,所述学生模型的最终输出分布趋近于所述教师模型的最终输出分布,且所述学生模型的最终输出趋近于所述训练数据的标注标签为训练目标,对所述待训练的学生模型进行训练,得到训练好的学生模型,包括:
[0010]以所述学生模型中间网络层的输出分布趋近于所述教师模型中间网络层的输出
分布,所述学生模型的最终输出分布趋近于所述教师模型的最终输出分布,且所述学生模型的最终输出趋近于所述训练数据的标注标签为训练目标构建学生模型损失函数;
[0011]基于构建的学生模型损失函数对所述学生模型进行训练,得到训练好的学生模型。
[0012]可选地,所述构建的学生模型损失函数包括第一损失项,第二损失项和第三损失项,其中,所述第一损失项用于表征所述教师模型的中间网络层的输出分布与所述学生模型的中间网络层的输出分布之间的误差,所述第二损失项用于表征所述教师模型的最终输出分布与所述学生模型的最终输出分布之间的误差;所述第三损失项用于表征所述学生模型的最终输出与所述训练数据的标注标签之间的误差。
[0013]可选地,所述构建的学生模型损失函数,包括:第一学生模型损失函数;
[0014]所述第一学生模型损失函数包括所述第一损失项,所述第二损失项和所述第三损失项,其中,所述第一损失项,所述第二损失项和所述第三损失项的系数相同。
[0015]可选地,所述基于构建的学生模型损失函数对所述学生模型进行训练,得到训练好的学生模型,包括:
[0016]将所述训练数据输入到所述教师模型和所述学生模型,基于所述第一学生模型损失函数对所述学生模型的参数进行迭代优化,直至训练结束,得到第一学生模型;
[0017]确定所述第一学生模型为训练好的学生模型;
[0018]或者,
[0019]将所述训练数据输入到所述教师模型和所述第一学生模型,基于所述第一学生模型损失函数对所述第一学生模型的参数进行迭代优化,直至训练结束,得到训练好的学生模型。
[0020]可选地,所述构建的学生模型损失函数,包括:第二学生模型损失函数和第三学生模型损失函数;
[0021]其中,所述第二学生模型损失函数包括所述第一损失项;所述第三学生模型损失函数包括所述第二损失项和所述第三损失项,所述第一损失项、所述第二损失项和所述第三损失项的系数相同。
[0022]可选地,所述基于构建的学生模型损失函数对所述学生模型进行训练,得到训练好的学生模型,包括:
[0023]将所述训练数据输入到所述教师模型和所述学生模型,基于所述第二学生模型损失函数对所述学生模型的参数进行迭代优化,直至训练结束,得到第二学生模型;
[0024]将所述训练数据输入到所述教师模型和所述第二学生模型,基于所述第三学生模型损失函数对所述第二学生模型的参数进行迭代优化,直至训练结束,得到训练好的学生模型。
[0025]可选地,所述构建的学生模型损失函数,包括:第四学生模型损失函数和第五学生模型损失函数;
[0026]所述第四学生模型损失函数包括所述第一损失项,所述第二损失项和所述第三损失项,其中,所述第二损失项和所述第三损失项的系数相同,所述第一损失项的系数远大于所述第二损失项和所述第三损失项的系数;
[0027]所述第五学生模型损失函数包括所述第一损失项,所述第二损失项和所述第三损
失项,其中,所述第二损失项的系数和所述第三损失项的系数相同,所述第二损失项和所述第三损失项的系数远大于所述第一损失项的系数。
[0028]可选地,所述基于构建的学生模型损失函数对所述学生模型进行训练,得到训练好的学生模型,包括:
[0029]将所述训练数据输入到所述教师模型和所述学生模型,基于所述第四学生模型损失函数对所述学生模型的参数进行迭代优化,直至训练结束,得到第三学生模型;
[0030]将所述训练数据输入到所述教师模型和所述第三学生模型,基于所述第五学生模型损失函数对所述第三学生模型的参数进行迭代优化,直至训练结束,得到训练好的学生模型。
[0031]可选地,所述教师模型的中间网络层的输出分布与所述学生模型的中间网络层的输出分布之间的误差,包括:所述教师模型的中间网络层各个子层的输出分布与所述学生模型的中间网络层各个子层的输出分布之间的误差。
[0032]可选地,所述教师模型和所述学生模型均为流式端到端语音识别模型,所述流式端到端语音识别模型包括编码器、解码器和联合网络,所述编码器和解码器为中间网络层,所述联合网络的输出为所述流式端到端语音识别模型的最终输出。
[0033]一种基于知识蒸馏的模型训练装置,所述装置包括:
[00本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于知识蒸馏的模型训练方法,其特征在于,所述方法包括:获取预先训练好的教师模型、待训练的学生模型、训练数据以及所述训练数据的标注标签;所述教师模型和所述学生模型均包括中间网络层;以所述训练数据为训练样本,以所述学生模型中间网络层的输出分布趋近于所述教师模型中间网络层的输出分布,所述学生模型的最终输出分布趋近于所述教师模型的最终输出分布,且所述学生模型的最终输出趋近于所述训练数据的标注标签为训练目标,对所述待训练的学生模型进行训练,得到训练好的学生模型。2.根据权利要求1所述的方法,其特征在于,所述以所述训练数据为训练样本,以所述学生模型中间网络层的输出分布趋近于所述教师模型中间网络层的输出分布,所述学生模型的最终输出分布趋近于所述教师模型的最终输出分布,且所述学生模型的最终输出趋近于所述训练数据的标注标签为训练目标,对所述待训练的学生模型进行训练,得到训练好的学生模型,包括:以所述学生模型中间网络层的输出分布趋近于所述教师模型中间网络层的输出分布,所述学生模型的最终输出分布趋近于所述教师模型的最终输出分布,且所述学生模型的最终输出趋近于所述训练数据的标注标签为训练目标构建学生模型损失函数;基于构建的学生模型损失函数对所述学生模型进行训练,得到训练好的学生模型。3.根据权利要求2所述的方法,其特征在于,所述构建的学生模型损失函数包括第一损失项,第二损失项和第三损失项,其中,所述第一损失项用于表征所述教师模型的中间网络层的输出分布与所述学生模型的中间网络层的输出分布之间的误差,所述第二损失项用于表征所述教师模型的最终输出分布与所述学生模型的最终输出分布之间的误差;所述第三损失项用于表征所述学生模型的最终输出与所述训练数据的标注标签之间的误差。4.根据权利要求3所述的方法,其特征在于,所述构建的学生模型损失函数,包括:第一学生模型损失函数;所述第一学生模型损失函数包括所述第一损失项,所述第二损失项和所述第三损失项,其中,所述第一损失项,所述第二损失项和所述第三损失项的系数相同。5.根据权利要求4所述的方法,其特征在于,所述基于构建的学生模型损失函数对所述学生模型进行训练,得到训练好的学生模型,包括:将所述训练数据输入到所述教师模型和所述学生模型,基于所述第一学生模型损失函数对所述学生模型的参数进行迭代优化,直至训练结束,得到第一学生模型;确定所述第一学生模型为训练好的学生模型;或者,将所述训练数据输入到所述教师模型和所述第一学生模型,基于所述第一学生模型损失函数对所述第一学生模型的参数进行迭代优化,直至训练结束,得到训练好的学生模型。6.根据权利要求3所述的方法,其特征在于,所述构建的学生模型损失函数,包括:第二学生模型损失函数和第三学生模型损失函数;其中,所述第二学生模型损失函数包括所述第一损失项;所述第三学生模型损失函数包括所述第二损失项和所述第三损失项,所述第一损失项、所述第二损失项和所述第三损失项的系数相...

【专利技术属性】
技术研发人员:唐海桃王智国方昕李永超
申请(专利权)人:科大讯飞股份有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1