【技术实现步骤摘要】
神经网络的训练方法、训练装置和电子设备
本申请涉及深度学习
,且更具体地,涉及一种神经网络的训练方法、神经网络的训练装置和电子设备。
技术介绍
性能优良的深度神经网络通常具有较深的层数,导致网络的参数量巨大。如果要在移动端应用的话,通常会选择模型参数较小的轻量型网络,但轻量型网络的性能相对没有那么优良。提升轻量型网络的模型性能的技术中,知识蒸馏作为一种有效的手段,被广泛应用。其工作原理是将大模型的输出作为辅助标注去进一步有效的监督轻量型网络的训练,实现知识迁移。但是,传统的知识蒸馏并没有充分地将大网络的知识迁移到轻量型网络中,轻量型网络的精度尚存在提高空间。因此,期望提供改进的轻量型网络的生成方案。
技术实现思路
为了解决上述技术问题,提出了本申请。本申请的实施例提供了一种神经网络的训练方法、神经网络的训练装置和电子设备,其能够结合已训练和未训练的神经网络在相同预设层的特征图获得损失函数,并进一步结合未训练的神经网络本身的损失函数来更新未训练的神经网络的参数,从而提高训练后的神经网络的精度。根据本申请的一个方面,提供了一种神经网络的训练方法,包括:将训练数据输入已训练的第一神经网络和待训练的第二神经网络;确定所述第一神经网络的预设层输出的第一特征图与所述第二神经网络在所述预设层输出的第二特征图;基于所述第一特征图和所述第二特征图确定所述第二神经网络的第一损失函数值;基于所述第一损失函数值和所述第二神经网络的第二损失函数值,更新所述第二神经网络的参数;以及,将更 ...
【技术保护点】
1.一种神经网络的训练方法,包括:/n将训练数据输入已训练的第一神经网络和待训练的第二神经网络;/n确定所述第一神经网络的预设层输出的第一特征图与所述第二神经网络在所述预设层输出的第二特征图;/n基于所述第一特征图和所述第二特征图确定所述第二神经网络的第一损失函数值;以及/n基于所述第一损失函数值和所述第二神经网络的第二损失函数值,更新所述第二神经网络的参数;以及/n将更新后的所述第二神经网络的参数作为待训练的第二神经网络的初始参数,以迭代方式重复上述所述将训练数据输入已训练的第一神经网络和待训练的第二神经网络的步骤~所述基于所述第一损失函数值和所述第二神经网络的第二损失函数值,更新所述第二神经网络的参数的步骤,在更新得到的所述第二神经网络符合预设条件时,得到最终已训练的所述第二神经网络。/n
【技术特征摘要】
1.一种神经网络的训练方法,包括:
将训练数据输入已训练的第一神经网络和待训练的第二神经网络;
确定所述第一神经网络的预设层输出的第一特征图与所述第二神经网络在所述预设层输出的第二特征图;
基于所述第一特征图和所述第二特征图确定所述第二神经网络的第一损失函数值;以及
基于所述第一损失函数值和所述第二神经网络的第二损失函数值,更新所述第二神经网络的参数;以及
将更新后的所述第二神经网络的参数作为待训练的第二神经网络的初始参数,以迭代方式重复上述所述将训练数据输入已训练的第一神经网络和待训练的第二神经网络的步骤~所述基于所述第一损失函数值和所述第二神经网络的第二损失函数值,更新所述第二神经网络的参数的步骤,在更新得到的所述第二神经网络符合预设条件时,得到最终已训练的所述第二神经网络。
2.如权利要求1所述的神经网络的训练方法,其中,
确定所述第一神经网络的预设层输出的第一特征图与所述第二神经网络在所述预设层输出的第二特征图包括:
将所述第一神经网络的卷积层中的最后一层输出的特征图确定为第一特征图;以及,
将所述第二神经网络的卷积层中的最后一层输出的特征图确定为第二特征图;
基于所述第一特征图和所述第二特征图确定所述第二神经网络的第一损失函数值包括:
基于所述第一特征图和所述第二特征图确定所述第二神经网络的L2损失函数值;以及
基于所述L2损失函数值确定所述第二神经网络的第一损失函数值。
3.如权利要求1所述的神经网络的训练方法,其中,
确定所述第一神经网络的预设层输出的第一特征图与所述第二神经网络在所述预设层输出的第二特征图包括:
将所述第一神经网络的softmax层输出的特征图确定为第一特征图;以及,
将所述第二神经网络的softmax层输出的特征图确定为第二特征图;
基于所述第一特征图和所述第二特征图确定所述第二神经网络的第一损失函数值包括:
基于所述第一特征图和所述第二特征图确定所述第二神经网络的交叉熵损失函数值;以及
基于所述交叉熵损失函数值确定所述第二神经网络的第一损失函数值。
4.如权利要求1所述的神经网络的训练方法,其中,基于所述第一损失函数值和所述第二神经网络的第二损失函数值,更新所述第二神经网络的参数包括:
计算所述第二神经网络的交叉熵损失函数值作为所述第二损失函数值;
计算所述第一损失函数值和所述第二损失函数值的加权和作为总损失函数值;以及
以所述总损失函数值通过反向传播的方式更新所述第二神经网络的参数。
5.如权利要求1所述的神经网络的训练方法,其中,在将训练数据输入已训练的第一神经网络和待训莲的第二神经网络之前进一步包括:
训练第一神经网络直到所述第一神经网络收敛;以及
对所述第一神经网络对应的所述第二神经网络进行高斯初始化。
6.一种神经网络的训练装置,包括:
神经网络输入单元,用于将训练数据输入已训练的第一神经网络和待训练的第...
【专利技术属性】
技术研发人员:周贺龙,张骞,黄畅,
申请(专利权)人:南京人工智能高等研究院有限公司,
类型:发明
国别省市:江苏;32
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。