模型训练方法、装置、电子设备及介质制造方法及图纸

技术编号:38550975 阅读:11 留言:0更新日期:2023-08-22 20:57
本公开提供的一种模型训练方法、装置、电子设备及介质,涉及车辆技术领域,方法包括:获取第一浮点模型,对所述第一浮点模型进行后量化处理,得到第一量化模型,将第一浮点模型作为知识蒸馏的教师模型,将第一量化模型作为知识蒸馏的学生模型,基于所述第一浮点模型以及所述第一量化模型进行量化感知训练,得到量化感知训练后的模型。采用该方法可以提高训练得到的量化模型的精度,使得训练的模型适宜部署于车辆。于车辆。于车辆。

【技术实现步骤摘要】
模型训练方法、装置、电子设备及介质


[0001]本公开涉及车辆
,尤其涉及一种模型训练方法、装置、电子设备及存储介质。

技术介绍

[0002]随着深度学习的发展,神经网络被广泛应用于各种领域,例如,可以应用于车辆领域,以辅助车辆驾驶。在车辆驾驶时,通过模型帮助车辆迅速做出判断以及决策。然而,目前的模型大多都很复杂,导致计算速度慢,所需内存大,不适宜部署到车辆上。
[0003]相关技术中,可以通过模型量化来解决计算速度慢,所需内存大的问题,以提高将模型部署到车辆上的适应性。然而,模型量化是一种将浮点计算转成低比特定点计算的技术,可以有效的降低模型计算强度、参数大小和内存消耗,但模型量化往往带来巨大的精度损失,这给部署于车辆中的模型带来了另外的挑战。

技术实现思路

[0004]为克服相关技术中存在的问题,本公开提供一种模型训练方法、装置、电子设备及介质。
[0005]根据本公开实施例的第一方面,提供一种模型训练方法,所述模型训练方法包括:获取第一浮点模型,所述第一浮点模型是利用第一样本数据集对待训练的神经网络模型训练至收敛得到的;对所述第一浮点模型进行后量化处理,得到第一量化模型,所述第一量化模型包括第一量化参数;基于所述第一浮点模型以及所述第一量化模型进行量化感知训练,得到量化感知训练后的模型,其中,在量化感知训练过程中,所述第一浮点模型作为知识蒸馏的教师模型,所述第一量化模型作为知识蒸馏的学生模型,所述第一量化参数保持不变,所述量化感知训练后的模型用于对车辆采集的以下任一种数据进行处理:图像数据、音频数据、点云数据以及文本数据。
[0006]在一些实施方式中,所述基于所述第一浮点模型以及所述第一量化模型进行量化感知训练,得到量化感知训练后的模型,包括:构建知识蒸馏损失函数,所述蒸馏损失函数表征所述第一浮点模型的中间结果数据以及所述第一量化模型对应的中间结果数据的分布差异;基于所述知识蒸馏损失函数以及所述第一量化模型对应的损失函数,确定预设损失函数,所述第一量化模型对应的损失函数表征第一量化模型的输出结果与样本真实标签之间的差异;利用所述预设损失函数替换所述第一量化模型的损失函数,并对替换后的量化模型进行量化感知训练,得到量化感知训练后的模型。
[0007]在一些实施方式中,所述中间结果数据包括对应模型的输出层之前的至少一层的
计算结果。
[0008]在一些实施方式中,所述方法还包括:获取包括所述第一量化参数的第二量化模型;将所述第二量化模型与所述第一浮点模型进行拼接,得到拼接后的拼接模型;利用第二样本数据集对所述拼接模型进行训练,得到训练后的拼接模型;从所述训练后的拼接模型中分离出第二浮点模型;将分离出的第二浮点模型作为所述第一浮点模型,并返回执行对所述第一浮点模型进行后量化处理,得到第一量化模型的步骤,直到满足预设训练停止条件。
[0009]在一些实施方式中,所述拼接后的拼接模型包括特征提取网络以及检测头,所述拼接模型的检测头与所述第二量化模型的检测头的网络结构相同,所述拼接模型的特征提取网络包括所述第一浮点模型的第一特征提取网络、所述第二量化模型的第二特征提取网络以及均值计算节点,所述第二特征提取网络对输入数据提取的特征以及所述第一特征提取网络对所述输入数据提取的特征经过均值计算节点处理之后作为所述拼接模型的检测头的输入。
[0010]在一些实施方式中,所述对所述第一浮点模型进行后量化处理,得到第一量化模型,包括:利用所述第一浮点模型对第三样本数据集中的待预测数据进行推理,并获取对各个待预测数据分别进行推理时,所述第一浮点模型中预设层的待量化参数,所述待量化参数包括激活值以及权重值;基于所述第一浮点模型中预设层的待量化参数确定初始量化参数;基于所述初始量化参数,构建搜索空间;基于从所述搜索空间中采样的候选量化参数对应的模型性能评估指标,从所述搜索空间中确定所述第一量化参数,所述候选量化参数对应的模型性能评估指标是将所述候选量化参数赋值给所述第一浮点模型得到的候选量化模型对应的模型性能评估指标;将所述第一量化参数赋值给所述第一浮点模型得到所述第一量化模型。
[0011]在一些实施方式中,所述基于从所述搜索空间中采样的候选量化参数对应的模型性能评估指标,从所述搜索空间中确定所述第一量化参数,包括:基于从所述搜索空间中已采样的各个候选量化参数各自对应的模型性能评估指标,预测模型性能评估指标的概率分布;基于所述模型性能评估指标的概率分布,从所述搜索空间中确定对应模型性能评估指标最优的下一个候选量化参数;返回执行基于从所述搜索空间中已采样的各个候选量化参数各自对应的模型性能评估指标,预测模型性能评估指标的概率分布的步骤,直到满足预设搜索停止条件,将最近确定的候选量化参数确定为所述第一量化参数根据本公开实施例的第二方面,提供一种模型训练装置,所述模型训练装置包括:第一获取模块,被配置为获取第一浮点模型,所述第一浮点模型是利用第一样本数据集对待训练的神经网络模型训练至收敛得到的;后量化模块,被配置为对所述第一浮点模型进行后量化处理,得到第一量化模型,所述第一量化模型包括第一量化参数;
量化感知训练模块,被配置为基于所述第一浮点模型以及所述第一量化模型进行量化感知训练,得到量化感知训练后的模型,其中,在量化感知训练过程中,所述第一浮点模型作为知识蒸馏的教师模型,所述第一量化模型作为知识蒸馏的学生模型,所述第一量化参数保持不变,所述量化感知训练后的模型用于对车辆采集的以下任一种数据进行处理:图像数据、音频数据、点云数据以及文本数据。
[0012]根据本公开实施例的第三方面,提供一种电子设备,所述电子设备包括:处理器;用于存储处理器可执行指令的存储器;其中,所述处理器被配置为实现第一方面所述方法的步骤。
[0013]根据本公开实施例的第四方面,提供一种计算机可读存储介质,其上存储有计算机程序指令,该程序指令被处理器执行时实现本公开第一方面所提供的方法的步骤。
[0014]本公开提供的一种模型训练方法、装置、电子设备及介质,获取第一浮点模型,对所述第一浮点模型进行后量化处理,得到第一量化模型,将第一浮点模型作为知识蒸馏的教师模型,将第一量化模型作为知识蒸馏的学生模型,基于所述第一浮点模型以及所述第一量化模型进行量化感知训练,得到量化感知训练后的模型。由于在对第一量化模型进行量化感知训练的时候,加入了精度较高的浮点模型作为教师模型来对精度较低的第一量化模型进行知识蒸馏处理,可以辅助第一量化模型尽量收敛到第一浮点模型的精度水平,从而提高最终训练得到的模型的精度,由于量化感知训练后的模型兼具计算速度以及计算精度的优势,适宜部署于车辆,以对车辆采集的相关数据类型的数据进行处理。
[0015]应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本公开。
附图说明
[0016]此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本公开的实施例,并与说明书一起用于解释本公开的原理。
[0017]本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,所述方法包括:获取第一浮点模型,所述第一浮点模型是利用第一样本数据集对待训练的神经网络模型训练至收敛得到的;对所述第一浮点模型进行后量化处理,得到第一量化模型,所述第一量化模型包括第一量化参数;基于所述第一浮点模型以及所述第一量化模型进行量化感知训练,得到量化感知训练后的模型,其中,在量化感知训练过程中,所述第一浮点模型作为知识蒸馏的教师模型,所述第一量化模型作为知识蒸馏的学生模型,所述第一量化参数保持不变,所述量化感知训练后的模型用于对车辆采集的以下任一种数据进行处理:图像数据、音频数据、点云数据以及文本数据。2.根据权利要求1所述的方法,其特征在于,所述基于所述第一浮点模型以及所述第一量化模型进行量化感知训练,得到量化感知训练后的模型,包括:构建知识蒸馏损失函数,所述蒸馏损失函数表征所述第一浮点模型的中间结果数据以及所述第一量化模型对应的中间结果数据的分布差异;基于所述知识蒸馏损失函数以及所述第一量化模型对应的损失函数,确定预设损失函数,所述第一量化模型对应的损失函数表征第一量化模型的输出结果与样本真实标签之间的差异;利用所述预设损失函数替换所述第一量化模型的损失函数,并对替换后的量化模型进行量化感知训练,得到量化感知训练后的模型。3.根据权利要求2所述的方法,其特征在于,所述中间结果数据包括对应模型的输出层之前的至少一层的计算结果。4.根据权利要求1所述的方法,其特征在于,所述方法还包括:获取包括所述第一量化参数的第二量化模型;将所述第二量化模型与所述第一浮点模型进行拼接,得到拼接后的拼接模型;利用第二样本数据集对所述拼接模型进行训练,得到训练后的拼接模型;从所述训练后的拼接模型中分离出第二浮点模型;将分离出的第二浮点模型作为所述第一浮点模型,并返回执行对所述第一浮点模型进行后量化处理,得到第一量化模型的步骤,直到满足预设训练停止条件。5.根据权利要求4所述的方法,其特征在于,所述拼接后的拼接模型包括特征提取网络以及检测头,所述拼接模型的检测头与所述第二量化模型的检测头的网络结构相同,所述拼接模型的特征提取网络包括所述第一浮点模型的第一特征提取网络、所述第二量化模型的第二特征提取网络以及均值计算节点,所述第二特征提取网络对输入数据提取的特征以及所述第一特征提取网络对所述输入数据提取的特征经过均值计算节点处理之后作为所述拼接模型的检测头的输入。6.根据权利要求1

5任一项所述的方法,其特征在于,所述对所述第一浮点模型进行...

【专利技术属性】
技术研发人员:刘安华
申请(专利权)人:小米汽车科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1