一种基于级联温控蒸馏的目标关键点检测方法技术

技术编号:34542825 阅读:15 留言:0更新日期:2022-08-13 21:39
本发明专利技术公开了一种基于级联温控蒸馏的目标关键点检测方法,图像处理技术领域,其包括以下步骤:1)采集多张训练集图像,给所有训练集图像标注关键点信息;2)构建教师网络模型和学生网络模型;3)将训练集图像输入教师网络模型,将标注关键点信息的训练集图像输入学生网络模型,联合教师网络模型,采用级联温控蒸馏技术训练用于目标关键点检测的学生网络模型;4)基于训练好的学生网络模型对待检测图像进行目标关键点检测。本发明专利技术利用了高维空间特征向量和低维空间特征向量的两种蒸馏损失,通过温度T的变化,互相约束和促进,引导学生模型训练寻找最优解集,其中高维特征蒸馏能降低学生模型的过拟合,并增强鲁棒性。并增强鲁棒性。并增强鲁棒性。

【技术实现步骤摘要】
一种基于级联温控蒸馏的目标关键点检测方法


[0001]本专利技术涉及图像处理
,尤其涉及一种基于级联温控蒸馏的目标关键点检测方法。

技术介绍

[0002]关键点检测是一种在任意图像中自动搜索定义好的特征点的位置的技术。该项技术需要对目标物体建立模型系统,然后将待检测图像输入到建立的模型系统中,输出目标物体在待检测图像中的位置。
[0003]传统的图像关键点检测方法如申请公开号为CN110674714A的中国专利申请公开的一种快速人脸和人脸关键点联合检测方法,包括如下步骤:步骤1,构建教师网络和学生网络;步骤2,输入一批训练图像,进行数据增强;步骤3,根据自适应尺度匹配策略,划分正负锚点框样本;步骤4,挖掘正负样本,计算多任务损失函数,更新网络参数;步骤5,转至步骤2,直至训练收敛,得到教师网络模型;步骤6,重复步骤2到步骤5,利用教师网络模型,加入迁移学习损失函数,训练得到学生网络模型;步骤7,在测试阶段,输入测试图像到学生网络模型,得到检测结果;该方法获得的学生模型复杂度高、精度低。

技术实现思路

[0004]本专利技术提供了一种基于级联温控蒸馏的目标关键点检测方法,以解决传统图像关键点检测方法存在的训练模型复杂度高、精度低的问题。
[0005]为了解决上述技术问题,本专利技术提供的技术方案为:本专利技术一种基于级联温控蒸馏的目标关键点检测方法,其包括以下步骤:S1.采集多张训练集图像,给所有训练集图像标注关键点信息;S2.构建教师网络模型和学生网络模型;S3.将训练集图像输入教师网络模型,将标注关键点信息的训练集图像输入学生网络模型中,联合教师网络模型,采用级联温控蒸馏技术训练用于目标关键点检测的学生网络模型,其具体步骤为:S3.1.计算教师网络模型全连接层1和学生网络模型全连接层1的余弦蒸馏损失;S3.2.计算教师网络模型全连接层2和学生网络模型全连接层2之间的KL蒸馏损失;S3.3.基于学生网络模型的回归损失、余弦蒸馏损失和KL蒸馏损失计算学生网络模型的总损失函数;S3.4.以学生网络模型的总损失函数为基础,引导学生网络模型进行训练;S4.基于训练好的学生网络模型对待检测图像进行目标关键点检测。
[0006]优选地,所述步骤S2构建教师网络模型包括以下步骤:基于原始的mobilefacenet网络,在GDC层后设置3层全连接层,其中全连接层1的输出维度为256,全连接层2的输出维度设置为64,全连接层3的输出维度为14,选择smooth L1 loss作为损失函数;
所述步骤S2构建学生网络模型包括以下步骤:基于原始的Onet网络,在GDC层后设置3层全连接层,其中全连接层1的输出维度为256,全连接层2的输出维度为64,全连接层3的输出维度为14,选择smooth L1 loss作为回归损失函数。
[0007]构建教师网络模型和学生网络模型时,对教师网络模型和学生网路模型进行了修改,使卷积层特征向量从512维,分2梯度降到14维,2梯度分别为:256,64,进而使教师网络模型和学生网络模型可以分解出高维特征和低维特征。
[0008]优选地,所述步骤S3.1中教师网络模型全连接层1和学生网络模型全连接层1的余弦蒸馏损失的计算公式为:公式中,学生网络模型全连接层1和教师网络模型全连接层1之间的余弦蒸馏损失,分别表示学生网络模型和教师网络模型的输出向量, 表示学生网络模型和教师网络模型输出的向量点成,表示学生网络模型和教师网络模型的向量的模的乘积。
[0009]优选地,所述步骤S3.2中教师网络模型全连接层2和学生网络模型全连接层2之间的KL蒸馏损失的计算公式为:公式中,表示学生网络模型全连接层2和教师网络模型全连接层2之间的KL蒸馏损失,表示教师网络模型全连接层2和学生网络模型全连接层2的KL散度,表示归一化指数函数,表示教师网络模型的第i维特征向量值,表示学生网络模型的第i维特征向量值, j表示训练时的迭代次数,表示第j次迭代的蒸馏温度,n表示输出向量的维度,i表示向量索引。
[0010]优选地,所述教师网络模型全连接层2和学生网络模型全连接层2的KL散度的计算公式为:公式中,表示KL散度,表示教师网络模型的向量,表示学生网络模型的向量,和分别表示教师网络模型和学生网络模型输出的第i维特征向量值,表示教师网络模型输出向量的正则化项,表示学生网络模型输出向量的正则化项。
[0011]优选地,所述的教师网络模型输出向量的正则化项和学生网络模型输出向量的
正则化项均为。
[0012]优选地,所述的第j次迭代的蒸馏温度的表达式为:公式中,第j

1次训练迭代时的蒸馏温度;、分别表示第j次迭代时的余弦蒸馏损失和KL蒸馏损失。
[0013]优选地,所述的学生网络模型的总损失函数的计算公式为:公式中,表示学生网络模型总损失;a表示回归损失的比例,b表示KL蒸馏损失的比例,c表示余弦蒸馏损的比例;表示回归损失。
[0014]优选地,所述的回归损失的比例为0.5, KL蒸馏损失和余弦蒸馏损失的比例为0.25。
[0015]采用本专利技术提供的技术方案,与现有技术相比,具有如下有益效果:1、本专利技术涉及的基于级联温控蒸馏的目标关键点检测方法采用级联温控蒸馏技术训练学生网络模型,该方法分别利用了高维空间特征向量和低维空间特征向量的两种蒸馏损失,且处于相连的全连接层,通过温度T的变化,互相约束和促进,引导学生模型训练寻找最优解集。
[0016]2、本专利技术涉及的基于级联温控蒸馏的目标关键点检测方法通过高维特征蒸馏降低学生模型的过拟合,并增强鲁棒性;低维特征蒸馏能促进学生模型向教师模型分布学习,达到知识迁移的目的。
[0017]3、本专利技术涉及的基于级联温控蒸馏的目标关键点检测方法通过级联结构能够使得高维特征和低维特征相互约束,同频学习,能加快学生模型收敛。
[0018]4、本专利技术涉及的基于级联温控蒸馏的目标关键点检测方法通过温度控制使得第1级蒸馏损失的变化方向会指导第2级蒸馏温度T,从而影响蒸馏的剧烈程度,使的学生模型更快找到全局最优解集。
附图说明
[0019]图1为教师网络模型的测试集关键点误差损失;图2为级联温控蒸馏指导训练的学生模型的测试集关键点误差损失;图3为常见的知识蒸馏指导训练的学生模型的测试集关键点误差损失。
具体实施方式
[0020]为进一步了解本专利技术的内容,结合实施例对本专利技术作详细描述,以下实施例用于
说明本专利技术,但不用来限制本专利技术的范围。
[0021]实施例:本专利技术涉及的一种基于级联温控蒸馏的目标关键点检测方法包括以下步骤:S1.采集多张训练集图像,给所有训练集图像标注关键点信息。
[0022]S2.构建教师网络模型和学生网络模型:构建教师网络模型包括以下步骤:基于原始的mobilefacenet网络,在GDC层后设置3层全连接层,其中全连接层1的输出维度为256,全连接层2的输出维度设置为64,全连接层3的输出维度为14,选择smooth本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于级联温控蒸馏的目标关键点检测方法,其特征在于:其包括以下步骤:S1.采集多张训练集图像,给所有训练集图像标注关键点信息;S2.构建教师网络模型和学生网络模型;S3.将训练集图像输入教师网络模型,将标注关键点信息的训练集图像输入学生网络模型,联合教师网络模型,采用级联温控蒸馏技术训练用于目标关键点检测的学生网络模型,其具体步骤为:S3.1.计算教师网络模型全连接层1和学生网络模型全连接层1的余弦蒸馏损失;S3.2.计算教师网络模型全连接层2和学生网络模型全连接层2之间的KL蒸馏损失;S3.3.基于学生网络模型的回归损失、余弦蒸馏损失和KL蒸馏损失计算学生网络模型的总损失函数;S3.4.以学生网络模型的总损失函数为基础,引导学生网络模型进行训练;S4.基于训练好的学生网络模型对待检测图像进行目标关键点检测。2.根据权利要求1所述的基于级联温控蒸馏的目标关键点检测方法,其特征在于:所述步骤S2构建教师网络模型包括以下步骤:基于原始的mobilefacenet网络,在GDC层后设置3层全连接层,其中全连接层1的输出维度为256,全连接层2的输出维度设置为64,全连接层3的输出维度为14,选择smooth L1 loss作为损失函数;所述步骤S2构建学生网络模型包括以下步骤:基于原始的Onet网络,在GDC层后设置3层全连接层,其中全连接层1的输出维度为256,全连接层2的输出维度为64,全连接层3的输出维度为14,选择smooth L1 loss作为回归损失函数。3.根据权利要求1所述的基于级联温控蒸馏的目标关键点检测方法,其特征在于:所述步骤S3.1中教师网络模型全连接层1和学生网络模型全连接层1的余弦蒸馏损失的计算公式为:公式中,表示学生网络模型全连接层1和教师网络模型全连接层1之间的余弦蒸馏损失,分别表示学生网络模型和教师网络模型的输出向量,表示学生网络模型和教师网络模型输出的向量点成,表示学生网络模型和教师网络模型...

【专利技术属性】
技术研发人员:朱晓芳李学双赵国栋
申请(专利权)人:山东圣点世纪科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1