基于剪枝卷积神经网络的分类方法及相关设备技术

技术编号:34474401 阅读:11 留言:0更新日期:2022-08-10 08:49
本发明专利技术提供一种基于剪枝卷积神经网络的分类方法及相关设备,包括:基于训练数据以及训练数据对应的标签训练得到预先训练好的分类模型;在预设的模型目标性能基础上通过对预先训练好的分类模型分别进行结构化剪枝与非结构化剪枝后得到剪枝后的分类模型;获取待分类的图片;将待分类的图片输入剪枝后的分类模型,得到对应的分类结果。本发明专利技术能够在预设的模型目标性能基础上对分类模型实现混合剪枝,从而最大程度上压缩简化模型。从而最大程度上压缩简化模型。从而最大程度上压缩简化模型。

【技术实现步骤摘要】
基于剪枝卷积神经网络的分类方法及相关设备


[0001]本专利技术涉及模型剪枝
,尤其涉及一种基于剪枝卷积神经网络的分类方法及相关设备。

技术介绍

[0002]模型剪枝是神经网络模型压缩领域的常用方法,用来压缩模型大小以及加速模型计算,通常情况下是通过裁剪掉神经网络权重中不重要的张量来达到降低整个神经网络的计算量的目的。按照剪枝粒度由小到大来分,具体的剪枝粒度包括:细粒度剪枝(fine

grained)、向量剪枝(vector

level)、核剪枝(kernel

level)、滤波器剪枝(Filter

level)以及层剪枝(layer

level)。
[0003]其中,细粒度剪枝就是对连接或者神经元进行剪枝,它是粒度最小的剪枝。向量剪枝相对于细粒度剪枝粒度更大,属于对卷积核内部(intra

kernel)的剪枝。核剪枝则是去除某个卷积核,它将丢弃对输入通道中对应计算通道的响应。滤波器剪枝是对整个卷积核组进行剪枝,会造成推理过程中输出特征通道数的改变。
[0004]上述细粒度剪枝、向量剪枝、核剪枝方法在参数量与模型性能之间取得了一定的平衡,但是网络的拓扑结构本身发生了变化,需要专门的算法设计来支持这种稀疏的运算,被称之为非结构化剪枝。
[0005]而滤波器剪枝只改变了网络中的滤波器组和特征通道数目,所获得的模型不需要专门的算法设计就能够运行,被称为结构化剪枝。除此之外还有对整个网络层的剪枝,它可以被看作是滤波器剪枝的变种,即所有的滤波器都丢弃。
[0006]混合剪枝就是从不同剪枝粒度入手对整个模型进行剪枝,而现有混合剪枝,常以压缩率为目标来剪枝模型,但剪枝之后模型性能下降多少无法控制,且在混合剪枝时,仅混合1

2种剪枝粒度,不能充分挖掘既定目标下的模型剪枝极限。
[0007]因此,如何在既定剪枝目标下通过混合多种剪枝粒度来最大程度压缩简化模型是亟需解决的问题。

技术实现思路

[0008]本专利技术提供一种基于剪枝卷积神经网络的分类方法及相关设备,用以解决上述问题。
[0009]本专利技术提供一种基于剪枝卷积神经网络的分类方法,包括:
[0010]基于训练数据以及训练数据对应的标签训练得到预先训练好的分类模型;
[0011]在预设的模型目标性能基础上通过对所述预先训练好的分类模型分别进行结构化剪枝与非结构化剪枝后得到剪枝后的分类模型;
[0012]获取待分类的图片;
[0013]将所述待分类的图片输入所述剪枝后的分类模型,得到对应的分类结果。
[0014]根据本专利技术提供的一种基于剪枝卷积神经网络的分类方法,所述结构化剪枝包括
卷积层剪枝与过滤器剪枝;
[0015]相应地,所述剪枝后的分类模型是在预设的模型目标性能基础上通过对预先训练好的分类模型分别进行结构化剪枝与非结构化剪枝后得到,包括:
[0016]S1、对预先训练好的分类模型中的每个卷积层的各个过滤器进行过滤器剪枝敏感度分析,得到满足预设的模型目标性能的最大卷积层剪枝率;
[0017]S2、判断所述最大卷积层剪枝率是否大于预设的卷积层剪枝率阈值;
[0018]在所述最大卷积层剪枝率大于预设的卷积层剪枝率阈值的情况下,对所述最大卷积层剪枝率所对应的卷积层进行卷积层剪枝;
[0019]在所述最大卷积层剪枝率不大于预设的卷积层剪枝率阈值的情况下,对所述最大卷积层剪枝率所对应的过滤器进行过滤器剪枝;
[0020]S3、重复所述S2直到穷尽所有最大卷积层剪枝率,从而得到第一剪枝后模型;
[0021]S4、对所述第一剪枝后模型进行非结构化剪枝,得到第二剪枝后模型作为剪枝后的分类模型。
[0022]根据本专利技术提供的一种基于剪枝卷积神经网络的分类方法,所述对预先训练好的分类模型中的每个卷积层的各个过滤器进行过滤器剪枝敏感度分析,得到满足预设的模型目标性能的最大卷积层剪枝率,包括:
[0023]为每个卷积层配置待分析剪枝率集合;其中,所述待分析剪枝率集合包括多个待分析剪枝率;
[0024]依次计算每个卷积层在各个待分析剪枝率下的模型性能,并判断计算得到的模型性能是否满足所述预设的模型目标性能;
[0025]在计算得到的模型性能满足所述预设的模型目标性能的情况下,将每个卷积层中数值最大的待分析剪枝率作为最大卷积层剪枝率。
[0026]根据本专利技术提供的一种基于剪枝卷积神经网络的分类方法,所述非结构剪枝为卷积核剪枝;
[0027]相应地,所述S4、对所述第一剪枝后模型进行非结构化剪枝,得到第二剪枝后模型作为剪枝后的分类模型,包括:
[0028]S41、对所述第一剪枝后模型中所有权重进行规范化处理,得到规范后参数;
[0029]S42、按照数值大小对所述规范后参数进行排序,得到规范后参数序列;
[0030]S43、利用预定的权重剪枝定位方法,从所述规范后参数序列中确定待剪枝权重;
[0031]S44、判断所述待剪枝权重对应的核剪枝率是否大于预设的核剪枝率阈值,
[0032]在所述待剪枝权重对应的核剪枝率大于预设的核剪枝率阈值的情况下,对所述核剪枝率对应的卷积核进行卷积核剪枝;
[0033]在所述待剪枝权重对应的核剪枝率不大于预设的核剪枝率阈值的情况下,对所述待剪枝权重进行权重剪枝;
[0034]S45、重复所述S44直到穷尽所有待剪枝权重,从而得到第二剪枝后模型,将第二剪枝后模型作为剪枝后的分类模型。
[0035]根据本专利技术提供的一种基于剪枝卷积神经网络的分类方法,所述按照数值大小对所述规范后参数进行排序是从小到大对所述规范后参数进行排序,从而得到规范后参数序列;
[0036]相应地,所述S43、利用预定的权重剪枝定位方法,从所述规范后参数序列中确定待剪枝权重,包括:
[0037]S431、对所述规范后参数序列中前i个规范后参数进行剪枝,得到初始第二剪枝后模型,其中,i初始值为1且小于所述规范后参数序列中所有规范后参数的总数量;
[0038]S432、判断所述初始第二剪枝后模型是否满足所述预设的模型目标性能;
[0039]S433、在所述初始第二剪枝后模型满足所述预设的模型目标性能的情况下,将i+1作为新的i,并重复所述S431至所述S432,直到所述初始第二剪枝后模型不满足所述预设的模型目标性能,根据所述新的i将所述规范后参数序列中前i个规范后参数对应的权重确定为待剪枝权重。
[0040]根据本专利技术提供的一种基于剪枝卷积神经网络的分类方法,所述预设的模型目标性能为剪枝后的分类模型性能下降不超过性能阈值。
[0041]本专利技术还提供一种基于剪枝卷积神经网络的分类装置,包括:
[0042]模型训练模块,用于基于训练数据以及训练数据对应的标签训练得到预先训练好的分类模本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于剪枝卷积神经网络的分类方法,其特征在于,包括:基于训练数据以及训练数据对应的标签训练得到预先训练好的分类模型;在预设的模型目标性能基础上通过对所述预先训练好的分类模型分别进行结构化剪枝与非结构化剪枝后得到剪枝后的分类模型;获取待分类的图片;将所述待分类的图片输入剪枝后的分类模型,得到对应的分类结果。2.根据权利要求1所述的基于剪枝卷积神经网络的分类方法,其特征在于,所述结构化剪枝包括卷积层剪枝与过滤器剪枝;相应地,在预设的模型目标性能基础上通过对所述预先训练好的分类模型分别进行结构化剪枝与非结构化剪枝后得到剪枝后的分类模型,包括:S1、对预先训练好的分类模型中的每个卷积层的各个过滤器进行过滤器剪枝敏感度分析,得到满足预设的模型目标性能的最大卷积层剪枝率;S2、判断所述最大卷积层剪枝率是否大于预设的卷积层剪枝率阈值;在所述最大卷积层剪枝率大于预设的卷积层剪枝率阈值的情况下,对所述最大卷积层剪枝率所对应的卷积层进行卷积层剪枝;在所述最大卷积层剪枝率不大于预设的卷积层剪枝率阈值的情况下,对所述最大卷积层剪枝率所对应的过滤器进行过滤器剪枝;S3、重复所述S2直到穷尽所有最大卷积层剪枝率,从而得到第一剪枝后模型;S4、对所述第一剪枝后模型进行非结构化剪枝,得到第二剪枝后模型作为剪枝后的分类模型。3.根据权利要求2所述的基于剪枝卷积神经网络的分类方法,其特征在于,所述对预先训练好的分类模型中的每个卷积层的各个过滤器进行过滤器剪枝敏感度分析,得到满足预设的模型目标性能的最大卷积层剪枝率,包括:为每个卷积层配置待分析剪枝率集合;其中,所述待分析剪枝率集合包括多个待分析剪枝率;依次计算每个卷积层在各个待分析剪枝率下的模型性能,并判断计算得到的模型性能是否满足所述预设的模型目标性能;在计算得到的模型性能满足所述预设的模型目标性能的情况下,将每个卷积层中数值最大的待分析剪枝率作为最大卷积层剪枝率。4.根据权利要求2所述的基于剪枝卷积神经网络的分类方法,其特征在于,所述非结构剪枝为卷积核剪枝;相应地,所述S4、对所述第一剪枝后模型进行非结构化剪枝,得到第二剪枝后模型作为剪枝后的分类模型,包括:S41、对所述第一剪枝后模型中所有权重进行规范化处理,得到规范后参数;S42、按照数值大小对所述规范后参数进行排序,得到规范后参数序列;S43、利用预定的权重剪枝定位方法,从所述规范后参数序列中确定待剪枝权重;S44、判断所述待剪枝权重对应的核剪枝率是否大于预设的核剪枝率阈值,在所述待剪...

【专利技术属性】
技术研发人员:陆强
申请(专利权)人:际络科技上海有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1