【技术实现步骤摘要】
基于知识蒸馏的模型训练方法、相关设备及可读存储介质
[0001]本申请涉及神经网络
,更具体的说,是涉及一种基于知识蒸馏的模型训练方法、相关设备及可读存储介质。
技术介绍
[0002]知识蒸馏是基于教师
‑
学生的模型压缩方式,通过引入大规模教师模型以诱导小规模学生模型的训练,实现知识迁移。传统的基于知识蒸馏的模型训练方法,是先训练一个教师模型,然后使用教师模型的最终输出和训练样本的标注标签去训练学生模型,使得学生模型不仅可以从训练样本中学习如何判断正确样本的类别,还可以从教师模型中学习类间关系。
[0003]但是,在一些场景中,教师模型和学生模型均包含中间网络层,模型的最终输出与中间网络层的输出相关。比如,在基于传统的知识蒸馏方法对基于流式端到端语音识别模型进行训练的场景中,教师模型和学生模型均包含编码器、解码器和联合网络层,其中,编码器和解码器作为中间网络层,联合网络层的输出作为模型的最终输出,联合网络层的输出与编码器和解码器的输出相关。在这些场景中,仅仅使用教师模型的最终输出和训练样本的标注标签去训练学生模型,可能会导致学生模型中间网络层的输出不能与教师模型中间网络层的输出相似,最终可能导致学生模型的最终输出无法向教师模型的最终输出靠近。
[0004]因此,如何对传统的基于知识蒸馏的模型训练方法进行优化,成为本领域技术人员亟待解决的技术问题。
技术实现思路
[0005]鉴于上述问题,本申请提出了一种基于知识蒸馏的模型训练方法、相关设备及可读存储介质。具体方案 ...
【技术保护点】
【技术特征摘要】
1.一种基于知识蒸馏的模型训练方法,其特征在于,所述方法包括:获取预先训练好的教师模型、待训练的学生模型、训练数据以及所述训练数据的标注标签;所述教师模型和所述学生模型均包括中间网络层;以所述训练数据为训练样本,以所述学生模型中间网络层的输出分布趋近于所述教师模型中间网络层的输出分布,所述学生模型的最终输出分布趋近于所述教师模型的最终输出分布,且所述学生模型的最终输出趋近于所述训练数据的标注标签为训练目标,对所述待训练的学生模型进行训练,得到训练好的学生模型。2.根据权利要求1所述的方法,其特征在于,所述以所述训练数据为训练样本,以所述学生模型中间网络层的输出分布趋近于所述教师模型中间网络层的输出分布,所述学生模型的最终输出分布趋近于所述教师模型的最终输出分布,且所述学生模型的最终输出趋近于所述训练数据的标注标签为训练目标,对所述待训练的学生模型进行训练,得到训练好的学生模型,包括:以所述学生模型中间网络层的输出分布趋近于所述教师模型中间网络层的输出分布,所述学生模型的最终输出分布趋近于所述教师模型的最终输出分布,且所述学生模型的最终输出趋近于所述训练数据的标注标签为训练目标构建学生模型损失函数;基于构建的学生模型损失函数对所述学生模型进行训练,得到训练好的学生模型。3.根据权利要求2所述的方法,其特征在于,所述构建的学生模型损失函数包括第一损失项,第二损失项和第三损失项,其中,所述第一损失项用于表征所述教师模型的中间网络层的输出分布与所述学生模型的中间网络层的输出分布之间的误差,所述第二损失项用于表征所述教师模型的最终输出分布与所述学生模型的最终输出分布之间的误差;所述第三损失项用于表征所述学生模型的最终输出与所述训练数据的标注标签之间的误差。4.根据权利要求3所述的方法,其特征在于,所述构建的学生模型损失函数,包括:第一学生模型损失函数;所述第一学生模型损失函数包括所述第一损失项,所述第二损失项和所述第三损失项,其中,所述第一损失项,所述第二损失项和所述第三损失项的系数相同。5.根据权利要求4所述的方法,其特征在于,所述基于构建的学生模型损失函数对所述学生模型进行训练,得到训练好的学生模型,包括:将所述训练数据输入到所述教师模型和所述学生模型,基于所述第一学生模型损失函数对所述学生模型的参数进行迭代优化,直至训练结束,得到第一学生模型;确定所述第一学生模型为训练好的学生模型;或者,将所述训练数据输入到所述教师模型和所述第一学生模型,基于所述第一学生模型损失函数对所述第一学生模型的参数进行迭代优化,直至训练结束,得到训练好的学生模型。6.根据权利要求3所述的方法,其特征在于,所述构建的学生模型损失函数,包括:第二学生模型损失函数和第三学生模型损失函数;其中,所述第二学生模型损失函数包括所述第一损失项;所述第三学生模型损失函数包括所述第二损失项和所述第三损失项,所述第一损失项、所述第二损失项和所述第三损失项的系数相...
【专利技术属性】
技术研发人员:唐海桃,王智国,方昕,李永超,
申请(专利权)人:科大讯飞股份有限公司,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。