基于注意力机制的mobilenet-v1知识蒸馏方法、存储器及终端设备技术

技术编号:28422447 阅读:15 留言:0更新日期:2021-05-11 18:30
本发明专利技术提供了一种基于注意力机制的mobilenet‑v1知识蒸馏方法、存储器及终端设备,其中,包括:分别选择复杂模型WRN‑50‑8以及简单模型mobilenet‑v对应的特定中间层,用以进行注意图的知识转移;处理得到复杂模型和简单模型的中间层所对应的注意力图之间的损失,记为损失值一;处理获得复杂模型和简单模型的Logit层之间的KL散度;处理获得简单模型的交叉熵损失,记为损失值二;根据损失值一、KL散度及损失值二处理得到总损失;损失值一、RL散度、损失值二以及总损失用以简单模型的参数的计算。其技术方案的有益效果在于,与现有其他蒸馏方式相比,大幅提高mobilenet‑v1学生网络的识别精度和准确率,并可以将其部署在算力有限的设备。

【技术实现步骤摘要】
基于注意力机制的mobilenet-v1知识蒸馏方法、存储器及终端设备
本专利技术涉及深度学习模型压缩
,尤其涉及基于注意力机制的mobilenet-v1知识蒸馏方法、存储器及终端设备。
技术介绍
知识蒸馏是将复杂模型(教师网络)中的暗知识(darkknowledge)迁移到简单模型(学生网络)中去,一般来说,复杂模型具有强大的能力和表现,而简单模型则更为紧凑。通过知识蒸馏,希望简单模型能尽可能逼近亦或是超过复杂模型,从而用更少的复杂度来获得类似的预测效果。(GeoffreyHinton,OriolVinyals,JeffDean.“DistillingtheKnowledgeinaNeuralNetwork”InNIPS,2014)首次提出了知识蒸馏的概念,通过引入教师网络的软目标(softtargets)以诱导学生网络的训练。近些年来出现了许多知识蒸馏的方法,而不同的方法对于网络中需要转移的暗知识定义也各不相同。(SergeyZagoruyko,NikosKomodakis.“PAYINGMOREATTENTIONTOATTENTION:IMPROVINGTHEPERFORMANCEOFCONVOLUTIONALNEURALNETWORKSVIAATTENTIONTRANSFER”InICLR,2017)首次提出利用注意力机制对WRN(WideResNet)网络进行蒸馏。由于WRN网络结构依然很大,不适合部署在计算能力有限的设备(比如移动终端)。
技术实现思路
针对现有的在无法在计算能力有限的设备上部署WRN网络结存在的问题。现提供一种方便对对简单模型进行蒸馏以适应有限算力的端侧设备的基于注意力机制的mobilenet-v1知识蒸馏方法。具体包括以下:一种基于注意力机制的mobilenet-v1知识蒸馏方法,其中,包括:分别选择复杂模型WRN-50-8以及简单模型mobilenet-v(MobileNets基于一种流线型结构使用深度可分离卷积来构造轻型权重深度神经网络。)对应的特定中间层,用以进行注意图的知识转移;处理得到所述复杂模型和所述简单模型的中间层所对应的注意力图之间的损失,记为损失值一;处理获得所述复杂模型和所述简单模型的Logit层之间的KL散度;处理获得所述简单模型的交叉熵损失,记为损失值二;根据所述损失值一、所述KL散度及所述损失值二处理得到总损失;所述损失值一、所述RL散度、所述损失值二以及所述总损失用以所述简单模型的参数的计算。优选的所述进行注意图的知识转移方法包括:从所述复杂模型的结构中选择预定数量的中间层输出作为计算注意力图的中间层特征图,记为中间特征图一;从所述简单模型的结构中选择预定数量的中间层输出作为计算注意力图的中间层特征图,记为中间特征图二;将所述中间特征图一的知识转移给所述中间特征图二。上述技术方案中,注意力图知识转移将教师网络中间层特征图经过计算获得教师网络中间层注意力图,再将与其对应的学生网络中间层特征图经过同样的计算过程得到学生网络中间层注意力图。优选的,处理得到所述简单模型或所述复杂模型的中间层对应的注意力图的方法如下式所示:设张量A∈RC*H*W为所述的简单模型或复杂模型的某个中间层特征图,即特征图A有C个通道,每个通道为H*W的二维矩阵,则注意力图按照如下公式计算:其中,注意力图计算结果Q∈RH*W,A(i,:,:)表示第i个通道的H*W二维矩阵。优选的,处理得到所述复杂模型和所述简单模型的中间层所对应的注意力图之间的损失的方法如下式所示:其中,表示复杂模型WRN-50-8的第j个注意力图,表示对应的简单模型mobilenet-v1的第j个注意力图,||X||表示计算矩阵X的L2正则。优选的,计算所述KL散度的方法包括:所述复杂模型的logit层,是WRN-50-8网络的fc层的输出lT∈R1*1*10;所述的简单模型的logit层,是mobilenet-v1网络的fc层的输出lS∈R1*1*10;计算所述简单模型和所述复杂模型logit层之间的KL散度,如下式所示:其中,lT[i]表示复杂模型fc层的输出lT的第i个值;lS[i]表示简单模型fc层的输出lS的第i个值;T表示温度参数,这里取值为4。优选的,处理获得所述简单模型的交叉熵损失的方法包括,将简单模型softmax层的输出与训练数据的真值标签计算交叉熵损失Lce。优选的,计算所述总损失的方法如下式所示:ltotal=α*Lkl+(1-α)*Lce+β*LAT其中,参数α取值为0.9,参数β取值为1000,Lkl表示KL散度,Lce表示交叉熵损失,LAT表示注意力图之间的损失。还包括一种非易失性存储器,其中存储有软件,其中,所述软件用以实现权利要上述的基于注意力机制的mobilenet-v1知识蒸馏方法。还包括一种终端设备,包括一个或多个处理器和与其耦合的一个或多个存储器,其中,所述一个或多个存储器用于存储计算机程序代码,所述计算机程序代码包括计算机指令;所述一个或多个处理器用于执行所述计算机指令并实现上述的基于注意力机制的mobilenet-v1知识蒸馏方法。上述技术方案具有如下优点或有益效果:与现有其他蒸馏方式相比,大幅提高mobilenet-v1学生网络的识别精度和准确率,并可以将其部署在算力有限的设备。附图说明图1是本专利技术中的一种基于注意力机制的mobilenet-v1知识蒸馏方法的实施例的流程示意图;图2是本专利技术中的一种基于注意力机制的mobilenet-v1知识蒸馏方法的实施例中,关于复杂模型即教师网络WRN-50-8的结构示意图;图3是本专利技术中的一种基于注意力机制的mobilenet-v1知识蒸馏方法的实施例中,关于简单模型即学生网络mobilenet-v的结构示意图。具体实施方式下面将结合本专利技术实施例中的附图,对本专利技术实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本专利技术一部分实施例,而不是全部的实施例。基于本专利技术中的实施例,本领域普通技术人员在没有作出创造性劳动的前提下所获得的所有其他实施例,都属于本专利技术保护的范围。需要说明的是,在不冲突的情况下,本专利技术中的实施例及实施例中的特征可以相互组合。下面结合附图和具体实施例对本专利技术作进一步说明,但不作为本专利技术的限定。具体包括以下内容:一种基于注意力机制的mobilenet-v1知识蒸馏方法的实施例,其中,包括:分别选择复杂模型即教师网络WRN-50-8以及简单模型即学生网络mobilenet-v对应的特定中间层,用以进行注意图的知识转移;处理得到复杂模型和简单模型的中间层所对应的注意力图之间的损失,记为损失值一;处理获得复杂模型和简单模型本文档来自技高网...

【技术保护点】
1.一种基于注意力机制的mobilenet-v1知识蒸馏方法,其特征在于,包括:/n分别选择复杂模型以及简单模型对应的特定中间层,用以进行注意图的知识转移;/n分别处理得到所述复杂模型和所述简单模型的所述特定中间层所对应的注意力图之间的损失并记为第一损失值,根据所述第一损失值对所述简单模型中的所述特定中间层进行更新;/n分别处理获得所述复杂模型和所述简单模型的Logit层的KL散度;/n处理获得所述简单模型的交叉熵损失,记为损失值二;/n根据所述损失值一、所述KL散度及所述损失值二处理得到总损失;/n所述损失值一、所述RL散度、所述损失值二以及所述总损失用以所述简单模型的参数的计算。/n

【技术特征摘要】
1.一种基于注意力机制的mobilenet-v1知识蒸馏方法,其特征在于,包括:
分别选择复杂模型以及简单模型对应的特定中间层,用以进行注意图的知识转移;
分别处理得到所述复杂模型和所述简单模型的所述特定中间层所对应的注意力图之间的损失并记为第一损失值,根据所述第一损失值对所述简单模型中的所述特定中间层进行更新;
分别处理获得所述复杂模型和所述简单模型的Logit层的KL散度;
处理获得所述简单模型的交叉熵损失,记为损失值二;
根据所述损失值一、所述KL散度及所述损失值二处理得到总损失;
所述损失值一、所述RL散度、所述损失值二以及所述总损失用以所述简单模型的参数的计算。


2.根据权利1所述的方法,其特征在于,所述进行注意图的知识转移方法包括:
从所述复杂模型的结构中选择预定数量的中间层输出作为计算注意力图的中间层特征图,记为中间特征图一;
从所述简单模型的结构中选择预定数量的中间层输出作为计算注意力图的中间层特征图,记为中间特征图二;
将所述中间特征图一的知识转移给所述中间特征图二。


3.根据权利1所述的方法,其特征在于,处理得到所述简单模型或所述复杂模型的中间层对应的注意力图的方法如下式所示:
设张量A∈RC*H*W为所述的简单模型或复杂模型的某个中间层特征图,即特征图A有C个通道,每个通道为H*W的二维矩阵,则注意力图按照如下公式计算:



其中,注意力图计算结果Q∈RH*W,A(i,:,:)表示第i个通道的H*W二维矩阵。


4.根据权利1所述的方法,其特征在于,处理得到所述复杂模型和所述简单模型的中间层所对应的注意力图之间的损失的方法如下式所示:



其中,表示复杂模型WRN-50-8的第...

【专利技术属性】
技术研发人员:黄明飞姚宏贵梁维斌王昊
申请(专利权)人:开放智能机器上海有限公司
类型:发明
国别省市:上海;31

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1