基于半监督知识蒸馏的行人重识别模型训练方法及装置制造方法及图纸

技术编号:36156337 阅读:58 留言:0更新日期:2022-12-31 20:02
本申请提供一种基于半监督知识蒸馏的行人重识别模型训练方法及装置。该方法包括:分别对教师模型和学生模型进行预训练,将有标签数据输入到训练后的教师模型和学生模型中,分别利用教师模型的预测值以及真实标签去监督学生模型的预测值,得到第一损失函数和第二损失函数;利用教师模型的预测值去监督学生模型中的无标签数据分类器输出的预测值,得到第三损失函数;利用对齐后的学生特征提取网络的特征图与教师特征提取网络的特征图计算得到第四损失函数;依据第一损失函数、第二损失函数、第三损失函数及第四损失函数,计算总损失函数;利用总损失函数对行人重识别模型进行更新和训练。本申请提高了行人重识别模型的泛化能力、精度以及识别效果。精度以及识别效果。精度以及识别效果。

【技术实现步骤摘要】
基于半监督知识蒸馏的行人重识别模型训练方法及装置


[0001]本申请涉及计算机
,尤其涉及一种基于半监督知识蒸馏的行人重识别模型训练方法及装置。

技术介绍

[0002]行人检测是利用计算机视觉技术识别图像或者视频流中是否存在行人并给予精确定位。该技术应用领域广泛,可与行人跟踪、行人重识别等技术结合,能够很好地应用于人工智能系统、车辆辅助驾驶系统、智能视频监控、人体行为分析、智能交通等现实场景领域。
[0003]目前行人重识别领域,在特定训练场景下已经取得非常高的性能,但由于不同场景之间的数据分布差异,导致模型在新的测试场景下性能明显下降。目前常用的无监督行人重识方法多数是基于伪标签的方法,主要思想是为无标记的数据产生高质量的伪标记来训练和更新神经网络,该方法思路简单清晰且性能良好。但是,现有的基于伪标记训练行人重识别模型的方法,仍存在模型泛化能力差,模型精度低,训练效果差的问题。

技术实现思路

[0004]有鉴于此,本申请实施例提供了一种基于半监督知识蒸馏的行人重识别模型训练方法及装置,以解决现有技术存在的模型泛化能力差,模型精度低,训练效果差的问题。
[0005]本申请实施例的第一方面,提供了一种基于半监督知识蒸馏的行人重识别模型训练方法,包括:获取预设场景下的有标签数据和无标签数据,利用有标签数据分别对教师模型和学生模型进行预训练,得到训练后的教师模型和学生模型;将有标签数据输入到训练后的教师模型和学生模型中,分别利用教师模型的预测值以及真实标签去监督学生模型的预测值,得到第一损失函数和第二损失函数;将无标签数据输入到训练后的教师模型和学生模型中,利用教师模型的预测值去监督学生模型中的无标签数据分类器输出的预测值,得到第三损失函数;将学生特征提取网络的特征图与教师特征提取网络的特征图进行对齐,并利用对齐后的学生特征提取网络的特征图与教师特征提取网络的特征图计算得到第四损失函数;依据第一损失函数、第二损失函数、第三损失函数以及第四损失函数,计算得到总损失函数;利用总损失函数对行人重识别模型的原损失函数进行更新,并利用训练集对更新损失函数后的行人重识别模型进行训练。
[0006]本申请实施例的第二方面,提供了一种基于半监督知识蒸馏的行人重识别模型训练装置,包括:预训练模块,被配置为获取预设场景下的有标签数据和无标签数据,利用有标签数据分别对教师模型和学生模型进行预训练,得到训练后的教师模型和学生模型;第一监督模块,被配置为将有标签数据输入到训练后的教师模型和学生模型中,分别利用教师模型的预测值以及真实标签去监督学生模型的预测值,得到第一损失函数和第二损失函数;第二监督模块,被配置为将无标签数据输入到训练后的教师模型和学生模型中,利用教师模型的预测值去监督学生模型中的无标签数据分类器输出的预测值,得到第三损失函
数;第三监督模块,被配置为将学生特征提取网络的特征图与教师特征提取网络的特征图进行对齐,并利用对齐后的学生特征提取网络的特征图与教师特征提取网络的特征图计算得到第四损失函数;计算模块,被配置为依据第一损失函数、第二损失函数、第三损失函数以及第四损失函数,计算得到总损失函数;训练模块,被配置为利用总损失函数对行人重识别模型的原损失函数进行更新,并利用训练集对更新损失函数后的行人重识别模型进行训练。
[0007]本申请实施例的第三方面,提供了一种电子设备,包括存储器,处理器及存储在存储器上并可在处理器上运行的计算机程序,处理器执行程序时实现上述方法的步骤。
[0008]本申请实施例采用的上述至少一个技术方案能够达到以下有益效果:
[0009]通过获取预设场景下的有标签数据和无标签数据,利用有标签数据分别对教师模型和学生模型进行预训练,得到训练后的教师模型和学生模型;将有标签数据输入到训练后的教师模型和学生模型中,分别利用教师模型的预测值以及真实标签去监督学生模型的预测值,得到第一损失函数和第二损失函数;将无标签数据输入到训练后的教师模型和学生模型中,利用教师模型的预测值去监督学生模型中的无标签数据分类器输出的预测值,得到第三损失函数;将学生特征提取网络的特征图与教师特征提取网络的特征图进行对齐,并利用对齐后的学生特征提取网络的特征图与教师特征提取网络的特征图计算得到第四损失函数;依据第一损失函数、第二损失函数、第三损失函数以及第四损失函数,计算得到总损失函数;利用总损失函数对行人重识别模型的原损失函数进行更新,并利用训练集对更新损失函数后的行人重识别模型进行训练。本申请提高了行人重识别模型对未知域的泛化能力,提高行人重识别模型的精度,提高模型的蒸馏效果。
附图说明
[0010]为了更清楚地说明本申请实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
[0011]图1是本申请实施例提供的基于半监督知识蒸馏的损失函数生成示意图;
[0012]图2是本申请实施例提供的基于半监督知识蒸馏的行人重识别模型训练方法的流程示意图;
[0013]图3是本申请实施例提供的基于半监督知识蒸馏的行人重识别模型训练装置的结构示意图;
[0014]图4是本申请实施例提供的电子设备的结构示意图。
具体实施方式
[0015]以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本申请实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本申请。在其它情况中,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本申请的描述。
[0016]如
技术介绍
所述,目前行人重识别领域,在特定训练场景下已经取得非常高的性
能,但由于不同场景之间的数据分布差异,导致模型在新的测试场景下性能明显下降。目前常用的无监督行人重识方法多数是基于伪标签的方法,主要思想是为无标记的数据产生高质量的伪标记来训练和更新神经网络,该方法思路简单清晰且性能良好。但是,现有的基于伪标记训练行人重识别模型的方法,仍存在模型泛化能力差,模型精度低,训练效果差的问题。因此,如何有效地利用新场景中的无标注数据,训练适用于新场景的行人重识别模型,是目前行人重识别领域亟需解决的问题之一。
[0017]有鉴于此,本申请实施例为解决上述问题,提供了一种基于半监督知识蒸馏的行人重识别模型训练方法,通过利用有标签数据和无标签数据训练泛化性能更好的ReID模型(行人重识别模型),为进一步提高模型的泛化能力,本申请采用基于蒸馏的方法生成伪标签,并对蒸馏方法做了优化,从而进一步提高模型精度。下面结合附图以及具体实施例对本申请技术方案的内容进行详细描述。
[0018]图1是本申请实施例提供的基于半监督知识蒸馏的损失函数生成示意图,如图1所示,该基于半监督知识蒸馏的损失函数生成过程具体可以包括:
[0019]首本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于半监督知识蒸馏的行人重识别模型训练方法,其特征在于,包括:获取预设场景下的有标签数据和无标签数据,利用所述有标签数据分别对教师模型和学生模型进行预训练,得到训练后的教师模型和学生模型;将所述有标签数据输入到训练后的所述教师模型和所述学生模型中,分别利用所述教师模型的预测值以及真实标签去监督所述学生模型的预测值,得到第一损失函数和第二损失函数;将所述无标签数据输入到训练后的所述教师模型和所述学生模型中,利用所述教师模型的预测值去监督所述学生模型中的无标签数据分类器输出的预测值,得到第三损失函数;将学生特征提取网络的特征图与教师特征提取网络的特征图进行对齐,并利用对齐后的学生特征提取网络的特征图与所述教师特征提取网络的特征图计算得到第四损失函数;依据所述第一损失函数、第二损失函数、第三损失函数以及第四损失函数,计算得到总损失函数;利用所述总损失函数对行人重识别模型的原损失函数进行更新,并利用训练集对更新损失函数后的行人重识别模型进行训练。2.根据权利要求1所述的方法,其特征在于,在所述得到训练后的教师模型和学生模型之后,所述方法还包括:将所述有标签数据和所述无标签数据分别输入到训练后的教师模型和学生模型中,利用训练后的教师模型中的教师特征提取网络提取特征图,并利用训练后的学生模型中的学生特征提取网络提取特征图。3.根据权利要求1所述的方法,其特征在于,利用所述教师模型的预测值去监督所述学生模型的预测值,得到第一损失函数,包括:将所述教师模型输出的预测值作为标签,利用所述标签与所述学生模型输出的预测值计算第一损失函数,其中,所述第一损失函数采用JS散度损失函数。4.根据权利要求1所述的方法,其特征在于,利用所述有标签数据对应的真实标签去监督所述学生模型的预测值,得到第二损失函数,包括:将所述有标签数据的真实标签与所述学生模型输出的预测值计算第二损失函数,其中,所述第二损失函数采用交叉熵损失函数。5.根据权利要求1所述的方法,其特征在于,所述利用所述教师模型的预测值去监督所述学生模型中的无标签数据分类器输出的预测值,得到第三损失函数,包括:在所述学生模型中生成一个与有标签数据分类器具有相同维度的无标签数据分类器,将所述教师模型输出的预测值作为标签,利用所述标签与所述无标签数据分类器输出的预测值计算第三损失函数,其中,所述第三损失函数采用JS散度损失函数。6.根据权利要求1所述的方法,其特征在于,所述将学生特征提取网络的特征图与教师特征提取网络的...

【专利技术属性】
技术研发人员:ꢀ七四专利代理机构
申请(专利权)人:深圳须弥云图空间科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1