【技术实现步骤摘要】
一种神经网络训练方法及相关装置
[0001]本申请涉及计算机
,尤其涉及一种神经网络训练方法及相关装置。
技术介绍
[0002]知识蒸馏是通过一个大模型教师网络(teacher)中的“知识”,去指导小模型学生网络(student)的训练,提高小模型的效果,从而达到变相压缩模型的目的。如图1所示,从知识蒸馏的位置来说,蒸馏技术总体可以分为两大类,输出层(Logits层)知识蒸馏和网络层知识蒸馏,下面分别进行介绍。
[0003]Logits层知识蒸馏通常是针对分类、识别这一类任务,如图2所示,Logits是指网络中最后一层的输出,教师网络和学生网络都会输出各自的Logits,而Logits直接影响最终的分类结果。Logits层知识蒸馏方法通常是利用教师网络的logits去指导学生网络进行训练,其实现形式通常是利用相对熵损失函数(Kullback
‑
Leibler Divergence Loss,KL)来度量教师网络的Logits和学生网络的Logits之间的差异。除了Logits层知识蒸馏及其变种 ...
【技术保护点】
【技术特征摘要】
1.一种神经网络训练方法,其特征在于,包括:获取第一特征,所述第一特征由输入数据经过学生网络的第一网络层处理后得到,所述输入数据包括以下数据中的任意一种或多种:图像数据、音频数据和文本数据;获取第二特征,所述第二特征由输入数据经过教师网络的第二网络层处理后得到,所述第二网络层是所述教师网络中与所述学生网络的第一网络层对应的网络层;根据第一损失函数训练中间变换网络,所述中间变换网络包括扩张模块和收缩模块,所述扩张模块用于将所述第一特征转换为第三特征,所述第三特征与所述第二特征对齐,所述第一损失函数用于衡量所述第三特征与所述第二特征的差异;所述收缩模块用于将所述第三特征转换为第四特征,所述第四特征与所述第一特征对齐。2.根据权利要求1所述的方法,其特征在于,所述根据第一损失函数训练中间变换网络包括:迭代地对所述中间变换网络进行多次同态变换,直至所述第一损失函数不再减小。3.根据权利要求2所述的方法,其特征在于,所述进行多次同态变换中的一次同态变换包括:基于预设的网络结构搜索空间进行同态变换搜索,得到与所述中间变换网络中的扩张模块等价的第一目标网络以更新所述扩张模块;和/或基于预设的网络结构搜索空间进行同态变换搜索,得到与所述中间变换网络中的收缩模块等价的第二目标网络以更新所述收缩模块。4.根据权利要求3所述的方法,其特征在于,所述网络结构搜索空间包括多个网络结构,所述多个网络结构均为卷积核大小为1x1的卷积层,且均能与所述学生网络融合。5.根据权利要求1
‑
4任一项所述的方法,其特征在于,还包括:将所述中间变换网络融入所述学生网络。6.根据权利要求5所述的方法,其特征在于,所述将所述中间变换网络融入所述学生网络包括:根据所述中间变换网络更新所述学生网络的第一网络层或所述第一网络层的下一个网络层的权重。7.根据权利要求5所述的方法,其特征在于,所述将所述中间变换网络融入所述学生网络包括:将所述中间变换网络融合为目标网络层,并将所述目标网络层插入所述学生网络的第一网络层与所述第一网络层的下一个网络层之间。8.根据权利要求3
‑
7中任一项所述的方法,其特征在于,所述目标网络的分支数量大于所述中间变换网络当前的网络的分支数量,和/或,所述目标网络的宽度大于所述中间变换网络当前的网络的宽度,和/或,所述目标网络的深度大于所述中间变换网络当前的网络的深度。9.根据权利要求4
‑
8任一项所述的方法,其特征在于,还包括:根据第二损失函数训练融入所述中间变换网络之后的所述学生网络。10.一种神经网络训练装置,其特征在于,包括:第一获取单元,用于获取第一特征,所述第一特征由输入数据经过学生网络的第一网
络层处理后得到,所述输入数据包括以下数据中的任意一种或多种:图像数据、音频数据和文本数据;第二获取单元,用于获取第二特征,所述第二特征由输入数据经过教师网络的第二网络层处理后得到,所述第二网络层是所述教师网络中与所述学生网络的第一网络层对应的网络层;...
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。