基于任务的焦点损失提升多语言元学习语音识别方法技术

技术编号:35104063 阅读:21 留言:0更新日期:2022-10-01 17:13
本发明专利技术提供一种基于任务的焦点损失提升多语言元学习语音识别方法。该方法基于任务的焦点损失改进多语言元学习对任务不平衡的忽略,基于每个任务的查询损失引入了难任务调节器,引导模型更加关注难任务,并且为了充分利用难任务的数据,同时使用支持集梯度与查询集梯度来更新元参数。此外,本发明专利技术还在样本层面解释了难任务调节器的意义,经过公式推导,发现它与任务内样本的预测概率乘积成反相关。通过使用本发明专利技术方法,可以使模型学习到的初始化更加均衡,更加充分地利用了所有源语言的知识,从而能够有效的对目标语言进行泛化。从而能够有效的对目标语言进行泛化。从而能够有效的对目标语言进行泛化。

【技术实现步骤摘要】
基于任务的焦点损失提升多语言元学习语音识别方法


[0001]本专利技术涉及语音识别
,尤其涉及一种基于任务的焦点损失提升多语言元学习语音识别方法。

技术介绍

[0002]多语言预训练中,当从不同的源语言中学习时,不同种类的语言有不同的训练数据规模和不同的语音系统,造成任务难度的不平衡。这样一来,模型的初始化就倾向于接近大尺度和容易的语言,而偏离了小尺度和困难的语言,使得模型学习到的初始化不是最佳的,从而降低了对目标语言的泛化性。

技术实现思路

[0003]针对多语言语音识别中存在因任务难度不平衡造成模型泛化性较低的问题,本专利技术提供一种基于任务的焦点损失提升多语言元学习语音识别方法,引导模型专注于那些困难的任务,从而使模型的初始化在各种语言之间更加平衡。
[0004]本专利技术提供的基于任务的焦点损失提升多语言元学习语音识别方法,采用端到端的语音识别网络架构,具体包括:
[0005]步骤1:初始化语音识别模型f
θ
,输入原始语音特征序列(x1,x2,...,x
T
);
[0006]步骤2:针对从多语言数据集中抽取的任务T
i
,将所述任务T
i
分为支持集和查询集表示第i种语言数据;
[0007]步骤3:计算任务T
i
的ASR损失,使用梯度下降得到在支持集上更新后的参数θ
i

[0008]步骤4:使用在支持集上更新后的参数θ
i
在查询集上计算查询损失
[0009]步骤5:根据任务T
i
的查询损失计算得到任务T
i
的难任务调节器M
TFL
(θ),所述难任务调节器M
TFL
(θ)用于表示任务T
i
的学习难度等级;其中,查询损失越大,则对应的难任务调节器M
TFL
(θ)越大;
[0010]步骤6:重复N次步骤2至步骤5,计算得到N个任务对应的查询损失和难任务调节器;
[0011]步骤7:基于所有N个任务对应的查询损失和难任务调节器计算得到基于任务的焦点损失L
TFL

[0012]步骤8:使用所述焦点损失L
TFL
更新语音识别模型f
θ
的元参数θ;
[0013]步骤9:重复步骤2至步骤8,直至更新后的语音识别模型f
θ
满足给定要求。进一步地,步骤5中,难任务调节器M
TFL
(θ)的计算公式为:
[0014][0015]其中,k≥0和γ≥0为可调超参数。
[0016]进一步地,步骤7中,基于任务的焦点损失L
TFL
的计算公式为:
[0017][0018]其中,是基础学习器的损失函数。
[0019]进一步地,步骤8中,元参数θ的更新公式为:
[0020][0021]其中,β表示学习率。
[0022]进一步地,所述端到端的语音识别网络架构具体采用CTC

注意力联合架构;
[0023]对应的,步骤3中,任务T
i
的ASR损失的计算公式为:
[0024]L=λL
ctc
+(1

λ)L
att
[0025]其中,L
ctc
为CTC损失,L
att
为解码损失,超参数λ表示L
ctc
的权重。
[0026]本专利技术的有益效果:
[0027]目前的多语言元学习(MML

ASR)方法中忽略了任务之间的不平衡,导致模型初始化倾向于简单且数据量大的语言而远离困难且数据量小的语言,从而对新任务的泛化能力下降,如图4(a)所示。本专利技术提出的基于任务的焦点损失改进多语言元学习对任务不平衡的忽略,基于每个任务的查询损失引入了难任务调节器,引导模型更加关注难任务,并且为了充分利用难任务的数据,同时使用支持集梯度与查询集梯度来更新元参数。此外,本专利技术还在样本层面解释了难任务调节器的意义,经过公式推导,发现它与任务内样本的预测概率乘积成反相关。通过使用本专利技术方法,可以使模型学习到的初始化更加均衡,如图4(b)所示,更加充分地利用了所有源语言的知识,从而能够有效的对目标语言进行泛化。
附图说明
[0028]图1为本专利技术实施例提供的基于任务的焦点损失提升多语言元学习语音识别方法的流程示意图;
[0029]图2为本专利技术实施例提供的MML

ASR和TFL

MML

ASR的示意图;
[0030]图3为本专利技术实施例提供的基于CTC

注意力联合架构(joint CTC

Attention)多语言学习模型结构示意图;
[0031]图4为本专利技术实施例提供的使用基于任务的焦点损失后的损失与原损失的关系图;
[0032]图5为本专利技术实施例提供的在IPARA BABEL数据集和OpenSLR数据集上的多语言预训练验证集字符错误率(%);
[0033]图6为本专利技术实施例提供的在25%越南语下不同方法微调曲线;
[0034]图7为本专利技术实施例提供的不同参数下TFL

MML

ASR的验证集字符错误率;
[0035]图8为本专利技术实施例提供的在IPARA BABEL数据集和OpenSLR数据集上的多语言预训练在有无支持集梯度下的验证集字符错误率(%)。
具体实施方式
[0036]为使本专利技术的目的、技术方案和优点更加清楚,下面将结合本专利技术实施例中的附
图,对本专利技术实施例中的技术方案进行清楚地描述,显然,所描述的实施例是本专利技术一部分实施例,而不是全部的实施例。基于本专利技术中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本专利技术保护的范围。
[0037]在像MAML这样的元学习器中,当一个任务的查询损失很大时,意味着模型不能很好地学习这个任务,本专利技术认为这样的任务是一个难任务。而对于这些难任务,本专利技术认为最好同时使用该任务的支持集梯度和查询集梯度来学习它,这样做的原因在于:本专利技术认为有两个原因导致任务的查询损失较大,一个原因是模型在支持集上学习得不好,另一个原因是尽管在支持集上训练得比较好,但模型在查询集上表现得很差。对于第一个原因,引入支持集梯度显然是有用的。对于第二个原因,任务的查询损失较大意味着任务的难度;此时,最好使用更多的数据来再次训练,特别是在很少的实例的情况下。因此,引入支持集梯度来更新难任务的元参数是很有好处的。特别是,我们希望使模型更专注于难任务,这就意味着需要提高有大查询损失的任务的权重。
[0038]基于上述内容,本专利技术构思主要包括两个方面:一方面在外循环的查询损失之本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.基于任务的焦点损失提升多语言元学习语音识别方法,其特征在于,所述方法采用端到端的语音识别网络架构,具体包括:步骤1:初始化语音识别模型f
θ
,输入原始语音特征序列(x1,x2,...,x
T
);步骤2:针对从多语言数据集中抽取的任务T
i
,将所述任务T
i
分为支持集和查询集和查询集表示第i种语言数据;步骤3:计算任务T
i
的ASR损失,使用梯度下降得到在支持集上更新后的参数θ
i
;步骤4:使用在支持集上更新后的参数θ
i
在查询集上计算查询损失步骤5:根据任务T
i
的查询损失计算得到任务T
i
的难任务调节器M
TFL
(θ),所述难任务调节器M
TFL
(θ)用于表示任务T
i
的学习难度等级;其中,查询损失越大,则对应的难任务调节器M
TFL
(θ)越大;步骤6:重复N次步骤2至步骤5,计算得到N个任务对应的查询损失和难任务调节器;步骤7:基于所有N个任务对应的查询损失和难任务调节器计算得到基于任务的焦点损失L
TFL
;步骤8:使用所述焦点损失L
...

【专利技术属性】
技术研发人员:屈丹陈雅淇杨绪魁张文林张昊陈琦李静涛
申请(专利权)人:郑州信大先进技术研究院
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1