基于知识蒸馏的识别模型训练方法及装置制造方法及图纸

技术编号:37764458 阅读:11 留言:0更新日期:2023-06-06 13:22
本公开涉及一种基于知识蒸馏的识别模型训练方法及装置,其中,该方法利用不同分辨率的点云数据之间的知识蒸馏,为低线数激光雷达系统训练识别模型,使得低线数激光雷达系统通过使用训练好的识别模型基于采集到的低分辨率的点云数据能够获得准确度较高的识别结果,保证识别性能。此外,采用本公开的方法在数据采集与训练阶段部署少量的高线数激光雷达装置便可以提升应用部署端大规模的基于低线数激光雷达的识别性能,极大地降低了应用成本。极大地降低了应用成本。极大地降低了应用成本。

【技术实现步骤摘要】
基于知识蒸馏的识别模型训练方法及装置


[0001]本公开涉及计算机
,尤其涉及一种基于知识蒸馏的识别模型训练方法及装置。

技术介绍

[0002]随着人工智能浪潮的推进,自动驾驶作为最前沿、难度系数最高且未来运用领域最广的技术得到广泛关注与研究。在车辆行驶过程中,自动驾驶系统通过分析不同的传感器数据来获取障碍物位置、可通行区域等周围环境信息,并根据感知得到的环境信息规划合适的路径和运动速度,并控制车辆自动实施操作,进而保证车辆能够安全、平稳、高效地行驶至目的地点。激光雷达得益于其对环境光的不敏感性以及捕捉物体三维空间结构的能力,可以更准确地探测车辆周围环境,从而实现更安全的自动驾驶系统,因此,基于激光点云的目标识别被广泛地应用于自动驾驶领域中的环境感知任务中。
[0003]其中,通过激光雷达采集的点云数据的分辨率容易受到激光雷达硬件参数规格的影响,通常激光雷达的通道数越多,点云数据的分辨率越高,识别准确度较高,但成本也越高,无法广泛应用。因此,如何实现通过低线数激光雷达采集的较低分辨率的点云数据能够得到准确度较高的识别结果是当前亟待解决的问题。

技术实现思路

[0004]为了解决上述技术问题,本公开提供了一种基于知识蒸馏的识别模型训练方法及装置。
[0005]第一方面,本公开提供一种基于知识蒸馏的识别模型训练方法,包括:
[0006]获取教师模型、待训练的学生模型以及针对相同环境利用激光雷达采集的第一点云数据和第二点云数据;其中,所述第一点云数据对应的第一分辨率高于所述第二点云数据对应的第二分辨率;
[0007]将所述第一点云数据作为训练样本输入至所述教师模型进行目标识别得到所述教师模型输出的第一结果,以及,将所述第二点云数据作为训练样本输入至所述学生模型进行目标识别得到所述学生模型输出的第二结果;
[0008]利用所述第一结果以及所述第二结果进行知识蒸馏对所述学生模型进行训练直至满足训练结束条件得到目标学生模型。
[0009]在一些实施例中,所述第一结果包括:所述教师模型基于所述第一点云数据得到的特征图、回归结果、分类结果中的一项或多项;所述第二结果包括:所述学生模型基于所述第二点云数据得到的特征图、回归结果、分类结果中的一项或多项;所述利用所述第一结果以及所述第二结果进行知识蒸馏对所述学生模型进行训练,包括:
[0010]基于所述第一结果和所述第二结果计算特征蒸馏损失函数、回归蒸馏损失函数以及分类蒸馏损失函数中的一项或多项;
[0011]对所述特征蒸馏损失函数、所述回归蒸馏损失函数以及所述分类蒸馏损失函数中
的一项或多项以及所述学生模型对应的初始损失函数进行加权求和得到目标损失函数;其中,所述学生模型对应的初始损失函数基于所述第二结果与相应训练样本的特征图标签、回归标签以及分类标签得到;
[0012]利用所述目标损失函数对所述学生模型进行训练。
[0013]在一些实施例中,所述基于所述第一结果和所述第二结果计算特征蒸馏损失函数、回归蒸馏损失函数以及分类蒸馏损失函数中的一项或多项,包括:
[0014]计算所述第一结果和所述第二结果分别包含的分类结果之间的KL散度作为所述分类蒸馏损失函数。
[0015]在一些实施例中,所述基于所述第一结果和所述第二结果计算特征蒸馏损失函数、回归蒸馏损失函数以及分类蒸馏损失函数中的一项或多项,包括:
[0016]计算所述第一结果和所述第二结果分别包含的特征图之间的L2距离作为所述特征蒸馏损失函数。
[0017]在一些实施例中,所述基于所述第一结果和所述第二结果计算特征蒸馏损失函数、回归蒸馏损失函数以及分类蒸馏损失函数中的一项或多项,包括:
[0018]基于所述第一结果和所述第二结果分别包含的回归结果确定用于计算所述回归蒸馏损失函数的目标训练样本;
[0019]计算所述学生模型针对所述目标训练样本输出的回归结果与所述教师模型针对所述目标训练样本输出的回归结果之间的距离的平均值作为所述回归蒸馏损失函数。
[0020]在一些实施例中,所述基于所述第一结果和所述第二结果分别包含的回归结果确定用于计算所述回归蒸馏损失函数的目标训练样本,包括:
[0021]针对训练样本,计算回归目标与所述第一识别结果包含的回归结果之间的第一距离,以及,计算所述回归目标与所述第二识别结果包含的回归结果之间的第二距离;
[0022]根据所述第一距离和所述第二距离之间的差值确定所述目标训练样本。
[0023]在一些实施例中,所述基于所述第一结果和所述第二结果计算特征蒸馏损失函数、回归蒸馏损失函数以及分类蒸馏损失函数中的一项或多项之后,还包括:
[0024]通过所述训练样本的焦点权重对计算得到的特征蒸馏损失函数、回归蒸馏损失函数以及分类蒸馏损失函数中的一项或多项进行修正;
[0025]所述对所述特征蒸馏损失函数、所述回归蒸馏损失函数以及所述分类蒸馏损失函数中的一项或多项以及所述学生模型对应的初始损失函数进行加权求和得到目标损失函数,包括:
[0026]基于修正后的特征蒸馏损失函数、修正后的回归蒸馏损失函数以及修正后的分类蒸馏损失函数中的一项或多项以及所述学生模型对应的初始损失函数进行加权求和得到目标损失函数。
[0027]在一些实施例中,还包括:
[0028]获取针对目标环境通过激光雷达采集的第三点云数据,所述第三点云数据对应所述第二分辨率;
[0029]将所述第三点云数据输入至所述目标学生模型得到所述目标学生模型输出的结果;
[0030]根据所述目标学生模型输出的结果确定所述第三点云数据对应的目标识别结果。
[0031]第二方面,本公开提供一种基于知识蒸馏的识别模型训练装置,包括:
[0032]获取模块,用于获取教师模型、待训练的学生模型以及针对相同环境利用激光雷达采集的第一点云数据和第二点云数据;其中,所述第一点云数据对应的第一分辨率高于所述第二点云数据对应的第二分辨率;
[0033]第一识别模块,用于将所述第一点云数据输入所述教师模型进行目标识别得到所述教师模型对应的第一结果;
[0034]第二识别模块,用于将所述第二点云数据输入所述学生模型进行目标识别得到所述学生模型对应的第二结果;
[0035]知识蒸馏模块,用于利用所述第一结果以及所述第二结果进行知识蒸馏对所述学生模型进行训练直至满足训练结束条件得到目标学生模型。
[0036]第三方面,本公开提供一种电子设备,包括:存储器和处理器;
[0037]所述存储器被配置为存储计算机程序指令;
[0038]所述处理器被配置为执行所述计算机程序指令,使得所述电子设备实现第一方面以及第一方面任一项所述的基于知识蒸馏的识别模型训练方法。
[0039]第四方面,本公开提供了一种可读存储介质,包括:计算机程序指令;电子设备的至少一个处理器执行所述计算机程序指令,使得所述电子设备实现第一方面以及第一方面本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于知识蒸馏的识别模型训练方法,其特征在于,包括:获取教师模型、待训练的学生模型以及针对相同环境利用激光雷达采集的第一点云数据和第二点云数据;其中,所述第一点云数据对应的第一分辨率高于所述第二点云数据对应的第二分辨率;将所述第一点云数据作为训练样本输入至所述教师模型进行目标识别得到所述教师模型输出的第一结果,以及,将所述第二点云数据作为训练样本输入至所述学生模型进行目标识别得到所述学生模型输出的第二结果;利用所述第一结果以及所述第二结果进行知识蒸馏对所述学生模型进行训练直至满足训练结束条件得到目标学生模型。2.根据权利要求1所述的方法,其特征在于,所述第一结果包括:所述教师模型基于所述第一点云数据得到的特征图、回归结果、分类结果中的一项或多项;所述第二结果包括:所述学生模型基于所述第二点云数据得到的特征图、回归结果、分类结果中的一项或多项;所述利用所述第一结果以及所述第二结果进行知识蒸馏对所述学生模型进行训练,包括:基于所述第一结果和所述第二结果计算特征蒸馏损失函数、回归蒸馏损失函数以及分类蒸馏损失函数中的一项或多项;对所述特征蒸馏损失函数、所述回归蒸馏损失函数以及所述分类蒸馏损失函数中的一项或多项以及所述学生模型对应的初始损失函数进行加权求和得到目标损失函数;其中,所述学生模型对应的初始损失函数基于所述第二结果与相应训练样本的特征图标签、回归标签以及分类标签得到;利用所述目标损失函数对所述学生模型进行训练。3.根据权利要求2所述的方法,其特征在于,所述基于所述第一结果和所述第二结果计算特征蒸馏损失函数、回归蒸馏损失函数以及分类蒸馏损失函数中的一项或多项,包括:计算所述第一结果和所述第二结果分别包含的分类结果之间的KL散度作为所述分类蒸馏损失函数。4.根据权利要求2所述的方法,其特征在于,所述基于所述第一结果和所述第二结果计算特征蒸馏损失函数、回归蒸馏损失函数以及分类蒸馏损失函数中的一项或多项,包括:计算所述第一结果和所述第二结果分别包含的特征图之间的L2距离作为所述特征蒸馏损失函数。5.根据权利要求2所述的方法,其特征在于,所述基于所述第一结果和所述第二结果计算特征蒸馏损失函数、回归蒸馏损失函数以及分类蒸馏损失函数中的一项或多项,包括:基于所述第一结果和所述第二结果分别包含的回归结果确定用于计算所述回归蒸馏损失函数的目标训练样本;计算所述学生模型针对所述目标训练样本输出的回归结果与所述教师模型针对所述目标训练样本输出的回归结果之间的距离的平均值作为所述回归蒸馏损失函数。6.根据权利要求5所述的方法,其特征在于...

【专利技术属性】
技术研发人员:单佳炜张正杰沈罗丰
申请(专利权)人:南通探维光电科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1