一种缓解多任务学习中任务冲突方法、装置及存储介质制造方法及图纸

技术编号:30137144 阅读:16 留言:0更新日期:2021-09-23 14:49
本发明专利技术公开了一种缓解多任务学习中任务冲突的方法、装置及存储介质,所述方法获取待缓解多任务学习模型中各个学习任务的梯度值;判断各学习任务中选定学习任务与其余各学习任务之间是否存在任务冲突,在判定存在任务冲突时对选定学习任务的梯度值进行修剪,并将选定学习任务的梯度值更新为修剪后的梯度值,在选定学习任务的梯度值更新执行完毕后重新选定一学习任务作为选定学习任务重复执行梯度值更新,直至模型中所有学习任务的梯度值更新完毕,计算所有学习任务完成梯度值更新后的梯度值的平均值,获得平均梯度值,根据平均梯度值对模型的网络参数进行更新。通过实施本发明专利技术能在实现缓解任务冲突同时保持各个任务训练均衡。均衡。均衡。

【技术实现步骤摘要】
一种缓解多任务学习中任务冲突方法、装置及存储介质


[0001]本专利技术涉及计算机
,尤其涉及一种缓解多任务学习中任务冲突方法、装置及存储介质。

技术介绍

[0002]深度学习在各个领域已经取得了不错的成绩,但是目前的人工智能依赖于海量数据的训练,模型泛化能力不佳,在有限数据领域下的效果和快速拓展到新任务的能力都不尽人意。针对这个问题,一些研究者提出多任务学习方法(MTL,Multi

Task Learning)来解决这个问题。多任务学习方法能够联合多个任务一起学习,一些数据有限的任务能够利用其他任务共享的信息进行训练,从而提高任务的表现。
[0003]基于优化的多任务学习方法是现有的多任务学习方法中的一种;而现有基于优化的多任务学习方法中,当任务梯度发生冲突或者任务被其他较大梯度的任务支配时,通常是通过设计不同的策略来调整各个任务loss的权重,对于训练比较快的任务,降低其权重,减少模型对其的关注程度,让模型多关注那些没有训练充分的任务,从而实现缓解任务训练的不平衡。但是通过调整各任务loss权重来平衡任务训练,会导致某些任务被其他任务所支配,从而得不到充分训练,无法实现各个任务训练均衡,进而降低了模型的整体性能。

技术实现思路

[0004]本专利技术实施例提供一种缓解多任务学习中任务冲突的方法及装置,能在实现多任务模型中各个任务训练均衡的同时缓解任务冲突。
[0005]本专利技术一实施例提供一种缓解多任务学习中任务冲突的方法,包括:
[0006]获取待缓解多任务学习模型中各个学习任务的梯度值;
[0007]任意选取一学习任务作为选定学习任务,根据所述选定学习任务执行梯度值更新;其中,所述梯度值更新具体包括:根据所述选定学习任务的梯度值以及其余各学习任务的梯度值,逐一判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;在每一次判定存在任务冲突时,对所述选定学习任务的梯度值进行梯度修剪,并将所述选定学习任务的梯度值更新为修剪后的梯度值,根据所述选定学习任务更新后的梯度值继续判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;
[0008]在所述梯度值更新执行完毕后,重新选取一学习任务作为更新后的选定学习任务,并重复执行所述梯度值更新,直至所有所述学习任务的梯度值更新完毕;
[0009]计算所有所述学习任务完成梯度值更新后的梯度值的平均值,获得平均梯度,根据所述平均梯度对所述待缓解多任务学习模型的网络参数进行更新。
[0010]进一步的,所述获取待缓解多任务学习模型中各个学习任务的梯度值,具体包括:计算每一所述学习任务的损失值,继而根据每一所述学习任务的损失值计算每一所述学习任务对所述待缓解多任务学习模型中网络参数的偏导数,获得每一所述学习任务的梯度值。
[0011]进一步的,所述对所述选定学习任务的梯度值进行梯度修剪,具体包括:
[0012]根据所述选定学习任务的梯度值以及冲突学习任务的梯度值,确定所述选定学习任务与所述冲突学习任务的冲突平面;其中,所述冲突学习任务为与所述选定学习任务存在任务冲突的学习任务;
[0013]分别计算所述选定学习任务以及所述冲突学习任务在所述冲突平面上的梯度分量,继而根据所述梯度分量计算选定学习任务与所述冲突学习任务之间的梯度分量差值;
[0014]根据所述梯度分量差值对所述选定学习任务的梯度值进行梯度值修剪,获得修剪后的梯度值。
[0015]进一步的,所述根据所述选定学习任务的梯度值以及冲突学习任务的梯度值,确定所述选定学习任务与所述冲突学习任务的冲突平面,具体包括:
[0016]通过以下公式确定所述选定学习任务与所述冲突学习任务的冲突平面:
[0017]P
γ
=g
i

g
j

[0018]其中,g
i
为所述选定学习任务的梯度值,g
j
为所述冲突学习任务的梯度值。
[0019]进一步的,所述分别计算所述选定学习任务以及所述冲突学习任务在所述冲突平面上的梯度分量,继而根据所述梯度分量计算选定学习任务与所述冲突学习任务之间的梯度分量差值,具体包括:
[0020]通过以下公式计算所述选定学习任务以及所述冲突学习任务与所述冲突平面的夹角的余弦值;
[0021][0022]通过以下公式计算所述选定学习任务以及所述冲突学习任务在所述冲突平面上的梯度分量:
[0023]Δg
i
=g
i
·
cosφ
i
,Δg
j
=g
j
·
cosφ
j

[0024]通过以下公式计算所述梯度分量计算选定学习任务与所述冲突学习任务之间的梯度分量差值:
[0025]Δg
{i,j}
=||g
i
·
cosφ
i

g
j
·
cosφ
j
||;
[0026]其中,cosφ
i
为所述选定学习任务与所述冲突平面的夹角的余弦值,cosφ
j
为所述冲突学习任务与所述冲突平面的夹角的余弦值,Δg
i
为所述选定学习任务在所述冲突平面上的梯度分量,Δg
j
为所述冲突学习任务在所述冲突平面上的梯度分量。
[0027]进一步的,所述根据所述梯度分量差值对所述选定学习任务的梯度值进行梯度值修剪,获得修剪后的梯度值,具体包括:根据所述梯度分量差值对所述选定学习任务的梯度值进行梯度值修剪,获得修剪后的梯度值
[0028]在上述方法项实施例的基础上,本专利技术对应提供了装置项实施例;
[0029]本专利技术一实施例提供了一种缓解多任务学习中任务冲突的装置,包括梯度值获取模块、学习任务梯度值更新模块以及模型参数值更新模块;
[0030]所述梯度值获取模块,用于获取待缓解多任务学习模型中各个学习任务的梯度值;
[0031]所述学习任务梯度值更新模块,用于任意选取一学习任务作为选定学习任务,根据所述选定学习任务执行梯度值更新;其中,所述梯度值更新具体包括:根据所述选定学习
任务的梯度值以及其余各学习任务的梯度值,逐一判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;在每一次判定存在任务冲突时,对所述选定学习任务的梯度值进行梯度修剪,并将所述选定学习任务的梯度值更新为修剪后的梯度值,根据所述选定学习任务更新后的梯度值继续判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;以及,
[0032]在所述梯度值更新执行完毕后,重新选取一学习任务作为更新后的选定学习任务,并重复执行所述梯度值更新,直至所有所述学习任务的梯度值更新完毕;
[0033]所述模型参数值更新模块,用于计算所有所述学习任务完成梯度值更新后的本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种缓解多任务学习中任务冲突的方法,其特征在于,包括:获取待缓解多任务学习模型中各个学习任务的梯度值;任意选取一学习任务作为选定学习任务,根据所述选定学习任务执行梯度值更新;其中,所述梯度值更新具体包括:根据所述选定学习任务的梯度值以及其余各学习任务的梯度值,逐一判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;在每一次判定存在任务冲突时,对所述选定学习任务的梯度值进行梯度修剪,并将所述选定学习任务的梯度值更新为修剪后的梯度值,根据所述选定学习任务更新后的梯度值继续判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;在所述梯度值更新执行完毕后,重新选取一学习任务作为更新后的选定学习任务,并重复执行所述梯度值更新,直至所有所述学习任务的梯度值更新完毕;计算所有所述学习任务完成梯度值更新后的梯度值的平均值,获得平均梯度,根据所述平均梯度对所述待缓解多任务学习模型的网络参数进行更新。2.如权利要求1所述的缓解多任务学习中任务冲突的方法,其特征在于,所述获取待缓解多任务学习模型中各个学习任务的梯度值,具体包括:计算每一所述学习任务的损失值,继而根据每一所述学习任务的损失值计算每一所述学习任务对所述待缓解多任务学习模型中网络参数的偏导数,获得每一所述学习任务的梯度值。3.如权利要求2所述的缓解多任务学习中任务冲突的方法,其特征在于,所述对所述选定学习任务的梯度值进行梯度修剪,具体包括:根据所述选定学习任务的梯度值以及冲突学习任务的梯度值,确定所述选定学习任务与所述冲突学习任务的冲突平面;其中,所述冲突学习任务为与所述选定学习任务存在任务冲突的学习任务;分别计算所述选定学习任务以及所述冲突学习任务在所述冲突平面上的梯度分量,继而根据所述梯度分量计算选定学习任务与所述冲突学习任务之间的梯度分量差值;根据所述梯度分量差值对所述选定学习任务的梯度值进行梯度值修剪,获得修剪后的梯度值。4.如权利要求3所述的缓解多任务学习中任务冲突的方法,其特征在于,所述根据所述选定学习任务的梯度值以及冲突学习任务的梯度值,确定所述选定学习任务与所述冲突学习任务的冲突平面,具体包括:通过以下公式确定所述选定学习任务与所述冲突学习任务的冲突平面:P
γ
=g
i

g
j
;其中,g
i
为所述选定学习任务的梯度值,g
j
为所述冲突学习任务的梯度值。5.如权利要求4所述的缓解多任务学习中任务冲突的方法,其特征在于,所述分别计算所述选定学习任务以及所述冲突学习任务在所述冲突平面上的梯度分量,继而根据所述梯度分量计算选定学习任务与所述冲突学习任务之间的梯度分量差值,具体包括:通过以下公式计算所述选定学习任务以及所述冲突学习任务与所述冲突平面的夹角...

【专利技术属性】
技术研发人员:廖清柴合言王晔漆舒汉王轩
申请(专利权)人:哈尔滨工业大学深圳
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1