模型压缩的方法、装置、电子设备及计算机存储介质制造方法及图纸

技术编号:21035942 阅读:19 留言:0更新日期:2019-05-04 06:01
本申请提供了模型压缩的方法、装置、电子设备及计算机存储介质。所述方法包括:获取训练样本数据,其中,所述训练样本数据包括有标签样本数据;利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数;根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,获得训练后的学生网络。

Model Compression Method, Device, Electronic Equipment and Computer Storage Media

【技术实现步骤摘要】
模型压缩的方法、装置、电子设备及计算机存储介质
本申请涉及人工智能领域,尤其涉及模型压缩的方法、装置、电子设备及计算机存储介质。
技术介绍
近年来,深度学习网络在计算机视觉领域的目标检测应用中取得了巨大的成功。但由于深度学习网络模型往往包含大量的模型参数,计算量大、处理速度慢,其应用也多在云端,在终端落地仍面临巨大的挑战。为了减少网络模型的冗余,国内外研究人员提出了蒸馏学习算法,在蒸馏学习中,通过将结构复杂的老师网络的知识提炼或者蒸馏到结构简单的学生网络模型,指导学生网络模型的训练,从而实现了对老师网络的压缩。但蒸馏后的学生网络性能不够理想,与老师网络的各方面检测性能仍存在一定差距。并且,当前的蒸馏学习都是基于两阶段(Two-stage)目标检测的网络,对单阶段(One-stage)目标检测中的应用也尚未得到探索。
技术实现思路
本申请提供了模型压缩的方法、装置、电子设备及计算机存储介质,能够使得模型压缩后得到的学生网络检测性能超越老师网络。第一方面,提供了一种模型压缩的方法,所述方法包括以下步骤:获取训练样本数据,其中,所述训练样本数据包括有标签样本数据;利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数;根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,获得训练后的学生网络。可选地,在所述获取训练样本数据之前,所述方法还包括:利用所述有标签样本数据对老师网络模型进行训练,得到所述老师网络。可选地,在所述获取训练样本数据之前,所述方法还包括:获取无标签样本数据,并利用所述老师网络对所述无标签样本数据进行标注,得到标注后的样本数据;将所述有标签样本数据和所述标注后的样本数据组成所述训练样本数据。可选地,所述自适应蒸馏损失函数是根据所述老师网络和所述学生网络模型对同一样本数据的学习结果的差异从而确定的损失函数。可选地,所述自适应蒸馏损失函数包括自适应蒸馏损失系数,所述自适应蒸馏损失系数用于调整所述训练样本数据中预定样本数据的权重,其中,所述预定样本数据包括所述老师网络难学习的样本和所述学生网络模型难模仿所述老师网络的样本。可选地,所述自适应蒸馏损失函数的公式为,ADL=ADW·KLADW=(1-e-KL+βT(q))γ其中,ADL为所述自适应蒸馏损失函数,ADW为所述自适应蒸馏损失系数,KL表示所述学生网络模型难模仿所述老师网络的样本的权重,T(q)表示所述老师网络难学习的样本的权重,γ、β表示权值。可选地,所述方法还包括:所述训练后的学生网络进行自学习的过程。第二方面,提供了一种模型压缩的装置,包括获取单元、训练单元以及反向传播单元,其中,所述获取单元用于获取训练样本数据,其中,所述训练样本数据包括有标签样本数据;所述训练单元用于利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数;所述反向传播单元用于根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,获得训练后的学生网络。可选地,所述训练单元还用于在所述获取单元获取训练样本数据之前,利用所述有标签样本数据对老师网络模型进行训练,得到所述老师网络。可选地,所述装置还包括标注单元,所述标注单元用于在所述获取训练样本数据之前,获取无标签样本数据,并利用所述老师网络对所述无标签样本数据进行标注,得到标注后的样本数据;所述标注单元还用于将所述有标签样本数据和所述标注后的样本数据组成所述训练样本数据。可选地,所述自适应蒸馏损失函数是根据所述老师网络和所述学生网络模型对同一样本数据的学习结果的差异所确定的损失函数。可选地,所述自适应蒸馏损失函数包括自适应蒸馏损失系数,所述自适应蒸馏损失系数用于调整所述训练样本数据中预定样本数据的权重,其中所述预定样本数据包括所述老师网络难学习的样本和所述学生网络模型难模仿所述老师网络的样本。可选地,所述自适应蒸馏损失函数的公式为,ADL=ADW·KLADW=(1-e-KL+βT(q))γ其中,ADL为所述自适应蒸馏损失函数,ADW为所述自适应蒸馏损失系数,KL表示所述学生网络模型难模仿所述老师网络的样本的权重,T(q)表示所述老师网络难学习的样本的权重,γ、β表示权值。可选地,所述方法还包括:所述训练后的学生网络进行自学习的过程。第三方面,提供了一种电子设备,包括处理器、输入设备、输出设备和存储器,所述处理器、输入设备、输出设备和存储器相互连接,其中,所述存储器用于存储计算机程序,所述计算机程序包括程序指令,所述处理器被配置用于调用所述程序指令,执行上述第一方面所述的方法第四方面,提供了一种计算机可读存储介质,所述计算机存储介质存储有计算机程序,所述计算机程序包括程序指令,所述程序指令当被处理器执行时使所述处理器执行上述第一方面的方法。基于本申请提供的模型压缩的方法、装置、电子设备以及计算机可读存储介质,通过获取训练样本数据,利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数,根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,从而获得训练后的学生网络。由于自适应蒸馏损失函数中包括控制老师网络难学习的样本和学生网络模型难模仿所述老师网络的样本权重的系数,使得老师网络从训练样本数据中提取的数据结构特征能有针对性的传递到学生网络中,从而使得学生网络的目标检测性能得到大大提升。附图说明为了更清楚地说明本申请实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。图1是本申请提供的一种模型压缩的方法的流程示意图;图2是本申请提供的一种模型压缩的方法中第一预测函数输出的概率分布与蒸馏温度参数T之间关系的示意图;图3是本申请提供的一种模型压缩的方法中正响应样本数量a和负响应样本数量b与学生网络训练结果之间的关系示意图;图4是本申请提供的一种模型压缩的方法中获得学生网络模型的自适应蒸馏损失函数ADL的流程示意图;图5是本申请提供的一种模型压缩的装置结构示意图;图6是本申请提供的一种电子设备结构示意框图。具体实施方式下面通过具体实施方式结合附图对本申请作进一步详细说明。在以下的实施方式中,很多细节描述是为了使得本申请能被更好的理解。然而,本领域技术人员可以毫不费力的认识到,其中部分特征在不同情况下是可以省略的,或者可以由其他方法所替代。在某些情况下,本申请相关的一些操作并没有在说明书中显示或描述,这是为了避免本申请的核心部分被过多的描述所淹没。对于本领域技术人员而言,详细描述这些相关操作并不是必要的,他们根据说明书中的描述以及本领域的一般技术知识即可完整了解相关操作。应当理解,当在本说明书和所附权利要求书中使用术语时,术语“包括”和“包含”指示所描述特征、整体、步骤、操作、元素和组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。需要说明的是,在本申请实施例中使用的术语是仅仅出于描述特定实施例的目的,而非旨在限制本申请。在本申请实本文档来自技高网...

【技术保护点】
1.一种模型压缩的方法,其特征在于,包括:获取训练样本数据,其中,所述训练样本数据包括有标签样本数据;利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数;根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,获得训练后的学生网络。

【技术特征摘要】
1.一种模型压缩的方法,其特征在于,包括:获取训练样本数据,其中,所述训练样本数据包括有标签样本数据;利用所述训练样本数据对所述老师网络和学生网络模型分别进行训练,得到自适应蒸馏损失函数和焦点损失函数;根据所述自适应蒸馏损失函数和所述焦点损失函数对所述学生网络模型进行反向传播,获得训练后的学生网络。2.根据权利要求1所述的方法,其特征在于,在所述获取训练样本数据之前,所述方法还包括:利用所述有标签样本数据对老师网络模型进行训练,得到所述老师网络。3.根据权利要求1或2所述的方法,其特征在于,在所述获取训练样本数据之前,所述方法还包括:获取无标签样本数据,并利用所述老师网络对所述无标签样本数据进行标注,得到标注后的样本数据;将所述有标签样本数据和所述标注后的样本数据组成所述训练样本数据。4.根据权利要求1所述的方法,其特征在于,所述自适应蒸馏损失函数是根据所述老师网络和所述学生网络模型对同一样本数据的学习结果的差异从而确定的损失函数。5.根据权利要求1或4所述的方法,其特征在于,所述自适应蒸馏损失函数包括自适应蒸馏损失系数,所述自适应蒸馏损失系数用于调整所述训练样本数据中预定样本数据的权重,其中,所述预定样本数据包括所述老师网络难学习的样本和所述学生网络模型难模仿所述老师网络的样本。6.根据权利要求5所述的方法,其特征在于,所述自适应蒸...

【专利技术属性】
技术研发人员:唐诗涛冯俐铜旷章辉张伟陈益民
申请(专利权)人:北京市商汤科技开发有限公司
类型:发明
国别省市:北京,11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1