模型训练方法、装置、设备、存储介质及程序产品制造方法及图纸

技术编号:39328410 阅读:7 留言:0更新日期:2023-11-12 16:05
本申请公开了一种模型训练方法、装置、设备、存储介质及程序产品,涉及人工智能领域。该方法包括:获取第一样本数据集和第二样本数据集;通过第一候选预测模型对第一数据进行数据预测,基于得到的第一预测结果和样本标签之间的差异确定第一损失值;通过第二候选预测模型对第二数据进行数据预测,基于得到的第二预测结果和伪标签之间的差异确定第二损失值;第一候选预测模型和第二候选预测模型之间共享特征提取网络;基于第一损失值和第二损失值对第一候选预测模型进行训练,得到数据预测模型。一定程度上提升数据预测模型对有偏差的伪标签的容忍度,从而提升数据预测模型的鲁棒性。可应用于云技术、人工智能、智慧交通、辅助驾驶等各种场景。等各种场景。等各种场景。

【技术实现步骤摘要】
模型训练方法、装置、设备、存储介质及程序产品


[0001]本申请实施例涉及人工智能领域,特别涉及一种模型训练方法、装置、设备、存储介质及程序产品。

技术介绍

[0002]在对数据识别的过程中,往往会为各类型的数据进行标签标注,该标签可以用来指示数据对应的内容、场景、类型等等,例如,为图像进行场景标注,示例性的,为图像a标注“风景”标签。可见,数据标签是理解各类型数据的一项重要技术。
[0003]相关技术中,通常是采用人工批注的方式为大批量的数据标注标签,而后采用深度学习模型进行训练来获取数据标签。
[0004]然而,此方案中,人工标注的方式标注成本昂贵,标注周期较长,且数据标签数量过多容易造成错标或者漏标,导致最终训练得到的深度学习模型鲁棒性差,一定程度上降低数据预测的准确率。

技术实现思路

[0005]本申请实施例提供了一种模型训练方法、装置、设备、存储介质及程序产品,提高对数据预测的准确率。技术方案如下。
[0006]一方面,提供了一种模型训练方法,该方法包括:
[0007]获取第一样本数据集和第二样本数据集,所述第一样本数据集中包括标注有样本标签的第一数据,所述第二样本数据集中包括标注有伪标签的第二数据,所述样本标签用于指示所述第一数据的数据内容,所述伪标签是通过预先训练得到的伪标签预测模型对所述第二数据进行内容预测得到的标签;
[0008]通过第一候选预测模型对所述第一数据进行数据预测,基于得到的第一预测结果和所述样本标签之间的差异确定第一损失值;/>[0009]通过第二候选预测模型对所述第二数据进行数据预测,基于得到的第二预测结果和所述伪标签之间的差异确定第二损失值;其中,所述第一候选预测模型和所述第二候选预测模型之间共享特征提取网络;
[0010]基于所述第一损失值和所述第二损失值对所述第一候选预测模型进行训练,得到数据预测模型。
[0011]另一方面,提供了一种模型训练装置,该装置包括:
[0012]获取模块,用于获取第一样本数据集和第二样本数据集,所述第一样本数据集中包括标注有样本标签的第一数据,所述第二样本数据集中包括标注有伪标签的第二数据,所述样本标签用于指示所述第一数据的数据内容,所述伪标签是通过预先训练得到的伪标签预测模型对所述第二数据进行内容预测得到的标签;
[0013]确定模块,用于通过第一候选预测模型对所述第一数据进行数据预测,基于得到的第一预测结果和所述样本标签之间的差异确定第一损失值;
[0014]所述确定模块,还用于通过第二候选预测模型对所述第二数据进行数据预测,基于得到的第二预测结果和所述伪标签之间的差异确定第二损失值;其中,所述第一候选预测模型和所述第二候选预测模型之间共享特征提取网络;
[0015]训练模块,用于基于所述第一损失值和所述第二损失值对所述第一候选预测模型进行训练,得到数据预测模型。
[0016]另一方面,提供了一种计算机设备,计算机设备包括处理器和存储器,存储器中存储有至少一条指令、至少一段程序、代码集或指令集,至少一条指令、至少一段程序、代码集或指令集由处理器加载并执行以实现如上述本申请实施例中任一所述模型训练方法。
[0017]另一方面,提供了一种计算机可读存储介质,存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,至少一条指令、至少一段程序、代码集或指令集由处理器加载并执行以实现如上述本申请实施例中任一所述的模型训练方法。
[0018]另一方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中,计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述实施例中任一所述的模型训练方法。
[0019]本申请实施例提供的技术方案带来的有益效果至少包括:
[0020]将对标注有样本标签的第一数据的处理过程,和对标注有伪标签的第二数据的处理过程进行解耦,两个处理过程共享一个特征提取网络,再分别利用不同的候选预测模型对各自输入的数据进行数据预测,得到第一数据对应的第一预测结果,和第二数据对应的第二预测结果。而后根据第一预测结果和样本标签之间的差异确定出的第一损失值,以及根据第二预测结果和伪标签之间的差异确定出的第二损失值,共同对第一候选预测模型进行训练,使第一候选预测模型在训练过程中能够学习到第二数据中伪标签的影响,一定程度上提升数据预测模型对有偏差的伪标签的容忍度,从而提升数据预测模型的鲁棒性。
附图说明
[0021]为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
[0022]图1是相关技术中对数据进行预测的流程示意图一个示例性实施例提供的实施环境示意图;
[0023]图2是本申请一个示例性实施例提供的模型训练方法的整体框架图;
[0024]图3是本申请一个示例性实施例提供的实施环境示意图;
[0025]图4是本申请一个示例性实施例提供的模型训练方法的流程图;
[0026]图5是本申请一个示例性实施例提供的数据处理的流程框图;
[0027]图6是基于图4提供的一个用于确定数据预测模型的流程图;
[0028]图7是基于图4提供的另一个用于确定数据预测模型的流程图;
[0029]图8是本申请另一个示例性实施例提供的模型训练方法的流程图;
[0030]图9是本申请一个示例性实施例提供的数据预测的结果示意图;
[0031]图10是本申请一个示例性实施例提供的模型训练装置的结构框图;
[0032]图11是本申请另一个示例性实施例提供的模型训练装置的结构框图;
[0033]图12是本申请一个示例性实施例提供的服务器的结构框图。
具体实施方式
[0034]为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
[0035]首先,针对本申请实施例中涉及的名词进行简单介绍。
[0036]样本数据集:是指用于对模型进行训练的数据,该样本数据包括但不限于图像数据、音频数据、视频数据、文本数据等。本申请实施例中,样本数据集包括第一样本数据集和第二样本数据集,其中,第一样本数据集中包括标注有样本标签的第一数据,第二样本数据中包括标注有伪标签的第二数据,样本标签和伪标签均用于表征样本数据的数据内容、场景信息、数据类型等。示例性的,当样本数据为图像数据时,样本标签和伪标签可以是用于表征图像数据对应的场景信息和/或图像内容。伪标签用于指示将未标注样本标签的数据输入伪标签预测模型后,伪标签预测模型为未标注样本标签的数据标注标识数据内容的标签信息。在本申请实施例中,第二数据的标签标注准确率低于第本文档来自技高网...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,所述方法包括:获取第一样本数据集和第二样本数据集,所述第一样本数据集中包括标注有样本标签的第一数据,所述第二样本数据集中包括标注有伪标签的第二数据,所述样本标签用于指示所述第一数据的数据内容,所述伪标签是通过预先训练得到的伪标签预测模型对所述第二数据进行内容预测得到的标签;通过第一候选预测模型对所述第一数据进行数据预测,基于得到的第一预测结果和所述样本标签之间的差异确定第一损失值;通过第二候选预测模型对所述第二数据进行数据预测,基于得到的第二预测结果和所述伪标签之间的差异确定第二损失值;其中,所述第一候选预测模型和所述第二候选预测模型之间共享特征提取网络;基于所述第一损失值和所述第二损失值对所述第一候选预测模型进行训练,得到数据预测模型。2.根据权利要求1所述的方法,其特征在于,所述样本标签用于表示所述第一数据对应标签集的取值分布,所述伪标签用于表示所述第二数据对应所述标签集的取值分布,其中,所述取值分布中,标注有第一取值的标签表示与数据内容相关,标注有第二取值的标签表示与数据内容无关。3.根据权利要求2所述的方法,其特征在于,所述通过第一候选预测模型对所述第一数据进行数据预测,基于得到的第一预测结果和所述样本标签之间的差异确定第一损失值,包括:通过所述特征提取网络对所述第一数据进行特征提取,得到所述第一数据对应的第一特征表示;通过所述第一候选预测模型中的第一预测网络将所述第一特征表示与所述标签集进行匹配,得到所述第一特征表示对应所述标签集的第一取值分布作为所述第一预测结果;基于所述第一取值分布与所述样本标签之间的差异确定所述第一损失值;所述通过第二候选预测模型对所述第二数据进行数据预测,基于得到的第二预测结果和所述伪标签之间的差异确定第二损失值,包括:通过所述特征提取网络对所述第二数据进行特征提取,得到所述第二数据对应的第二特征表示;通过所述第二候选预测模型中的第二预测网络将所述第二特征表示与所述标签集进行匹配,得到所述第二特征表示对应所述标签集的第二取值分布作为所述第二预测结果;基于所述第二取值分布与所述伪标签之间的差异确定所述第二损失值。4.根据权利要求1至3任一所述的方法,其特征在于,所述基于所述第一损失值和所述第二损失值对所述第一候选预测模型进行训练,得到数据预测模型,包括:确定所述第一样本数据集中所述第一数据的第一数量;确定所述第二样本数据集中所述第二数据的第二数量,所述第二数量大于所述第一数量;基于所述第一数量对应的第一权重对所述第一损失值的加权处理结果,以及所述第二数量对应的第二权重对所述第二损失值的加权处理结果,得到预测损失值;基于所述预测损失值对所述第一候选预测模型进行训练,得到所述数据预测模型。
5.根据权利要求4所述的方法,其特征在于,所述第一损失值中包括第一数量的第一数据对应的损失和,所述第二损失值中包括第二数量的第二数据对应的损失和;所述基于所述第一数量对应的第一权重对所述第一损失值的加权处理结果,以及所述第二数量对应的第二权重对所述第二损失值的加权处理结果,得到预测损失值,包括...

【专利技术属性】
技术研发人员:高信凯谯睿智
申请(专利权)人:腾讯科技深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1