一种应用于BERT模型的蒸馏方法、装置、设备及存储介质制造方法及图纸

技术编号:27507274 阅读:22 留言:0更新日期:2021-03-02 18:35
本申请实施例属于深度学习技术领域,涉及一种应用于BERT模型的蒸馏方法、装置、计算机设备及存储介质。本申请提供的应用于BERT模型的蒸馏方法,由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。使得精简BERT模型收敛得更加稳定。使得精简BERT模型收敛得更加稳定。

【技术实现步骤摘要】
一种应用于BERT模型的蒸馏方法、装置、设备及存储介质


[0001]本申请涉及深度学习
,尤其涉及一种应用于BERT模型的蒸馏方法、装置、计算机设备及存储介质。

技术介绍

[0002]近年在计算机视觉、语音识别等诸多领域,在利用深度网络解决问题的时候人们常常倾向于设计更为复杂的网络收集更多的数据以期获得更好的结果。但是,随之而来的是模型的复杂度急剧提升,直观的表现是模参数越来越多、规模越来越大,需要的硬件资源(内存、GPU)越来越高。不利于模型的部署和应用向移动端的推广。
[0003]现有一种深度模型蒸馏方法,采用蒸馏模型的优势在进行模型蒸馏时匹配各个中间层之间的数据,已实现压缩模型的目的。
[0004]然而,传统的深度模型蒸馏方法普遍不智能,在蒸馏的过程中匹配中间层输出时,往往需要平衡较多损失(loss)参数,例如:下游任务loss、中间层输出loss、相关矩阵loss、注意力矩阵(Attention)loss、等等,从而导致传统的深度模型蒸馏方法存在平衡loss参数较为困难的问题。

技术实现思路

[0005]本申请实施例的目的在于提出一种应用于BERT模型的蒸馏方法、装置、计算机设备及存储介质,以解决传统的深度模型蒸馏方法存在平衡loss参数较为困难的问题。
[0006]为了解决上述技术问题,本申请实施例提供一种应用于BERT模型的蒸馏方法,采用了如下所述的技术方案:
[0007]接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;
[0008]读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;
[0009]构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;
[0010]基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;
[0011]在所述本地数据库中获取所述中间精简模型的训练数据;
[0012]基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。
[0013]为了解决上述技术问题,本申请实施例还提供一种应用于BERT模型的蒸馏装置,采用了如下所述的技术方案:
[0014]请求接收模块,用于接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;
[0015]原始模型获取模块,用于读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;
[0016]默认模型构建模块,用于构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;
[0017]蒸馏操作模块,用于基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;
[0018]训练数据获取模块,用于在所述本地数据库中获取所述中间精简模型的训练数据;
[0019]模型训练模块,用于基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。
[0020]为了解决上述技术问题,本申请实施例还提供一种计算机设备,采用了如下所述的技术方案:
[0021]包括存储器和处理器,所述存储器中存储有计算机可读指令,所述处理器执行所述计算机可读指令时实现如上所述的应用于BERT模型的蒸馏方法的步骤。
[0022]为了解决上述技术问题,本申请实施例还提供一种计算机可读存储介质,采用了如下所述的技术方案:
[0023]所述计算机可读存储介质上存储有计算机可读指令,所述计算机可读指令被处理器执行时实现如上所述的应用于BERT模型的蒸馏方法的步骤。
[0024]与现有技术相比,本申请实施例提供的应用于BERT模型的蒸馏方法、装置、计算机设备及存储介质主要有以下有益效果:
[0025]本申请实施例提供了一种应用于BERT模型的蒸馏方法,接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;在所述本地数据库中获取所述中间精简模型的训练数据;基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。
附图说明
[0026]为了更清楚地说明本申请中的方案,下面将对本申请实施例描述中所需要使用的附图作一个简单介绍,显而易见地,下面描述中的附图是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
[0027]图1是本申请实施例一提供的应用于BERT模型的蒸馏方法的实现流程图;
[0028]图2是图1中步骤S104的实现流程图;
[0029]图3是图1中步骤S105的实现流程图;
[0030]图4是本申请实施例一提供的参数优化操作的实现流程图;
[0031]图5是图4中步骤S403的实现流程图;
[0032]图6是本申请实施例二提供的应用于BERT模型的蒸馏装置的结构示意图;
[0033]图7是根据本申请的计算机设备的一个实施例的结构示意图。
具体实施方式
[0034]除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的
的技术人员通常理解的含义相同;本文中在申请的说明书中所使用的术语只是为了描述具体的实施例的目的,不是旨在于限制本申请;本申请的说明书和权利要求书及上述附图说明中的术语“包括”和“具有”以及它们的任何变形,意图在于覆盖不排他的包含。本申请的说明书和权利要求书或上述附图中的术语“第一”、“第二”等是用于区别不同对象,而不是用于描述特定顺序。
[0035]在本文中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本申请的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员显式地和隐式地理解的是,本文所描述的实施例可以与其它实施例相结合。
[0036]为了使本
的人员更好地理解本申请方案,下面将结合附图,对本申请实施例中的技术方案本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种应用于BERT模型的蒸馏方法,其特征在于,包括下述步骤:接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;在所述本地数据库中获取所述中间精简模型的训练数据;基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。2.根据权利要求1所述的应用于BERT模型的蒸馏方法,其特征在于,所述基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型的步骤,具体包括:基于所述蒸馏系数对所述原始BERT模型的transformer层进行分组操作,得到分组transformer层;基于伯努利分布分别在所述分组transformer层中进行提取操作,得到待替换transformer层;将所述待替换transformer层分别替换至所述默认精简模型,得到所述中间精简模型。3.根据权利要求1所述的应用于BERT模型的蒸馏方法,其特征在于,所述在所述本地数据库中获取所述中间精简模型的训练数据的步骤,具体包括:获取所述原始BERT模型训练后的原始训练数据;调高所述原始BERT模型softmax层的温度参数,得到调高BERT模型;将所述原始训练数据输入至所述调高BERT模型进行预测操作,得到均值结果标签;基于标签信息在所述原始训练数据进行筛选操作,得到带标签的筛选结果标签;基于所述放大训练数据以及所述筛选训练数据选取所述精简模型训练数据。4.根据权利要求1所述的应用于BERT模型的蒸馏方法,其特征在于,在所述基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型的步骤之后还包括:在所述本地数据库中获取优化训练数据;将所述优化训练数据分别输入至所述训练好的原始BERT模型以及所述目标精简模型中,分别得到原始transformer层输出数据以及目标transformer层输出数据;基于搬土距离计算所述原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据;根据所述蒸馏损失数据对所述目标精简模型进行参数优化操作,得到优化精简模型。5.根据权利要求4所述的应用于BERT模型的蒸馏方法,其特征在于,所述基于搬土距离计算所述原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据的步骤,具体包括:获取所述原始transformer层输出的原始注意力矩阵以及所述目标transformer层输出的目标注意力矩阵;根据所述原始注意力矩阵以及所述目标注意力矩阵计算注意力EMD距离;获取所述原始transformer层输出的原始FFN隐层矩阵以及所述目标transformer层输
出的目标FFN隐层矩阵;根据所述原始FFN隐层矩阵以及所述目标FFN隐层矩阵计算FFN隐层EMD距...

【专利技术属性】
技术研发人员:朱桂良
申请(专利权)人:平安科技深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1