当前位置: 首页 > 专利查询>中山大学专利>正文

一种基于元学习的对抗采样训练方法及装置制造方法及图纸

技术编号:28424849 阅读:27 留言:0更新日期:2021-05-11 18:33
本发明专利技术公开了一种基于元学习的对抗采样训练方法及装置,所述方法:根据策略网络从K个语种构成的大任务集T中输出K维概率向量

【技术实现步骤摘要】
一种基于元学习的对抗采样训练方法及装置
本专利技术涉及语音识别
,尤其涉及一种基于元学习的对抗采样训练方法及装置。
技术介绍
随着深度学习理论和相关技术的蓬勃发展,语音识别领域取得了巨大的进展。然而构造一个端到端的深层语音识别模型经常需要大量的有标注的数据,而这些数据对于许多低资源语种是非常难以获取的。为了解决上述问题,有许多工作利用无监督预训练和半监督学习的方法去利用大量无标注数据来帮助低资源目标语种提升识别效果,但是这些方法依然需要大量目标语种的无标注数据,对于部分小语种来说,无标注数据也是很少量的。因此,迁移学习被引入解决低资源语种识别问题,迁移学习通过其他语种的数据来帮助目标低资源语种来提升识别效果。同时还有多语种迁移学习方法用多个其他源语种预训练模型初始化参数,于是只需要少量低资源目标语种数据在预训练模型基础上训练就可以得到较好的目标模型。但迁移学习的方法学习到的模型参数比较容易倾向于源语种而无法很好地进行迁移。除此之外,元学习的方法也被引入到低资源语音识别问题中。元学习的方法通过一系列训练任务来元学习得到模型初始化参数,以便能够快速地适应到只有少量数据的新任务上,这种方法十分适用于低资源的场景。然而,现有应用低资源语音识别和低资源语码转换的语音识别中,都忽略了在真实场景中任务不均衡的问题,现有技术平等地利用每个语种的元信息,从而导致了效果的损失。
技术实现思路
本专利技术目的在于,提供一种基于元学习的对抗采样训练方法及装置,以多语种元学习语音识别框架为基础,引入策略网络形成对抗训练,提升多语种低资源语音识别训练的效果。为实现上述目的,本专利技术实施例提供一种基于元学习的对抗采样训练方法,包括:根据策略网络从K个语种构成的大任务集T中输出K维概率向量其中,为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集;所述支持集对语音识别模型初始化参数θ进行梯度下降得到更新参数所述查询集根据查询所述更新参数的效果获得查询损失向量所述查询损失向量用于对所述初始化参数θ寻优,获得最优的模型参数。优选地,所述根据策略网络从K个语种构成的大任务集T中输出K维概率向量其中,为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集,包括:所述策略网络包括前馈注意力层和LSTM层;所述策略网络通过所述LSTM层中存储的长短期记忆信息和当前查询损失向量获取采样任务,其中,所述采样任务根据所述采样概率获得所述训练任务集。优选地,所述查询集根据查询所述更新参数的效果获得查询损失向量所述查询损失向量用于对所述初始化参数θ寻优,获得最优的模型参数,包括:每一次训练获取当前训练步的查询损失向量与概率向量,将所述查询损失向量与所述概率向量输入下一次训练的策略网络,将所述查询损失向量与所述概率向量合并计算前馈注意力,所述前馈注意力通过全连接层输出cs+1。优选地,所述查询集根据查询所述更新参数的效果获得查询损失向量所述查询损失向量用于对所述初始化参数θ寻优,获得最优的模型参数,还包括:将所述全连接层输出的cs+1与上一次训练中LSTM层的隐藏状态hs作为输入当前LSTM层的值,生成所述当前LSTM层的输出ys+1与当前隐藏状态hs+1,基于所述当前LSTM层的输出值通过全连接层与Softmax函数预测当前的概率向量根据所述概率向量,选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,获取所述查询损失向量并用于对所述初始化参数θ寻优,获得最优的模型参数。本专利技术实施例还提供了一种基于元学习的对抗采样训练装置,包括:训练模块,根据策略网络从K个语种构成的大任务集T中输出K维概率向量其中,为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集;第一更新模块,所述支持集对语音识别模型初始化参数θ进行梯度下降得到更新参数第二更新模块,所述查询集根据查询所述更新参数的效果获得查询损失向量所述查询损失向量用于对所述初始化参数θ寻优,获得最优的模型参数。优选地,所述训练模块,包括:所述策略网络包括前馈注意力层和LSTM层;所述策略网络通过所述LSTM层中存储的长短期记忆信息和当前查询损失向量获取采样任务,其中,所述采样任务根据所述采样概率获得所述训练任务集。优选地,所述第二更新模块,包括:每一次训练获取当前训练步的查询损失向量与概率向量,将所述查询损失向量与所述概率向量输入下一次训练的策略网络,将所述查询损失向量与所述概率向量合并计算前馈注意力,所述前馈注意力通过全连接层输出cs+1。优选地,所述第二更新模块,还包括:将所述全连接层输出的cs+1与上一次训练中LSTM层的隐藏状态hs作为输入当前LSTM层的值,生成所述当前LSTM层的输出ys+1与当前隐藏状态hs+1,基于所述当前LSTM层的输出值通过全连接层与Softmax函数预测当前的概率向量根据所述概率向量,选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,获取所述查询损失向量并用于对所述初始化参数θ寻优,获得最优的模型参数。本专利技术实施例还提供一种计算机终端设备,包括一个或多个处理器和存储器。存储器与所述处理器耦接,用于存储一个或多个程序;当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现如上述任一实施例所述的基于元学习的对抗采样训练方法。本专利技术实施例还提供一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如上述任一实施例所述的基于元学习的对抗采样训练方法。本专利技术实施例在元学习的基础上引入策略网络,策略网络由注意力机制与LSTM层融合而成,通过策略网络输出概率向量与元学习生成的查询损失向量进行更新迭代,寻找最优的模型参数,每一次训练任务中,语音识别网络朝着查询损失值尽可能小的方向优化,而策略网络则是朝着尽可能大的方向去优化,形成对抗训练,促进了语音识别网络的有效训练。附图说明为了更清楚地说明本专利技术的技术方案,下面将对实施方式中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本专利技术的一些实施方式,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。图1是本专利技术某一实施例提供的基于元学习的对抗采样训练方法的流程示意图;图2是本专利技术另一实施例提供的基于元学习的对抗采样训练方法的流程示意图;图3是本专利技术某一实施例提供的基于元学习的本文档来自技高网...

【技术保护点】
1.一种基于元学习的对抗采样训练方法,其特征在于,包括:/n根据策略网络从K个语种构成的大任务集T中输出K维概率向量

【技术特征摘要】
1.一种基于元学习的对抗采样训练方法,其特征在于,包括:
根据策略网络从K个语种构成的大任务集T中输出K维概率向量其中,为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集;
所述支持集对语音识别模型初始化参数θ进行梯度下降得到更新参数
所述查询集根据查询所述更新参数的效果获得查询损失向量所述查询损失向量用于对所述初始化参数θ寻优,获得最优的模型参数。


2.根据权利要求1所述的基于元学习的对抗采样训练方法,其特征在于,所述根据策略网络从K个语种构成的大任务集T中输出K维概率向量其中,为第i个语种任务集对应的采样概率,根据所述采样概率选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,将所述训练任务集划分为支持集和查询集,包括:
所述策略网络包括前馈注意力层和LSTM层;
所述策略网络通过所述LSTM层中存储的长短期记忆信息和当前查询损失向量获取采样任务,其中,所述采样任务根据所述采样概率获得所述训练任务集。


3.根据权利要求1所述的基于元学习的对抗采样训练方法,其特征在于,所述查询集根据查询所述更新参数的效果获得查询损失向量所述查询损失向量用于对所述初始化参数θ寻优,获得最优的模型参数,包括:
每一次训练获取当前训练步的查询损失向量与概率向量,将所述查询损失向量与所述概率向量输入下一次训练的策略网络,将所述查询损失向量与所述概率向量合并计算前馈注意力,所述前馈注意力通过全连接层输出cs+1。


4.根据权利要求3所述的基于元学习的对抗采样训练方法,其特征在于,所述查询集根据查询所述更新参数的效果获得查询损失向量所述查询损失向量用于对所述初始化参数θ寻优,获得最优的模型参数,还包括:
将所述全连接层输出的cs+1与上一次训练中LSTM层的隐藏状态hs作为输入当前LSTM层的值,生成所述当前LSTM层的输出ys+1与当前隐藏状态hs+1,基于所述当前LSTM层的输出值通过全连接层与Softmax函数预测当前的概率向量根据所述概率向量,选取前M个概率最大的语种,根据所述M个概率最大语种中每个语种采样一个任务构成训练任务集,获取所述查询损失向量并用于对所述初始化参数θ寻优,获得最优的模型参数。


5.一种基于元学习的对抗采样训...

【专利技术属性】
技术研发人员:肖雨蓓郑国林聂琳梁小丹王青林倞
申请(专利权)人:中山大学
类型:发明
国别省市:广东;44

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1