一种基于知识蒸馏的目标检测方法和装置制造方法及图纸

技术编号:33919593 阅读:20 留言:0更新日期:2022-06-25 20:49
本发明专利技术公开了一种基于知识蒸馏的目标检测方法和装置,涉及计算机技术领域。该方法的一具体实施方式包括:利用样本图像集训练教师网络,得到目标教师网络,引入损失函数,使用可微分的分组搜索方法,逐组搜索目标教师网络的每个卷积组的聚合特征权重;使用聚合特征权重,从目标教师网络的每个卷积组中提取出相应聚合特征;将提取到的聚合特征作为知识,对学生网络进行聚合特征蒸馏,获得目标学生网络,进而将待检测的图像输入到目标学生网络中,以检测出待检测的图像中的目标对象和所处位置。该实施方式通过耦合来自教师网络不同层的特征,提取出教师网络中更多的知识,计算开销较少,提高了学生网络在检测目标对象时的准确度。度。度。

【技术实现步骤摘要】
一种基于知识蒸馏的目标检测方法和装置


[0001]本专利技术涉及计算机
,尤其涉及一种基于知识蒸馏的目标检测方法和装置。

技术介绍

[0002]在自动驾驶领域,激光雷达由于其稳定的性能且不易受外界环境的影响,一直是自动驾驶汽车的主要传感器之一,各种目标检测方案不断提出,但现有的方法在精度和速度方面都欠佳。
[0003]知识蒸馏指将知识从繁杂的大型神经网络转移到轻量化神经网络的过程。在知识蒸馏中,繁杂的大型神经网络被称为教师网络(teacher network),具有比小型模型更强的知识表达能力,但同时由于参数数量巨大,无法被算力有限的设备所应用;轻量化神经网络被称为学生网络(student network),网络表达能力有限,但因为其参数量和计算量小,可以被应用于广泛的计算设备中。
[0004]知识蒸馏的目标就是将教师模型中的深层知识以某种形式提取出来,并用于学生模型的训练,从而使得学生网络在具有较少的参数数量和计算量的同时,获得比肩教师模型的泛化能力。在实现本专利技术过程中,专利技术人发现现有知识蒸馏模型至少存在如下问题:
[0005]1)教师网络和学生网络的模型容量不匹配,导致学生网络无法拟合教师网络得到的知识;
[0006]2)选择教师网络的哪些信息作为知识,对于学生网络的训练非常重要,而知识的选择往往来源于专家的经验,无法自动设计。

技术实现思路

[0007]有鉴于此,本专利技术实施例提供一种基于知识蒸馏的目标检测方法和装置,至少能够解决现有技术中依赖专家经验选择教师网络中的知识、以及学生网络无法拟合教师网络知识的现象。
[0008]为实现上述目的,根据本专利技术实施例的一个方面,提供了一种基于知识蒸馏的目标检测方法,包括:
[0009]利用样本图像集训练教师网络,得到目标教师网络,引入损失函数,使用可微分的分组搜索方法,逐组搜索所述目标教师网络的每个卷积组的聚合特征权重;其中,样本图像中的对象和所处位置已标注;
[0010]使用所述聚合特征权重,从所述目标教师网络的每个卷积组中提取出相应聚合特征;
[0011]将提取到的聚合特征作为知识,对学生网络进行聚合特征蒸馏,获得目标学生网络,进而将待检测的图像输入到所述目标学生网络中,以检测出所述待检测的图像中的目标对象和所处位置。
[0012]可选的,所述引入损失函数,使用可微分的分组搜索方法,逐组搜索所述目标教师
网络的每个卷积组的聚合特征权重,包括:
[0013]将学生网络的模型参数作为底层变量,在训练集上以最小化训练损失函数为目标,利用梯度下降方式更新所述模型参数,得到新模型参数;
[0014]将所述目标教师网络的聚合特征权重的辅助参数作为上层变量,在新模型参数的基础上,在验证集上以最小化验证损失函数为目标,利用梯度下降方式更新所述辅助参数,得到目标辅助参数;
[0015]使用所述目标辅助参数计算得到聚合特征权重。
[0016]可选的,所述引入损失函数,包括:
[0017]将学生网络的原始信息分为两条路径,以计算从所述目标教师网络到学生网络的路径的第一损失函数、以及计算从学生网络到所述目标教师网络的路径的第二损失函数;
[0018]使用第一超参数和第二超参数,对所述第一损失函数和所述第二损失函数进行加权求和,得到总的桥式损失函数;其中,超参数用于调整第一或第二损失函数对桥式损失函数的影响;
[0019]将所述总的桥式损失函数,作为学生网络在训练集和验证集上的损失函数的选择。
[0020]可选的,所述计算从所述目标教师网络到学生网络的路径的第一损失函数,包括:
[0021]从所述目标教师网络的任一卷积组中提取出聚合特征,获取所述任一卷积组所处层数,以确定学生网络中位于所述层数的下一卷积组;
[0022]将聚合特征输入到所述下一卷积组并继续向前传播,直至传播到学生网络的最后一层卷积组;
[0023]利用最后一层卷积组输出的特征分布与真实标签,计算交叉熵损失函数,进而计算得到从所述目标教师网络到学生网络的路径的第一损失函数。
[0024]可选的,所述计算从学生网络到所述目标教师网络的路径的第二损失函数,包括:
[0025]从学生网络到所述目标教师网络的路径中,特征信息流以学生网络为起点,通过学生网络的前预设数量个卷积组得到输出特征图;
[0026]获取所述目标教师网络中与所述预设数量对应的卷积组的聚合特征,计算输出特征图和聚合特征之间的相似度,以将相似度作为从学生网络到所述目标教师网络的路径的第二损失函数。
[0027]可选的,还包括:在提取聚合特征、计算所述第二损失函数、对学生网络进行特征蒸馏训练时,选取学生网络和所述目标教师网络中位于激活层之前的特征图;以及保留特征图中不小于预设数值的特征,其他特征的值置为0。
[0028]可选的,还包括:对学生网络中的参数做初始化处理,以将每个卷积组中的聚合特征权重被初始化为:最后一个特征图的权重设置为1,其他特征图的权重均设置为0。
[0029]可选的,在所述引入损失函数之前,还包括:
[0030]按照分辨率的大小关系,对特征图进行分组,以将原始网络分为不同的卷积组。
[0031]为实现上述目的,根据本专利技术实施例的另一方面,提供了一种基于知识蒸馏的目标检测装置,包括:
[0032]权重搜索模块,用于利用样本图像集训练教师网络,得到目标教师网络,引入损失函数,使用可微分的分组搜索方法,逐组搜索所述目标教师网络的每个卷积组的聚合特征
权重;其中,样本图像中的对象和所处位置已标注;
[0033]特征提取模块,用于使用所述聚合特征权重,从所述目标教师网络的每个卷积组中提取出相应聚合特征;
[0034]特征蒸馏模块,用于将提取到的聚合特征作为知识,对学生网络进行聚合特征蒸馏,获得目标学生网络,进而将待检测的图像输入到所述目标学生网络中,以检测出所述待检测的图像中的目标对象和所处位置。
[0035]可选的,所述权重搜索模块,用于:
[0036]将学生网络的模型参数作为底层变量,在训练集上以最小化训练损失函数为目标,利用梯度下降方式更新所述模型参数,得到新模型参数;
[0037]将所述目标教师网络的聚合特征权重的辅助参数作为上层变量,在新模型参数的基础上,在验证集上以最小化验证损失函数为目标,利用梯度下降方式更新所述辅助参数,得到目标辅助参数;
[0038]使用所述目标辅助参数计算得到聚合特征权重。
[0039]可选的,所述权重搜索模块,用于:
[0040]将学生网络的原始信息分为两条路径,以计算从所述目标教师网络到学生网络的路径的第一损失函数、以及计算从学生网络到所述目标教师网络的路径的第二损失函数;
[0041]使用第一超参数和第二超参数,对所述第一损失函数和所述第二损失函数进行加权求和,得到总的桥式损失函数;其中,超参数用于调整第一或第二损失函数本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于知识蒸馏的目标检测方法,其特征在于,包括:利用样本图像集训练教师网络,得到目标教师网络,引入损失函数,使用可微分的分组搜索方法,逐组搜索所述目标教师网络的每个卷积组的聚合特征权重;其中,样本图像中的对象和所处位置已标注;使用所述聚合特征权重,从所述目标教师网络的每个卷积组中提取出相应聚合特征;将提取到的聚合特征作为知识,对学生网络进行聚合特征蒸馏,获得目标学生网络,进而将待检测的图像输入到所述目标学生网络中,以检测出所述待检测的图像中的目标对象和所处位置。2.根据权利要求1所述的方法,其特征在于,所述引入损失函数,使用可微分的分组搜索方法,逐组搜索所述目标教师网络的每个卷积组的聚合特征权重,包括:将学生网络的模型参数作为底层变量,在训练集上以最小化训练损失函数为目标,利用梯度下降方式更新所述模型参数,得到新模型参数;将所述目标教师网络的聚合特征权重的辅助参数作为上层变量,在新模型参数的基础上,在验证集上以最小化验证损失函数为目标,利用梯度下降方式更新所述辅助参数,得到目标辅助参数;使用所述目标辅助参数计算得到聚合特征权重。3.根据权利要求1所述的方法,其特征在于,所述引入损失函数,包括:将学生网络的原始信息分为两条路径,以计算从所述目标教师网络到学生网络的路径的第一损失函数、以及计算从学生网络到所述目标教师网络的路径的第二损失函数;使用第一超参数和第二超参数,对所述第一损失函数和所述第二损失函数进行加权求和,得到总的桥式损失函数;其中,超参数用于调整第一或第二损失函数对桥式损失函数的影响;将所述总的桥式损失函数,作为学生网络在训练集和验证集上的损失函数的选择。4.根据权利要求3所述的方法,其特征在于,所述计算从所述目标教师网络到学生网络的路径的第一损失函数,包括:从所述目标教师网络的任一卷积组中提取出聚合特征,获取所述任一卷积组所处层数,以确定学生网络中位于所述层数的下一卷积组;将聚合特征输入到所述下一卷积组并继续向前传播,直至传播到学生网络的最后一层卷积组;利用最后一层卷积组输出的特征分布与真实标签,计算交叉熵损失函数,进而计算得到从所述目标教师网络到学生网络的路径的第一损失函数。5.根据权利要求3所述的方法,其特征在于,所述计算从学生网络到...

【专利技术属性】
技术研发人员:徐鑫
申请(专利权)人:京东鲲鹏江苏科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1