分类模型的训练方法、装置、设备、存储介质及程序产品制造方法及图纸

技术编号:39261509 阅读:9 留言:0更新日期:2023-10-30 12:14
本申请提供了一种分类模型的训练方法、装置,包括:获取用于对象分类的初始分类模型,初始分类模型基于包括多个携带类别标签的第一对象样本的样本集合训练得到,多个第一对象样本的类别标签符合长尾分布;通过初始分类模型,对未携带类别标签的第二对象样本进行分类,得到分类结果;当分类结果表征第二对象样本属于长尾分布的尾部类别时,基于分类结果,对第二对象样本进行标注,得到携带初始伪标签的第二对象样本;对初始伪标签进行标签纠正,得到携带目标伪标签的第二对象样本;基于携带目标伪标签的第二对象样本,对初始分类模型进行训练,得到目标分类模型。通过本申请,能够提高分类模型在长尾分布场景下的分类准确性和鲁棒性。鲁棒性。鲁棒性。

【技术实现步骤摘要】
分类模型的训练方法、装置、设备、存储介质及程序产品


[0001]本申请涉及人工智能技术,尤其涉及一种分类模型的训练方法、装置、电子设备、计算机可读存储介质以及计算机程序产品。

技术介绍

[0002]人工智能(Artificial Intelligence,AI)是计算机科学的一个综合技术,通过研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。人工智能技术是一门综合学科,涉及领域广泛,例如自然语言处理技术以及机器学习/深度学习等几大方向,随着技术的发展,人工智能技术将在更多的领域得到应用,并发挥越来越重要的价值。
[0003]相关技术中,面向开放集的长尾分布学习对于计算机视觉、自然语言处理等领域均具有极为重要的意义。传统的深度学习系统聚焦于在类别平衡的闭合集上设计学习算法,而真实世界由于样本难以获取、人工标注成本高等问题,其面对的任务往往是长尾分布,而面向长尾分布的相关分类模型未能充分结合现实场景中大量可用的无标签样本来训练分类模型,使得分类模型在长尾分布场景下的分类结果的准确性低。

技术实现思路

[0004]本申请实施例提供一种分类模型的训练方法、装置、电子设备、计算机可读存储介质以及计算机程序产品,能够提高分类模型在长尾分布场景下的分类准确性和鲁棒性。
[0005]本申请实施例的技术方案是这样实现的:
[0006]本申请实施例提供一种分类模型的训练方法,包括:
[0007]获取用于对象分类的初始分类模型,所述初始分类模型基于包括多个第一对象样本的样本集合训练得到,其中,所述第一对象样本携带类别标签,且多个第一对象样本的类别标签符合长尾分布;
[0008]通过所述初始分类模型,对未携带类别标签的第二对象样本进行分类,得到分类结果;
[0009]当所述分类结果表征所述第二对象样本属于所述长尾分布的尾部类别时,基于所述分类结果,对所述第二对象样本进行标注,得到携带初始伪标签的第二对象样本;
[0010]对所述初始伪标签进行标签纠正,得到携带目标伪标签的第二对象样本;
[0011]基于携带所述目标伪标签的第二对象样本,对所述初始分类模型进行训练,得到目标分类模型。
[0012]本申请实施例提供一种分类模型的训练装置,包括:
[0013]获取模块,用于获取用于对象分类的初始分类模型,所述初始分类模型基于包括多个第一对象样本的样本集合训练得到,其中,所述第一对象样本携带类别标签,且多个第一对象样本的类别标签符合长尾分布;
[0014]分类模块,用于通过所述初始分类模型,对未携带类别标签的第二对象样本进行
分类,得到分类结果;
[0015]当所述分类结果表征所述第二对象样本属于所述长尾分布的尾部类别时,基于所述分类结果,对所述第二对象样本进行标注,得到携带初始伪标签的第二对象样本;
[0016]纠正模块,用于对所述初始伪标签进行标签纠正,得到携带目标伪标签的第二对象样本;
[0017]训练模块,用于基于携带所述目标伪标签的第二对象样本,对所述初始分类模型进行训练,得到目标分类模型。
[0018]本申请实施例提供一种电子设备,包括:
[0019]存储器,用于存储可执行指令;
[0020]处理器,用于执行所述存储器中存储的可执行指令时,实现本申请实施例提供的分类模型的训练方法。
[0021]本申请实施例提供一种计算机可读存储介质,其中存储有计算机可执行指令,当计算机可执行指令被处理器执行时,将引起处理器执行本申请实施例提供的分类模型的训练方法。
[0022]本申请实施例提供了一种计算机程序产品,该计算机程序产品包括计算机程序或计算机可执行指令,该计算机程序或计算机可执行指令存储在计算机可读存储介质中,电子设备的处理器从计算机可读存储介质读取该计算机可执行指令,处理器执行该计算机可执行指令,使得该电子设备执行本申请实施例提供的分类模型的训练方法。
[0023]本申请实施例具有以下有益效果:
[0024]应用本申请实施例,通过携带类别标签且类别标签符合长尾分布的第一对象样本训练得到初始分类模型,如此,能够提高分类模型针对符合长尾分布的头部类别的样本的分类准确性;同时,通过初始分类模型对处于长尾分布尾部类别、且未携带类别标签的第二对象样本添加初始伪标签,并通过对初始伪标签的迭代纠正操作,得到携带目标伪标签的第二对象样本,如此,能够实现针对长尾分布的尾部类别的对象样本的数据增强;最后,采用携带目标伪标签的第二对象样本对初始分类模型进行再次训练,得到目标分类模型,如此,能够通过对有类别标签的第一对象样本和无类别标签的第二对象样本的分布再平衡操作,实现针对初始分类模型的联合训练,提高分类模型在长尾分布场景下的分类准确性和鲁棒性。
附图说明
[0025]图1是本申请实施例提供的分类模型的训练系统100的架构示意图;
[0026]图2是本申请实施例提供的实施分类模型的训练方法的电子设备500的结构示意图;
[0027]图3是本申请实施例提供的分类模型的训练方法的流程示意图;
[0028]图4是本申请实施例提供的初始分类模型的结构示意图;
[0029]图5是本申请实施例提供的初始分类模型的分类流程图;
[0030]图6是本申请实施例提供的标签纠正方法流程示意图;
[0031]图7是本申请实施例提供的标签纠正模型的训练方式示意图;
[0032]图8是本申请实施例提供的初始分类模型的训练方式流程图;
[0033]图9是本申请实施例提供的基于损失函数更新分类模型的方法示意图;
[0034]图10是本申请实施例提供的针对待分类对象的分类方法示意图;
[0035]图11是本申请实施例提供的类别子空间的构建方法示意图;
[0036]图12是本申请实施例提供的确定类别子空间的方式示意图;
[0037]图13是本申请实施例提供的类别子空间对应的投影空间的确定方式示意图;
[0038]图14是本申请实施例提供的投影空间确定方式示意图;
[0039]图15是本申请实施例提供的确定待分类对象相对于每个类别子空间的距离的方法示意图;
[0040]图16是本申请实施例提供的基于半监督学习的分类模型的训练算法的代码示意图;
[0041]图17是本申请实施例提供的构造类别子空间的方法示意图。
具体实施方式
[0042]为了使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请作进一步地详细描述,所描述的实施例不应视为对本申请的限制,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本申请保护的范围。
[0043]在以下的描述中,涉及到“一些实施例”,其描述了所有可能实施例的子集,但是可以理解,“一些实施例”可以是所有可能实施例的相同子集或不同子集,并且可以在不冲突的情况下相互结合。
[0044]在本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种分类模型的训练方法,其特征在于,所述方法包括:获取用于对象分类的初始分类模型,所述初始分类模型基于包括多个第一对象样本的样本集合训练得到,其中,所述第一对象样本携带类别标签,且多个第一对象样本的类别标签符合长尾分布;通过所述初始分类模型,对未携带类别标签的第二对象样本进行分类,得到分类结果;当所述分类结果表征所述第二对象样本属于所述长尾分布的尾部类别时,基于所述分类结果,对所述第二对象样本进行标注,得到携带初始伪标签的第二对象样本;对所述初始伪标签进行标签纠正,得到携带目标伪标签的第二对象样本;基于携带所述目标伪标签的第二对象样本,对所述初始分类模型进行训练,得到目标分类模型。2.如权利要求1所述的方法,其特征在于,所述对所述初始伪标签进行标签纠正,得到携带目标伪标签的第二对象样本,包括:获取第二对象样本的初始特征、及所述初始伪标签的标签特征;对所述初始特征与所述标签特征进行特征拼接,得到所述第二对象样本的拼接特征;基于所述拼接特征,对所述第二对象样本进行标签转换,得到携带目标伪标签的第二对象样本。3.如权利要求2所述的方法,其特征在于,所述方法还包括:对所述拼接特征进行降维处理,得到降维后的拼接特征,所述降维后的拼接特征的维度与所述标签特征的维度相同;所述基于所述拼接特征,对所述第二对象样本进行标签转换,得到携带目标伪标签的第二对象样本,包括:对所述降维后的拼接特征进行非线性变换,得到所述目标伪标签;将所述目标伪标签替换所述初始伪标签,得到携带目标伪标签的第二对象样本。4.如权利要求1所述的方法,其特征在于,所述初始分类模型包括特征提取层、特征分类层、标签纠正层;所述通过所述初始分类模型,对未携带类别标签的第二对象样本进行分类,得到分类结果,包括:通过所述特征提取层,对未携带类别标签的第二对象样本进行特征提取,得到所述第二样本的初始特征;通过所述特征分类层,基于所述初始特征,对所述第二对象样本进行分类,得到所述第二对象样本的分类结果,所述分类结果用于指示所述第二对象样本归属的所述长尾分布中的类别;所述对所述初始伪标签进行标签纠正,得到携带目标伪标签的第二对象样本,包括:通过所述标签纠正层,对所述初始伪标签进行标签纠正,得到携带目标伪标签的所述第二对象样本。5.如权利要求1所述的方法,其特征在于,所述第二对象样本的数量为多个,所述基于携带所述目标伪标签的第二对象样本,对所述初始分类模型进行训练,得到目标分类模型,包括:通过所述初始分类模型,获取所述样本集合中的第一对象样本的第一预测结果,并基于所述第一预测结果与所述第一对象样本的类别标签之间的差异,确定第一损失函数的
值;通过所述初始分类模型,获取所述第二对象样本的第二预测结果,并基于所述第二预测结果与所述第二对象样本的目标伪标签之间的差异,确定第二损失函数的值;结合所述第一损失函数的值与所述第二损失函数的值,更新所述初始分类模型的模型参数,得到目标分类模型。6.如权利要求5所述的方法,其特征在于,所述结合所述第一损失函数的值与所述第二损失函数的值,更新所述初始分类模型的模型参数,得到目标分类模型,包括:获取所述初始分类模型对应的用于更新模型参数的学习率;对所述第一损失函数的值与所述第二损失函数的值进行求和处理,并对求和结果与所述学习率进行求积处理,得到求积结果;基于所述求积结果,更新所述初始分类模型的模型参数,得到目标分类模型。7.如权利要求1所述的方法,其特征在于,所述对所述初始伪标签进行标签纠正,得到携带目标伪标签的第二对象样本,包括:获取标签纠正模型,并通过所述标签纠正模型,对所述初始伪标签进行标签纠正,得到携带目标伪标签的第二对象样本;在获取所述标签纠正模型之前,所述方法还包括:通过所述初始分类模型,对携带类别标签的第三对象样本进行分类,得到第一分类结果,并基于所述第一分类结果与所述第三对象样本的类别标签,确定第三损失函数的值;获取初始标签纠正模型,并确定所述第三损失函数的值相对于所述初始标签纠正模型的梯度;基于所述梯度,更新所述初始标签纠正模型的模型参数,得到所述标签纠正模型。8.如权利要求1所述的方法,其特征在于,所述基于携带所述目标伪标签的第二对象样本,对所述初始分类模型进行训练,得到目标分类模型之后,所述方法还包括:获取对象库中的对象所对应的至少两个类别,并构建每个所述类别对应的类别子空间,每个所述类别子空间包括至少两个对象,所述对象库中包括携带类别标签的第一对象,以及携带目标伪标签的第二对象;获取待分类对象,并通过所述目标分类模型,获取所述待分类对象的初始特征;确定每个所述类别子空间所指示类别的平均特征;基于所述初始特征以及每个所...

【专利技术属性】
技术研发人员:李蓝青周志鹏
申请(专利权)人:腾讯科技深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1