当前位置: 首页 > 专利查询>浙江大学专利>正文

一种基于辅助实例集成的分类网络知识蒸馏方法技术

技术编号:39155608 阅读:10 留言:0更新日期:2023-10-23 15:00
本发明专利技术公开了一种基于辅助实例集成的分类网络知识蒸馏方法,包括以下步骤:利用深度学习框架搭建部署实例;利用多阶段分支法构建多阶段实例群;基于多阶段实例群构建多对一的网络蒸馏损失函数;对多阶段实例群进行模型训练和优化,在训练结束后,固定并保存为多阶段实例群网络模型及网络模型参数;对优化后的多阶段实例群网络模型进行部署,优化裁剪网络模型参数,得到图像分类模型,图像分类模型的输出即为图像类别预测结果。本发明专利技术方法能够有效提高蒸馏后部署实例的性能,减少训练部署的时间及硬件需求,适用于多种时空受限或对精度要求更高的分类任务应用场景。求更高的分类任务应用场景。求更高的分类任务应用场景。

【技术实现步骤摘要】
一种基于辅助实例集成的分类网络知识蒸馏方法


[0001]本专利技术属于图像处理
,具体涉及一种基于辅助实例集成的分类网络知识蒸馏方法。

技术介绍

[0002]图像分类任务是计算机视觉中的基本任务,通过将图像送入神经网络中进行前向推理得到基于类别的概率预测向量,从而对输入图像进行定界分类。在数据集确定的情况下,传统的训练优化流程能够达到比较好的分类性能。而当应用场景需求更可靠的分类性能时,可行的办法是使用更大更复杂的网络结构,但是这样的大型模型需要占用更多的时空资源,比如训练更耗时,显存、内存占用更大,将会进一步导致训练部署成本高等问题。因此大量研究投入到了模型优化领域中。
[0003]知识蒸馏作为一种模型压缩技术,近年来在深度学习的应用和研究中得到了人们的广泛关注。与单模型训练相比,知识蒸馏指的是一种训练过程,即预先准备一个高性能的大模型,并基于该大模型将所学知识提取到小模型中;训练过程中,大模型仅作为小模型的教师,提供优化的指导,大模型本身不再进行训练优化;部署时,仅使用小模型。经过知识蒸馏获得的小模型,在能耗比上往往优于一般训练得到小模型,即在等模型容量的情况下取得更好的性能。同时,知识蒸馏也可以看作是大模型对自身的一种压缩,在维持自身分类性能基本不变的情况下,裁剪自身参数量至更小的规模。
[0004]传统的知识蒸馏任务往往是离线的,即需要先进行大模型的训练,再利用大模型进行蒸馏得到小模型。目前先进的知识蒸馏方法为在线模式的互学习知识蒸馏算法,弱化了师生的概念,在训练阶段同时引入多个学生模型,进行相互的学习,并在训练结束后挑选其中测试集性能最佳的模型作为部署模型。在线式知识蒸馏方法使得单阶段端到端地实现蒸馏任务成为可能,但这会加大训练阶段的负担,因为需要同时承载多个独立网络的训练,并构造各个网络之间的交互,计算额外的交互损失函数,反向传播会各个网络中进行梯度更新,相比于单网络传统训练,其训练时间及空间占用是倍增的。
[0005]同时,在线式知识蒸馏方法在训练过程中,信息流的倾向性不强,因为最终优化目标是同时提高多个网络的表现,即所有网络在结构和公式上对称的,而实际上蒸馏任务本身是非对称,仅有一个模型会出现在部署阶段,优化目标更应该放在尽可能提高该部署网络的性能上,这一矛盾在一定程度上限制了在线蒸馏的最终性能。

技术实现思路

[0006]鉴于上述,本专利技术的目的是提供一种基于辅助实例集成的分类网络知识蒸馏方法,该方法以采用多阶段分支法构建多阶段实例群,并使用了浅宽辅助分支作为辅助实例补足结构,能够极大提高辅助实例的分类及教学性能,从而有效提高蒸馏后的部署实例的性能,同时有效减少训练部署的时间及硬件需求,适用于多种时空受限或对精度要求更高的分类任务应用场景。
[0007]为实现上述专利技术目的,本专利技术提供的技术方案如下:
[0008]本专利技术实施例提供的一种基于辅助实例集成的分类网络知识蒸馏方法,包括以下步骤:
[0009]利用深度学习框架搭建部署实例,所述部署实例为含有多阶段串行网络结构的残差神经网络;
[0010]利用多阶段分支法构建多阶段实例群,所述多阶段实例群包括基于部署实例串行结构解构得到的多个级联子实例,以及基于各子实例之间的分支点扩展得到的浅宽辅助分支,将子实例与浅宽辅助分支组合形成仅训练期出现的辅助实例;
[0011]基于多阶段实例群构建多对一的网络蒸馏损失函数,所述多对一的网络蒸馏损失函数包括部署实例图像类别预测值与标签真值的交叉熵损失、辅助实例图像类别预测值与标签真值的交叉熵损失和辅助实例集成的图像类别预测值与部署实例图像类别预测值的KL散度;
[0012]对多阶段实例群进行模型训练和优化,在训练数据集上,基于多对一的网络蒸馏损失函数利用梯度下降算法优化参数,在训练结束后,固定并保存为多阶段实例群网络模型及网络模型参数;
[0013]对优化后的多阶段实例群网络模型进行部署,优化裁剪网络模型参数,得到图像分类模型,图像分类模型的输出即为图像类别预测结果。
[0014]优选地,所述基于部署实例串行结构解构得到的多个级联子实例,包括:
[0015]基于部署实例F的串行结构,选择N个分支点,将其解构为N+1个独立的级联子实例S
i
,从而导出其中间层特征f
i
,即对于输入图像x,有:
[0016][0017]其中,f
i
为第i个分支点输出的中间层特征,S
i
为第i个独立的子实例,S
i
接受输入f
i
‑1并将其转化为f
i
,则f
N+1
即为部署实例F关于x的输出F(x),F(x)为部署实例图像类别预测值。
[0018]优选地,所述基于各子实例之间的分支点扩展得到的浅宽辅助分支,将子实例与浅宽辅助分支组合形成仅训练期出现的辅助实例,包括:
[0019]基于N个分支点扩展设计对应的浅宽辅助分支B
i
,将中间层特征f
i
通过浅宽辅助分支B
i
转化为辅助实例图像类别预测值p
i
,其表达式为:
[0020]p
i
=A
i
(x)=B
i
(f
i
)
[0021]其中,p
i
为辅助实例A
i
基于输入给出的图像类别预测输出,B
i
为f
i
对应的浅宽辅助分支,A
i
为抽象形成的辅助实例,由前i个独立子实例与B
i
组成,即:
[0022]A
i
=S1*S2*

*S
i
*B
i
[0023]其中,*为网络结构上非线性函数的级联。
[0024]优选地,所述浅宽辅助分支的结构包括依次连接的局部平均池化层AP1、瓶颈块BNB1、局部平均池化层AP2、瓶颈块BNB2、局部平均池化层AP2、瓶颈块BNB3、全局平均池化层GAP以及一个全连接层FC。
[0025]优选地,每个瓶颈块包含两条支路,其中,一条支路包括卷积层CONV0、批归一化层BN0;另一条支路依次包括卷积层CONV1、批归一化层BN1、激活层ReLU1、卷积层CONV2、批归
一化层BN2,激活层ReLU2、卷积层CONV3、批归一化层BN3;最后由一个融合层ADD合并两条支路的信息,并后接激活层ReLU3。
[0026]优选地,所述多对一的网络蒸馏损失函数表达式为:
[0027][0028]其中,x为输入图像,y为x对应的独热码形式的标签真值,Φ为实例集成函数,L
CE
(a,b)为两个分布a和b的交叉熵损失,L
KLD
(a,b)为两个分布a和b的KL散度,α和β分别为控制损失比例的超参数。
[0029]优选地,所述实本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于辅助实例集成的分类网络知识蒸馏方法,其特征在于,包括以下步骤:利用深度学习框架搭建部署实例,所述部署实例为含有多阶段串行网络结构的残差神经网络;利用多阶段分支法构建多阶段实例群,所述多阶段实例群包括基于部署实例串行结构解构得到的多个级联子实例,以及基于各子实例之间的分支点扩展得到的浅宽辅助分支,将子实例与浅宽辅助分支组合形成仅训练期出现的辅助实例;基于多阶段实例群构建多对一的网络蒸馏损失函数,所述多对一的网络蒸馏损失函数包括部署实例图像类别预测值与标签真值的交叉熵损失、辅助实例图像类别预测值与标签真值的交叉熵损失和辅助实例集成的图像类别预测值与部署实例图像类别预测值的KL散度;对多阶段实例群进行模型训练和优化,在训练数据集上,基于多对一的网络蒸馏损失函数利用梯度下降算法优化参数,在训练结束后,固定并保存为多阶段实例群网络模型及网络模型参数;对优化后的多阶段实例群网络模型进行部署,优化裁剪网络模型参数,得到图像分类模型,图像分类模型的输出即为图像类别预测结果。2.根据权利要求1所述的基于辅助实例集成的分类网络知识蒸馏方法,其特征在于,所述基于部署实例串行结构解构得到的多个级联子实例,包括:基于部署实例F的串行结构,选择N个分支点,将其解构为N+1个独立的级联子实例S
i
,从而导出其中间层特征f
i
,即对于输入图像x,有:其中,f
i
为第i个分支点输出的中间层特征,S
i
为第i个独立的子实例,S
i
接受输入f
i
‑1并将其转化为f
i
,则f
N+1
即为部署实例F关于x的输出F(x),F(x)为部署实例图像类别预测值。3.根据权利要求2所述的基于辅助实例集成的分类网络知识蒸馏方法,其特征在于,所述基于各子实例之间的分支点扩展得到的浅宽辅助分支,将子实例与浅宽辅助分支组合形成仅训练期出现的辅助实例,包括:基于N个分支点扩展设计对应的浅宽辅助分支B
i
,将中间层特征f
i
通过浅宽辅助分支B
i
转化为辅助实例图像类别预测值p
i
,其表达式为:p
i
=A
i
(x)=B
i
(f
i
)其中,p
i
为辅助实例A
i
基于输入给出的图像类别预测输出,B
i
为f<...

【专利技术属性】
技术研发人员:田翔叶欣蒋荣欣陈耀武
申请(专利权)人:浙江大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1