一种多任务模型的训练方法及装置制造方法及图纸

技术编号:34693523 阅读:10 留言:0更新日期:2022-08-27 16:28
本公开提供一种多任务模型的训练方法及装置,基于包括类别标签的第一训练数据集,对待训练的多任务模型进行训练,基于包括特征标签的第二训练数据集,对使用第一训练数据集训练后的多任务模型进行训练,基于包括解码标签的第三训练数据集,对使用第二训练数据集训练后的多任务模型进行训练,以得到包括分类网络、解码网络以及特征提取网络的多任务模型,其中,特征提取网络是分类网络的子网络,解码网络是特征提取网络的子网络,本公开提供的训练方法,通过对待训练的多任务模型进行分层训练,以实现提升多任务模型的扩展性的同时提高对多任务模型的训练效率。对多任务模型的训练效率。对多任务模型的训练效率。

【技术实现步骤摘要】
一种多任务模型的训练方法及装置


[0001]本公开涉及智能驾驶
,尤其涉及一种多任务模型的训练方法及装置。

技术介绍

[0002]随着人工智能领域的技术突破,高级驾驶辅助系统和无人驾驶技术有了飞速发展,感知算法作为高级驾驶辅助系统和无人驾驶技术的重要部分,是车辆实现安全驾驶的先决条件。车辆可通过如车载摄像头等车载传感器获取车辆当前环境信息,之后车载处理系统通过感知算法对环境信息进行处理提炼,供车辆控制系统做出相应决策。
[0003]感知算法大多依赖于多任务模型实现计算,多任务模型能并行处理的子任务越多,说明该多任务模型的扩展性越好,但是,多任务模型的扩展性越好,对该多任务模型进行训练的时间越长,因此,如何在提升多任务模型的扩展性的同时、提高对多任务模型的训练效率是亟需解决的问题。

技术实现思路

[0004]本公开的实施例提供了一种多任务模型的训练方法及装置,以解决现有技术中对多任务模型的训练效率低的问题。具体地,本公开实施例提供如下技术方案:
[0005]根据本公开的第一个方面,提供了一种多任务模型的训练方法,包括:
[0006]基于包括类别标签的第一训练数据集,对待训练的多任务模型进行训练,得到包括分类网络的多任务模型,所述分类网络用于对预设任务进行分类处理;
[0007]基于包括特征标签的第二训练数据集,对使用第一训练数据集训练后的所述多任务模型进行训练,得到包括特征提取网络的多任务模型,所述特征提取网络用于识别训练数据在预设任务类别中的特征标签,所述特征提取网络是所述分类网络的子网络;
[0008]基于包括解码标签的第三训练数据集,对使用第二训练数据集训练后的所述多任务模型进行训练,得到包括解码网络的多任务模型,所述解码网络用于对训练数据对应的特征标签进行解码,所述解码网络是所述特征提取网络的子网络。
[0009]根据本公开的第二个方面,提供了一种多任务模型的训练装置,包括:
[0010]第一训练模块:基于包括类别标签的第一训练数据集,对待训练的多任务模型进行训练,得到包括分类网络的多任务模型,所述分类网络用于对预设任务进行分类处理;
[0011]第二训练模块:基于包括特征标签的第二训练数据集,对第一训练模块训练后的所述多任务模型进行训练,得到包括特征提取网络的多任务模型,所述特征提取网络用于识别训练数据在预设任务类别中的特征标签,所述特征提取网络是所述分类网络的子网络;
[0012]第三训练模块:基于包括解码标签的第三训练数据集,对所述第二训练模块训练后的所述多任务模型进行训练,得到包括解码网络的多任务模型,所述解码网络用于对训练数据对应的特征标签进行解码,所述解码网络是所述特征提取网络的子网络。
[0013]根据本公开的第三方面,提供了一种计算机可读存储介质,所述存储介质存储有
计算机程序,所述计算机程序用于执行上述的多任务模型的训练方法。
[0014]根据本公开的第四方面,提供了一种电子设备,所述电子设备包括:
[0015]处理器;
[0016]用于存储所述处理器可执行指令的存储器;
[0017]所述处理器,用于从所述存储器中读取所述可执行指令,并执行所述指令以实现上述的多任务模型的训练方法。
[0018]本公开的多任务模型的训练方法,基于包括类别标签的第一训练数据集,对待训练的多任务模型进行训练,基于包括特征标签的第二训练数据集,对使用第一训练数据集训练后的多任务模型进行训练,基于包括解码标签的第三训练数据集,对使用第二训练数据集训练后的多任务模型进行训练,以得到包括分类网络、特征提取网络以及解码网络的多任务模型,其中,特征提取网络是分类网络的子网络,解码网络是特征提取网络的子网络,本公开提供的训练方法,通过对待训练的多任务模型进行分层训练,以实现提升多任务模型的扩展性的同时、提高对多任务模型的训练效率。
附图说明
[0019]通过结合附图对本公开实施例进行更详细的描述,本公开的上述以及其他目的、特征和优势将变得更加明显。附图用来提供对本公开实施例的进一步理解,并且构成说明书的一部分,与本公开实施例一起用于解释本公开,并不构成对本公开的限制。在附图中,相同的参考标号通常代表相同部件或步骤。
[0020]图1是本公开一示例性实施例提供的待训练的多任务模型的结构示意图。
[0021]图2是本公开一示例性实施例提供的多任务模型训练系统的结构示意图。
[0022]图3是本公开一示例性实施例提供的多任务模型的训练方法的流程示意图。
[0023]图4是本公开一示例性实施例提供的多任务模型的训练方法的流程示意图。
[0024]图5是本公开一示例性实施例提供的多任务模型的训练方法的流程示意图。
[0025]图6是本公开一示例性实施例提供的多任务模型的训练方法的流程示意图。
[0026]图7是本公开一示例性实施例提供的多任务模型训练装置的结构示意图。
[0027]图8是本公开一示例性实施例提供的第一训练模块的结构示意图。
[0028]图9是本公开一示例性实施例提供的第二训练模块的结构示意图。
[0029]图10是本公开一示例性实施例提供的第三训练模块的结构示意图。
[0030]图11是本公开一示例性实施例提供的多任务模型训练装置的信息交互图。
[0031]图12是本公开一示例性实施例提供的电子设备的结构图。
具体实施方式
[0032]下面,将参考附图详细地描述根据本公开的示例实施例。显然,所描述的实施例仅仅是本公开的一部分实施例,而不是本公开的全部实施例,应理解,本公开不受这里描述的示例实施例的限制。
[0033]应注意到:除非另外具体说明,否则在这些实施例中阐述的部件和步骤的相对布置、数字表达式和数值不限制本公开的范围。
[0034]本领域技术人员可以理解,本公开实施例中的“第一”、“第二”等术语仅用于区别
不同步骤、设备或模块等,既不代表任何特定技术含义,也不表示它们之间的必然逻辑顺序。
[0035]还应理解,在本公开实施例中,“多个”可以指两个或两个以上,“至少一个”可以指一个、两个或两个以上。
[0036]还应理解,对于本公开实施例中提及的任一部件、数据或结构,在没有明确限定或者在前后文给出相反启示的情况下,一般可以理解为一个或多个。
[0037]另外,本公开中术语“和/或”,仅仅是一种描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。另外,本公开中字符“/”,一般表示前后关联对象是一种“或”的关系。
[0038]还应理解,本公开对各个实施例的描述着重强调各个实施例之间的不同之处,其相同或相似之处可以相互参考,为了简洁,不再一一赘述。
[0039]同时,应当明白,为了便于描述,附图中所示出的各个部分的尺寸并不是按照实际的比例关系绘制的。本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种多任务模型的训练方法,包括:基于包括类别标签的第一训练数据集,对待训练的多任务模型进行训练,得到包括分类网络的多任务模型,所述分类网络用于对预设任务进行分类处理;基于包括特征标签的第二训练数据集,对使用第一训练数据集训练后的所述多任务模型进行训练,得到包括特征提取网络的多任务模型,所述特征提取网络用于识别训练数据在预设任务类别中的特征标签,所述特征提取网络是所述分类网络的子网络;基于包括解码标签的第三训练数据集,对使用第二训练数据集训练后的所述多任务模型进行训练,得到包括解码网络的多任务模型,所述解码网络用于对训练数据对应的特征标签进行解码,所述解码网络是所述特征提取网络的子网络。2.根据权利要求1所述的方法,其中,所述第一训练数据集包括第一训练数据和所述第一训练数据对应的所述类别标签;基于包括类别标签的第一训练数据集,对待训练的多任务模型进行训练,包括:利用所述待训练的多任务模型对所述第一训练数据进行预测,得到所述第一训练数据对应的预测任务类别;根据所述预测任务类别和所述类别标签,确定第一损失值;根据所述第一损失值,调整所述分类网络对应的第一组参数。3.根据权利要求2所述的方法,其中,所述第二训练数据集包括第二训练数据和所述第二训练数据对应的特征标签;基于包括特征标签的第二训练数据集,对使用第一训练数据集训练后的所述多任务模型进行训练,包括:利用所述使用第一训练数据集训练后的所述多任务模型对所述第二训练数据进行预测,得到所述第二训练数据对应的预测特征信息;根据所述预测特征信息和所述特征标签,确定第二损失值;根据所述第二损失值,调整所述特征提取网络对应的第二组参数。4.根据权利要求3所述的方法,其中,所述第三训练数据集包括第三训练数据和所述第三训练数据对应的解码标签;基于包括解码标签的第三训练数据集,对使用第二训练数据集训练后的所述多任务模型进行训练,包括:利用所述使用第二训练数据集训练后的所述多任务模型对所述第三训练数据进行预测,得到所述第三训练数据对应的预测解码信息;根据所述预测解码信息和所述解码标签,确定第三损失值;根据所述第三损失值,调整所述解码网络对应的第三组参数。5.根据权利要求4所述的方法,还包括:确定所述第一损失值大于预设的第一阈值,基于所述第一损失值调整所述分类网络对应的第一组参数;确定所述第二损失值大于预设的第二阈值,基于所述第二损失值调整所述特征提取网络对应的第二组参数;确定所述第三损失值大于预设的第三阈值,基于所述第三损失值调整所述解码网络对应的第三组参数。
6.一种多任务模型的训练装置,包括:第一训练模块:基于包括类别标签的第一训练数据集,对待训...

【专利技术属性】
技术研发人员:杜敏
申请(专利权)人:北京地平线机器人技术研发有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1