分类模型的训练方法和装置制造方法及图纸

技术编号:39296733 阅读:11 留言:0更新日期:2023-11-07 11:04
本申请涉及人工智能技术领域,提供了一种分类模型的训练方法和装置。该方法包括:获取待训练数据的头部特征和尾部特征;根据头部特征中的类别信息,提取头部特征的同类特征对;根据头部特征的同类特征对,提取与类别无关的特征;利用非全局注意力机制,将与类别无关的特征与尾部特征融合,得到自适应增广特征;融合自适应增广特征和尾部特征,得到增广尾部特征;对头部特征、尾部特征和增广尾部特征进行特征提取后并进行分类预测,得到分类结果;根据分类损失和预设尾部增广处理损失,更新分类模型的参数。该方法能够提升分类模型的分类性能。本发明专利技术实施例可应用于云技术、人工智能、智慧交通、辅助驾驶等各种场景。辅助驾驶等各种场景。辅助驾驶等各种场景。

【技术实现步骤摘要】
分类模型的训练方法和装置


[0001]本申请涉及计算机视觉技术和人工智能
,特别是涉及一种分类模型的训练方法、装置、计算机设备、存储介质和计算机程序产品。

技术介绍

[0002]随着人工智能技术的快速发展,人工智能被广泛应用在各行各业。以人工智能在图像处理上的应用为例,利用人工智能进行机器学习训练分类模型,能够提高图像分类的效率和精度。
[0003]其中,分类模型的精度还受限于训练样本,训练样本越多,标注结果越精确,训练得到的分类模型的预测精度也越精确。然而,在实际应用中,训练样本通常表现为长尾类分布。位于头部的一小部分类别含有较多数量的样本,剩下的类别含有较少数量的样本。样本数据据的长尾分布会,会使得模型在数量较多的样本上学习效果更好,而在数量较小的样本上学习效果更差,从而降低了模型的泛化性能,对模型性能造成影响。

技术实现思路

[0004]基于此,有必要针对上述技术问题,提供一种能够提高模型性能的分类模型的训练方法、装置、计算机设备、计算机可读存储介质和计算机程序产品。
[0005]第一方面,本申请提供了一种分类模型的训练方法。所述方法包括:
[0006]获取待训练数据的头部特征和尾部特征;
[0007]根据所述头部特征中的类别信息,提取所述头部特征的同类特征对;
[0008]根据所述头部特征的同类特征对,提取与类别无关的特征;
[0009]利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征;
[0010]融合所述自适应增广特征和所述尾部特征,得到增广尾部特征;
[0011]对所述头部特征、所述尾部特征和所述增广尾部特征进行特征提取后并进行分类预测,得到分类结果;
[0012]根据分类损失和预设尾部增广处理损失,更新所述分类模型的参数,直到所述分类模型收敛时,得到训练好的分类模型。
[0013]第二方面,本申请还提供了一种分类模型的训练装置。所述装置包括:
[0014]特征获取模块,用于获取待训练数据的头部特征和尾部特征;
[0015]同类特征对获取模块,用于根据所述头部特征中的类别信息,提取所述头部特征的同类特征对;
[0016]特征提取模块,用于根据所述头部特征的同类特征对,提取与类别无关的特征;
[0017]自适应增广模块,用于利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征;
[0018]尾部增广模块,用于融合所述自适应增广特征和所述尾部特征,得到增广尾部特
征;
[0019]分类预测模块,用于对所述头部特征、所述尾部特征和所述增广尾部特征进行特征提取后并进行分类预测,得到分类结果;
[0020]调整模块,用于根据分类损失和预设尾部增广处理损失,更新所述分类模型的参数,直到所述分类模型收敛时,得到训练好的分类模型。
[0021]第三方面,本申请还提供了一种计算机设备。所述计算机设备包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现以下步骤:
[0022]获取待训练数据的头部特征和尾部特征;
[0023]根据所述头部特征中的类别信息,提取所述头部特征的同类特征对;
[0024]根据所述头部特征的同类特征对,提取与类别无关的特征;
[0025]利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征;
[0026]融合所述自适应增广特征和所述尾部特征,得到增广尾部特征;
[0027]对所述头部特征、所述尾部特征和所述增广尾部特征进行特征提取后并进行分类预测,得到分类结果;
[0028]根据分类损失和预设尾部增广处理损失,更新所述分类模型的参数,直到所述分类模型收敛时,得到训练好的分类模型。
[0029]第四方面,本申请还提供了一种计算机可读存储介质。所述计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现以下步骤:
[0030]获取待训练数据的头部特征和尾部特征;
[0031]根据所述头部特征中的类别信息,提取所述头部特征的同类特征对;
[0032]根据所述头部特征的同类特征对,提取与类别无关的特征;
[0033]利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征;
[0034]融合所述自适应增广特征和所述尾部特征,得到增广尾部特征;
[0035]对所述头部特征、所述尾部特征和所述增广尾部特征进行特征提取后并进行分类预测,得到分类结果;
[0036]根据分类损失和预设尾部增广处理损失,更新所述分类模型的参数,直到所述分类模型收敛时,得到训练好的分类模型。
[0037]第五方面,本申请还提供了一种计算机程序产品。所述计算机程序产品,包括计
[0038]获取待训练数据的头部特征和尾部特征;
[0039]根据所述头部特征中的类别信息,提取所述头部特征的同类特征对;
[0040]根据所述头部特征的同类特征对,提取与类别无关的特征;
[0041]利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征;
[0042]融合所述自适应增广特征和所述尾部特征,得到增广尾部特征;
[0043]对所述头部特征、所述尾部特征和所述增广尾部特征进行特征提取后并进行分类预测,得到分类结果;
[0044]根据分类损失和预设尾部增广处理损失,更新所述分类模型的参数,直到所述分
类模型收敛时,得到训练好的分类模型。
[0045]上述分类模型的训练方法、装置、计算机设备、存储介质和计算机程序产品,通过从头部特征的同类特征中,提取与类别无关的特征,引入非全局注意力机制,将与类别无关的特征与尾部特征融合,得到自适应增广特征,能够使得与类别无关特征和在不同空间位置上尾部特征,实现更为精细地自适应融合,提升了与类别无关特征与尾部特征的适配性,使得增广尾部特征与真实尾部数据相符合,进而将增广尾部特征与尾部特征融合,得到增广尾部特征。该方法使得在分类模型的训练阶段有效扩增了尾部特征空间,从而能够提升分类模型的分类性能。
附图说明
[0046]图1为一个实施例中分类模型的训练方法的应用环境图;
[0047]图2为一个实施例中分类模型的训练方法的流程示意图;
[0048]图3为一个实施例中分类模型的结构示意图;
[0049]图4为一个实施例中分类模型的工作说明示意图;
[0050]图5为一个实施例中增广处理层的结构示意图;
[0051]图6为一个实施例中循环重构损失的说明示意图;
[0052]图7为一个实施例中分类模型的训练装置的结构框图;
[0053]图8为一个实施例中计算机设备的内部结构图。
具体实施方式
[0054]为了使本申请的本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种分类模型的训练方法,其特征在于,所述方法包括:获取待训练数据的头部特征和尾部特征;根据所述头部特征中的类别信息,提取所述头部特征的同类特征对;根据所述头部特征的同类特征对,提取与类别无关的特征;利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征;融合所述自适应增广特征和所述尾部特征,得到增广尾部特征;对所述头部特征、所述尾部特征和所述增广尾部特征进行特征提取后并进行分类预测,得到分类结果;根据分类损失和预设尾部增广处理损失,更新所述分类模型的参数,直到所述分类模型收敛时,得到训练好的分类模型。2.根据权利要求1所述的方法,其特征在于,所述利用非全局注意力机制,将所述与类别无关的特征与所述尾部特征融合,得到自适应增广特征,包括:根据所述与类别无关的特征与所述尾部特征的相似度,得到全局相似度特征;利用所述全局相似度特征对所述尾部特征进行注意,得到所述增广尾部特征。3.根据权利要求1所述的方法,其特征在于,所述预设尾部增广处理损失包括不同类别特征之间的对比损失;所述方法还包括:计算不同类别特征之间的相似度,得到不同类别特征的对比损失,所述对比损失的约束目标为拉近同类特征之间的距离,推远不同类特征之间的距离。4.根据权利要求1所述的方法,其特征在于,所述预设尾部增广处理损失包括增广尾部特征的类别损失;所述方法还包括:根据所述增广尾部特征与同源的尾部特征的类别的差异,计算增广尾部特征的类别损失,所述增广尾部特征的类别损失的约束目标为使增广尾部特征与同源的尾部特征的类别相同。5.根据权利要求1、3或4所述的方法,其特征在于,所述预设尾部增广处理损失包括增广处理的循环重构损失;所述方法还包括:对所述尾部特征中重构得到重构的自适应增广特征;对所述重构的自适应增广特征重构得到重构的尾部特征;根据所述自适应增广特征和所述重构的自适应增广特征的差异,以及根据所述尾部特征和所述重构的尾部特征之间的差异,计算得到所述重构损失;所述重构损失的约束目标为能够从所述尾部特征中重构中所述自适应增广特征,从重构得到的自适应增广特征重构出初始的尾部特征。6.根据权利要求5所述的方法,其特征在于,所述预设尾部增广处理损失包括增广处理的模式寻找损失;所述方法还包括:根据不同类别尾部特征间的差异与不同类别增广尾部特征间的差异的比值,以及不同类别头部特征之间的差异与不同类别增广尾部特征间的差异的比值,计算增广处理的模式寻找损失;
所述增广处理的模式寻找损失的约束目...

【专利技术属性】
技术研发人员:洪燕孙众毅鄢科
申请(专利权)人:腾讯科技深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1