一种目标检测模型训练方法、目标检测方法及系统技术方案

技术编号:39249813 阅读:10 留言:0更新日期:2023-10-30 12:01
本发明专利技术公开了一种目标检测模型训练方法、目标检测方法及系统,包括:将有标签数据集中的数据输入学生网络模型,获得损失SL

【技术实现步骤摘要】
一种目标检测模型训练方法、目标检测方法及系统


[0001]本专利技术涉及机器学习
,特别涉及一种目标检测模型训练方法、目标检测方法及系统。

技术介绍

[0002]现有半监督的目标检测算法中,主要有伪标签、蒸馏等方法。对于伪标签算法,主要靠对无标签数据生成伪标签,通过学习标签信息以提升目标检测效果。但是,生成伪标签的原始模型无法实时更新。而对于目前的蒸馏算法,虽然可以实时更新模型,但是,由于监督中的正负样本选择问题,大部分研究都集中在二阶段的目标检测算法,比如基于fastercnn的蒸馏算法。现有阶段基本很少关注一阶段的YOLOX模型,没有研究如何去融合无标签数据以进行蒸馏,进而提升对无标签数据中的目标检测效果。

技术实现思路

[0003]为解决现有技术中无标签数据目标检测效果差的不足,本专利技术提供了一种目标检测模型训练方法、目标检测方法及系统,能够大大提升无标签数据的目标检测效果。
[0004]为解决上述技术问题,本专利技术所采用的技术方案是:
[0005]一种目标检测模型训练方法,其特点是包括以下步骤:
[0006]步骤1,获取有标签数据集A和无标签数据集B;
[0007]步骤2,将有标签数据集A中的数据A
a
输入学生网络模型,所述学生网络模型包括学生网络特征提取模块和学生网络预测模块,学生网络特征提取模块对应数据A
a
的损失为SL
A

[0008]步骤3,将无标签数据集B中的数据B
b<br/>分别输入学生网络模型和教师网络模型,得到学生网络模型的特征Feature
stdent
、教师网络模型的特征Feature
teacher
、学生网络模型的目标预测结果O
stdent
、教师网络模型的目标预测结果O
teacher
;其中,所述教师网络模型包括教师网络特征提取模块和教师网络预测模块;
[0009]步骤4,基于Feature
stdent
和Feature
teacher
获得损失函数L
feature
,所述教师网络特征提取模块基于L
feature
指导训练学生网络特征提取模块;
[0010]步骤5,基于O
stdent
和O
teacher
获得对应数据B
b
的损失SL
B

[0011]步骤6,基于SL
A
和SL
B
,获得总损失L
total
;所述教师网络预测模块基于L
total
指导训练学生网络预测模块。
[0012]进一步地,目标检测模型训练方法还包括:
[0013]步骤7,基于学生网络模型的参数,更新教师网络模型的参数。
[0014]作为一种优选方式,所述学生网络模型和教师网络模型均为YOLOX模型。
[0015]作为一种优选方式,所述步骤2中,所述SL
A
基于有标签数据对应的前后背景损失SL
obj
、类别预测损失SL
cls
和位置预测损失SL
loc
获得;
[0016]所述步骤3中,所述O
stdent
包括学生网络模型对应数据B
b
的位置预测结果Loc
sttdent

类别预测结果CLS
stdent
、前后背景信息OBJ
stdent
;所述O
teacher
包括教师网络模型对应数据B
b
的位置预测结果Loc
teacher
、类别预测结果CLS
teacher
、前后背景信息OBJ
teacher

[0017]所述步骤5中,通过双阈值对比过滤挖掘方法,获取有效的监督信息,所述基于O
stdent
和O
teacher
获得对应数据B
b
的损失SL
B
包括:
[0018]设定第一阈值α和第二阈值β,其中α>β;
[0019]基于OBJ
teacher
和CLS
teacher
获得样本置信度CONF
teacher

[0020]将CONF
teacher
<β对应的无标签数据设置为背景,将CONF
teacher
>α对应的无标签数据设置为前景;对于教师网络预测结果好于学生网络的部分,才会给学生网络监督。
[0021]基于预设值r、蒸馏温度T、α、β、O
stdent
、O
teacher
,确定损失SL
B
;其中,SL
B
包括无标签数据对应的前后背景损失TL
obj
、类别预测损失TL
cls
和位置预测损失TL
loc

[0022]作为一种优选方式,总损失L
total
=SL
A
+θ(TL
loc
+TL
boj
+TL
cls
),其中,θ为预设的第一比例调节系数。
[0023]作为一种优选方式,所述步骤7中,根据公式T
params
=T
params
γ+S
params
(1

γ)更新教师网络模型的参数,其中,T
params
为更新获得的教师网络模型的参数,S
params
为学生网络模型的参数,γ为预设的第二比例调节系数。
[0024]作为一种优选方式,所述无标签数据集B中的无标签数据B
b
包括原始采集获得的无标签数据和/或由原始采集获得的无标签数据通过图片增广扩展处理获得的无标签数据。
[0025]基于同一个专利技术构思,本专利技术还提供了一种目标检测方法,其特点是利用经由所述的目标检测模型训练方法训练获得的学生网络模型对待检测数据进行目标检测。
[0026]基于同一个专利技术构思,本专利技术还提供了一种目标检测系统,其特点是包括待检测数据获取单元和经由所述的目标检测模型训练方法训练获得的学生网络模型,其中:
[0027]待检测数据获取单元:用于获得待检测数据;
[0028]学生网络模型:用于对待检测数据识别以输出目标检测结果。
[0029]与现有技术相比,本专利技术通过组合特征蒸馏,能够提高蒸馏算法一阶段检测的特征提取能力,从而大大提升无标签数据的目标检测效果,大大降低人工标注成本;通过在蒸馏算法一阶段检测模型的密集样本中挖掘难样本做监督,提升模型的半监督学习能力本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种目标检测模型训练方法,其特征在于,包括以下步骤:步骤1,获取有标签数据集A和无标签数据集B;步骤2,将有标签数据集A中的数据A
a
输入学生网络模型,所述学生网络模型包括学生网络特征提取模块和学生网络预测模块,学生网络特征提取模块对应数据A
a
的损失为SL
A
;步骤3,将无标签数据集B中的数据B
b
分别输入学生网络模型和教师网络模型,得到学生网络模型的特征Feature
stdent
、教师网络模型的特征Feature
teacher
、学生网络模型的目标预测结果O
stdent
、教师网络模型的目标预测结果O
teacher
;其中,所述教师网络模型包括教师网络特征提取模块和教师网络预测模块;步骤4,基于Feature
stdent
和Feature
teacher
获得损失函数L
feature
,所述教师网络特征提取模块基于L
feature
指导训练学生网络特征提取模块;步骤5,基于O
stdent
和O
teacher
获得对应数据B
b
的损失SL
B
;步骤6,基于SL
A
和SL
B
,获得总损失L
total
;所述教师网络预测模块基于L
total
指导训练学生网络预测模块。2.根据权利要求1所述的目标检测模型训练方法,其特征在于,还包括:步骤7,基于学生网络模型的参数,更新教师网络模型的参数。3.根据权利要求1或2所述的目标检测模型训练方法,其特征在于,所述学生网络模型和教师网络模型均为YOLOX模型。4.根据权利要求3所述的目标检测模型训练方法,其特征在于,所述步骤2中,所述SL
A
基于有标签数据对应的前后背景损失SL
obj
、类别预测损失SL
cls
和位置预测损失SL
loc
获得;所述步骤3中,所述O
stdent
包括学生网络模型对应数据B
b
的位置预测结果Loc
sttdent
、类别预测结果CLS
stdent
、前后背景信息OBJ
stdent
;所述O
teacher
包括教师网络模型对应数据B
b
的位置预测结果Loc
teacher
、类别预测...

【专利技术属性】
技术研发人员:夏威冀春锟黄金
申请(专利权)人:湖南视比特机器人有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1