学生模型的训练方法、装置及电子设备制造方法及图纸

技术编号:30157757 阅读:32 留言:0更新日期:2021-09-25 15:09
本申请提出了一种学生模型的训练方法及装置,涉及人工智能领域,尤其涉及自然语言处理和深度学习技术等领域,可应用于文本生成、机器翻译、模型压缩等场景下,包括将训练样本分别输入学生模型和教师模型中进行训练;获取学生模型和教师模型在嵌入层上的第一误差、在中间层上的第二误差以及在输出层上的损失函数;根据第一误差、第二误差、损失函数,确定学生模型的总损失函数,并基于总损失函数对学生模型的模型参数进行调整,并继续使用下一个训练样本对调整后的学生模型训练,直至训练结束,生成目标学生模型。本申请中,学生模型可以学习到教师模型的中间层的信息,使得学生模型的训练速度加快,优化了模型的训练效果,提高了模型的性能。了模型的性能。了模型的性能。

【技术实现步骤摘要】
学生模型的训练方法、装置及电子设备


[0001]本申请涉及人工智能领域,尤其涉及自然语言处理和深度学习技术等领域,可应用于文本生成、机器翻译以及模型压缩场景下。

技术介绍

[0002]目前,语义通顺度模型可以应用于多个领域,针对机器或其他方式生成的文案,进行通顺度的调整,进而过滤出高品质的符合人们阅读习惯的文本文案。但是,现有的语义通顺度模型的结构较为复杂,参数规模、运算量以及硬件的资源配置和消耗均较大,虽然可以实现高品质的文本文案的过滤输出,但是运算耗时较长,无法实现大规模的应用。
[0003]相关技术中,提出了一种知识蒸馏的方法,将复杂的语义通顺度模型作为教师模型,构建学生模型对其进行学习。但是,目前构建的学生模型仅可以实现对于教师模型最后一层的对齐学习,忽略了中间层,使得学生模型的学习成本较高,同时无法保证最终学生模型的输出结果的精度。
[0004]因此,如何实现对于学生模型对于教师模型的中间层的学习,进而提高模型的精度,是目前需要解决的问题。

技术实现思路

[0005]本申请提出了一种学生模型的训练方法、装置、电子设备、存储介质及计算机程序产品。
[0006]根据本申请的第一方面,提出了一种学生模型的训练方法,包括:将训练样本分别输入学生模型和教师模型中进行训练;获取所述学生模型和所述教师模型在嵌入层上的第一误差;获取所述学生模型和所述教师模型在中间层上的第二误差;获取所述学生模型与所述教师模型在输出层上的损失函数;根据所述第一误差、所述第二误差和所述损失函数,确定所述学生模型的总损失函数,并基于所述总损失函数对所述学生模型的模型参数进行调整,并继续使用下一个训练样本对调整模型参数的所述学生模型训练,直至训练结束,生成目标学生模型。
[0007]根据本申请的第二方面,提出了一种学生模型的训练装置,包括:输入模块,用于将训练样本分别输入学生模型和教师模型中进行训练;获取模块,用于获取所述学生模型和所述教师模型在嵌入层上的第一误差,以及获取所述学生模型和所述教师模型在中间层上的第二误差,以及获取所述学生模型与所述教师模型在输出层上的损失函数;训练模块,用于根据所述第一误差、所述第二误差和所述损失函数,确定所述学生模型的总损失函数,并基于所述总损失函数对所述学生模型的模型参数进行调整,并继续使用下一个训练样本对调整模型参数的所述学生模型训练,直至训练结束,生成目标学生模型。
[0008]根据本申请的第三方面,提出了一种电子设备,包括:包括处理器和存储器;其中,所述处理器通过读取所述存储器中存储的可执行程序代码来运行与所述可执行程序代码对应的程序,以用于实现如上述第一方面中任一项所述的学生模型的训练方法。
[0009]根据本申请的第四方面,提出了一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现如上述第一方面中任一项所述的学生模型的训练方法。
[0010]根据本申请的第五方面,提出了一种计算机程序产品,当所述计算机程序产品中的指令处理器执行时实现如上述第一方面中任一项所述的学生模型的训练方法。
[0011]应当理解,本部分所描述的内容并非旨在标识本申请的实施例的关键或重要特征,也不用于限制本申请的范围。本申请的其它特征将通过以下的说明书而变得容易理解。
附图说明
[0012]附图用于更好地理解本方案,不构成对本申请的限定。其中:
[0013]图1为本申请一实施例的学生模型的结构示意图;
[0014]图2为本申请一实施例的学生模型的训练方法的流程示意图;
[0015]图3为本申请另一实施例的学生模型的训练方法的流程示意图;
[0016]图4为本申请另一实施例的学生模型的训练方法的流程示意图;
[0017]图5为本申请另一实施例的学生模型的训练方法的流程示意图;
[0018]图6为本申请另一实施例的学生模型的训练方法的流程示意图;
[0019]图7为本申请一实施例的学生模型的训练装置的结构示意图;
[0020]图8为本申请另一实施例的学生模型的训练装置的结构示意图;
[0021]图9为本申请一实施例的电子设备的示意性框图。
具体实施方式
[0022]以下结合附图对本申请的示范性实施例做出说明,其中包括本申请实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本申请的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
[0023]深度学习(Deep Learning,简称DL),是机器学习(Machine Learning,简称ML)领域中一个新的研究方向,它被引入机器学习使其更接近于最初的目标——人工智能。深度学习是学习样本数据的内在律和表示层次,这些学习过程中获得的信息对诸如文字,图像和声音等数据的解释有很大的帮助。它的最终目标是让机器能够像人一样具有分析学习能力,能够识别文字、图像和声音等数据。深度学习是一个复杂的机器学习算法,在语音和图像识别方面取得的效果,远远超过先前相关技术。
[0024]自然语言处理(Natural Language Processing,NLP)是计算机科学领域与人工智能领域中的一个重要方向。它研究能实现人与计算机之间用自然语言进行有效通信的各种理论和方法。自然语言处理是一门融语言学、计算机科学、数学于一体的科学。自然语言处理主要应用于机器翻译、舆情监测、自动摘要、观点提取、文本分类、问题回答、文本语义对比、语音识别等方面。
[0025]人工智能(Artificial Intelligence,简称AI),是研究使计算机来模拟人生的某些思维过程和智能行为(如学习、推理、思考、规划等)的学科,既有硬件层面的技术,也有软件层面的技术。人工智能硬件技术一般包括计算机视觉技术、语音识别技术、自然语言处理技术以及及其学习/深度学习、大数据处理技术、知识图谱技术等几大方面。
[0026]机器翻译(machine translation),又称为自动翻译,是利用计算机将一种自然语言(源语言)转换为另一种自然语言(目标语言)的过程。它是计算语言学的一个分支,是人工智能的目标之一。
[0027]图1为本申请一实施例的学生模型的结构示意图,如图1所示。
[0028]教师模型100包括嵌入层11、多个自注意力层(transformer层)12和输出层13,学生模型20包括嵌入层21、多个transformer层22和输出层23。可选地,从教师模型100的多个transformer层12中,基于设定的层间隔,每间隔N层提取一层的教师模型100的transformer层12,并将其设置为学生模型200的transformer层22。
[0029]如图1所示,设定教师模型100中共计12个transformer层12,设定学生模型200每间隔4层的层间隔提取教师模型100的一层本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种学生模型的训练方法,包括:将训练样本分别输入学生模型和教师模型中进行训练;获取所述学生模型和所述教师模型在嵌入层上的第一误差;获取所述学生模型和所述教师模型在中间层上的第二误差;获取所述学生模型与所述教师模型在输出层上的损失函数;根据所述第一误差、所述第二误差和所述损失函数,确定所述学生模型的总损失函数,并基于所述总损失函数对所述学生模型的模型参数进行调整,并继续使用下一个训练样本对调整模型参数的所述学生模型训练,直至训练结束,生成目标学生模型。2.根据权利要求1所述的方法,其中,所述获取所述学生模型和所述教师模型在嵌入层上的第一误差,包括:根据所述学生模型中嵌入层输出的第一特征表示,以及所述教师模型中嵌入层输出的第二特征表示,获取所述第一误差。3.根据权利要求2所述的方法,其中,所述嵌入层输出的特征表示包括所述训练样本中每个元素的特征值,其中,所述根据所述学生模型中嵌入层输出的特征表示,以及所述教师模型中嵌入层输出的特征表示,获取所述第一误差,包括:获取所述学生模型中嵌入层输出的所述元素的第一特征值,以及所述教师模型中嵌入层输出的所述元素的第二特征值;针对同一个所述元素,根据所述第一特征值和所述第二特征值,获取所述元素的第一均方误差,并将每个所述元素对应的第一均方误差求和,获取所述第一误差。4.根据权利要求1所述的方法,其中,所述获取所述学生模型和所述教师模型在中间层上的第二误差,包括:根据所述学生模型中每个第一中间层输出的第三特征表示,以及所述教师模型中与所述第一中间层匹配的第二中间层输出的第四特征表示,获取所述第二误差。5.根据权利要求4所述的方法,其中,所述根据所述学生模型中每个第一中间层输出的第三特征表示,以及所述教师模型中与所述第一中间层匹配的第二中间层输出的第四特征表示,获取所述第二误差,包括:针对每个所述第一中间层,根据所述第一中间层的所述第三特征表示以及所述匹配的第二中间层的所述第四特征表示,获取所述第一中间层的第二均方误差,并将每个所述第一中间层对应的第二均方误差求和,获取所述第二误差。6.根据权利要求1

5任一项所述的方法,其中,所述获取所述学生模型与所述教师模型在输出层上的损失函数,包括:获取所述学生模型中输出层输出的第一预测概率分布,以及所述教师模型中输出层输出的第二预测概率分布;根据所述第一预测概率分布和所述第二预测概率分布,确定实际类别对应的第一损失函数,以及预测类别对应的第二损失函数;对所述第一损失函数和所述第二损失函数进行加权,获取所述损失函数。7.根据权利要求6所述的方法,其中,所述根据所述第一预测概率分布和所述第二预测概率分布,确定实际类别对应的第一损失函数,以及预测类别对应的第二损失函数,包括:从所述第一预测概率分布中,获取所述训练样本所属的实际类别对应的目标预测概
率,并根据所述目标预测概率,确定所述第一损失函数;针对所有类别中的任一类别,从所述第一预测概率分布中,获取所述任一类别对应的第一预测概率,并从所述第二预测概率分布中,获取所述任一类别对应的第二预测概率;根据所述任一类别的所述第一预测概率和所述第二预测概率,确定所述任一类别的损失值,并将每个所述类别的所述损失值求和,获取所述第二损失函数。8.根据权利要求1

5任一项所述的方法,其中,所述根据所述第一误差、所述第二误差和所述损失函数,确定所述学生模型的总损失函数,包括:对所述第一误差、所述第二误差和所述损失函数进行加权,获取所述总损失函数。9.根据权利要求1

5任一项所述的方法,其中,所述学生模型的第i个第一中间层与所述教师模型的第j个第二中间层匹配,其中,所述i与所述j之间间隔设定层数,其中,所述i、j均大于或者等于零。10.根据权利要求9所述的方法,其中,初始的所述学生模型中每层的模型参数与所述教师模型中对应层的模型参数相同。11.一种学生模型的训练装置,...

【专利技术属性】
技术研发人员:念天磊刘丽阳锋
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1