【技术实现步骤摘要】
学生模型的训练方法、装置及电子设备
[0001]本申请涉及人工智能领域,尤其涉及自然语言处理和深度学习技术等领域,可应用于文本生成、机器翻译以及模型压缩场景下。
技术介绍
[0002]目前,语义通顺度模型可以应用于多个领域,针对机器或其他方式生成的文案,进行通顺度的调整,进而过滤出高品质的符合人们阅读习惯的文本文案。但是,现有的语义通顺度模型的结构较为复杂,参数规模、运算量以及硬件的资源配置和消耗均较大,虽然可以实现高品质的文本文案的过滤输出,但是运算耗时较长,无法实现大规模的应用。
[0003]相关技术中,提出了一种知识蒸馏的方法,将复杂的语义通顺度模型作为教师模型,构建学生模型对其进行学习。但是,目前构建的学生模型仅可以实现对于教师模型最后一层的对齐学习,忽略了中间层,使得学生模型的学习成本较高,同时无法保证最终学生模型的输出结果的精度。
[0004]因此,如何实现对于学生模型对于教师模型的中间层的学习,进而提高模型的精度,是目前需要解决的问题。
技术实现思路
[0005]本申请提出了一种学生模型的训练方法、装置、电子设备、存储介质及计算机程序产品。
[0006]根据本申请的第一方面,提出了一种学生模型的训练方法,包括:将训练样本分别输入学生模型和教师模型中进行训练;获取所述学生模型和所述教师模型在嵌入层上的第一误差;获取所述学生模型和所述教师模型在中间层上的第二误差;获取所述学生模型与所述教师模型在输出层上的损失函数;根据所述第一误差、所述第二误差和所述损失函数,确定所述 ...
【技术保护点】
【技术特征摘要】
1.一种学生模型的训练方法,包括:将训练样本分别输入学生模型和教师模型中进行训练;获取所述学生模型和所述教师模型在嵌入层上的第一误差;获取所述学生模型和所述教师模型在中间层上的第二误差;获取所述学生模型与所述教师模型在输出层上的损失函数;根据所述第一误差、所述第二误差和所述损失函数,确定所述学生模型的总损失函数,并基于所述总损失函数对所述学生模型的模型参数进行调整,并继续使用下一个训练样本对调整模型参数的所述学生模型训练,直至训练结束,生成目标学生模型。2.根据权利要求1所述的方法,其中,所述获取所述学生模型和所述教师模型在嵌入层上的第一误差,包括:根据所述学生模型中嵌入层输出的第一特征表示,以及所述教师模型中嵌入层输出的第二特征表示,获取所述第一误差。3.根据权利要求2所述的方法,其中,所述嵌入层输出的特征表示包括所述训练样本中每个元素的特征值,其中,所述根据所述学生模型中嵌入层输出的特征表示,以及所述教师模型中嵌入层输出的特征表示,获取所述第一误差,包括:获取所述学生模型中嵌入层输出的所述元素的第一特征值,以及所述教师模型中嵌入层输出的所述元素的第二特征值;针对同一个所述元素,根据所述第一特征值和所述第二特征值,获取所述元素的第一均方误差,并将每个所述元素对应的第一均方误差求和,获取所述第一误差。4.根据权利要求1所述的方法,其中,所述获取所述学生模型和所述教师模型在中间层上的第二误差,包括:根据所述学生模型中每个第一中间层输出的第三特征表示,以及所述教师模型中与所述第一中间层匹配的第二中间层输出的第四特征表示,获取所述第二误差。5.根据权利要求4所述的方法,其中,所述根据所述学生模型中每个第一中间层输出的第三特征表示,以及所述教师模型中与所述第一中间层匹配的第二中间层输出的第四特征表示,获取所述第二误差,包括:针对每个所述第一中间层,根据所述第一中间层的所述第三特征表示以及所述匹配的第二中间层的所述第四特征表示,获取所述第一中间层的第二均方误差,并将每个所述第一中间层对应的第二均方误差求和,获取所述第二误差。6.根据权利要求1
‑
5任一项所述的方法,其中,所述获取所述学生模型与所述教师模型在输出层上的损失函数,包括:获取所述学生模型中输出层输出的第一预测概率分布,以及所述教师模型中输出层输出的第二预测概率分布;根据所述第一预测概率分布和所述第二预测概率分布,确定实际类别对应的第一损失函数,以及预测类别对应的第二损失函数;对所述第一损失函数和所述第二损失函数进行加权,获取所述损失函数。7.根据权利要求6所述的方法,其中,所述根据所述第一预测概率分布和所述第二预测概率分布,确定实际类别对应的第一损失函数,以及预测类别对应的第二损失函数,包括:从所述第一预测概率分布中,获取所述训练样本所属的实际类别对应的目标预测概
率,并根据所述目标预测概率,确定所述第一损失函数;针对所有类别中的任一类别,从所述第一预测概率分布中,获取所述任一类别对应的第一预测概率,并从所述第二预测概率分布中,获取所述任一类别对应的第二预测概率;根据所述任一类别的所述第一预测概率和所述第二预测概率,确定所述任一类别的损失值,并将每个所述类别的所述损失值求和,获取所述第二损失函数。8.根据权利要求1
‑
5任一项所述的方法,其中,所述根据所述第一误差、所述第二误差和所述损失函数,确定所述学生模型的总损失函数,包括:对所述第一误差、所述第二误差和所述损失函数进行加权,获取所述总损失函数。9.根据权利要求1
‑
5任一项所述的方法,其中,所述学生模型的第i个第一中间层与所述教师模型的第j个第二中间层匹配,其中,所述i与所述j之间间隔设定层数,其中,所述i、j均大于或者等于零。10.根据权利要求9所述的方法,其中,初始的所述学生模型中每层的模型参数与所述教师模型中对应层的模型参数相同。11.一种学生模型的训练装置,...
【专利技术属性】
技术研发人员:念天磊,刘丽,阳锋,
申请(专利权)人:北京百度网讯科技有限公司,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。