一种基于结构重参数化的模型推理速度提升方法及装置制造方法及图纸

技术编号:33766155 阅读:29 留言:0更新日期:2022-06-12 14:16
本申请提供了一种基于结构重参数化的模型推理速度提升方法及装置,涉及卷积神经网络模型技术领域,包括:获取样本数据并按照多分支结构进行训练,得到训练模型,所述训练模型包括残差连接大层,且所述残差连接大层包括1

【技术实现步骤摘要】
一种基于结构重参数化的模型推理速度提升方法及装置


[0001]本专利技术属于卷积神经网络模型
,尤其涉及一种基于结构重参数化的模型推理速度提升方法及装置。

技术介绍

[0002]卷积神经网络已成为解决许多问题的主流方法。VGG在图像识别方面取得了巨大的成功,但是它仅使用了一个由conv、ReLU和pooling组成的简单体系结构。随着Inception、ResNet和DenseNet的出现,许多研究兴趣转移到了设计良好的体系结构上,使得模型变得越来越复杂。
[0003]许多复杂的卷积网络比简单的卷积网络提供更高的精度,但缺点也非常显著:(1)复杂的多分支设计虽然使得模型的精度大大的提高,但是复杂的结构使得模型的推理速度相应的大大降低;(2)虽然在现有技术中,通常采用Conv层和BN层融合的技术,从而减少网络层,提升推理速度,但是无法将一个像ResNet一样的残差连接大层进行融合,因为通常一个Conv层都由Conv+BN+Relu组成,即根据Conv层和BN层融合的技术无法将线性层与非线性层融合,不能提高模型的推理速度。

技术实现思路

[0004]本专利技术提供了一种基于结构重参数化的模型推理速度提升方法及装置,旨在解决上述中复杂的结构使得模型推理速度较低,现有融合技术无法融合像ResNet一样的残差连接大层的问题。
[0005]为了实现上述目的,本申请采用以下技术方案,包括:
[0006]获取样本数据并按照多分支结构进行训练,得到训练模型,所述训练模型包括残差连接大层,且所述残差连接大层包括1
×
1卷积层和3
×
3卷积层;
[0007]将所述残差连接大层的非线性层放在所述残差连接大层的最后一层,再根据Conv层与BN层融合的技术对所述1
×
1卷积层和所述3
×
3卷积层进行融合,得到卷积融合层并完成所述残差连接大层的融合,以得到结构重参数化的检测模型。
[0008]作为优选,所述Conv层与BN层融合的技术,包括:
[0009]分别将Conv层的表达式Conv=Wc*x+Bc和BN层的表达式带入公式y=BN(Conv(x))中进行计算,得到第一公式其中x为所述Conv层的输入,y为所述Conv层的输出,Wc为所述Conv层的权重,Bc为所述Conv层的偏置项,γ为所述BN层的缩放系数,β为所述BN层的偏移系数,E为所述BN层的均值,var为所述BN层的方差;
[0010]对所述第一公式进行变形,得到第二公式所述第
二公式为一元一次方程,完成所述Conv层和所述BN层的融合。
[0011]作为优选,所述将所述残差连接大层的非线性层放在所述残差连接大层的最后一层,再根据Conv层与BN层融合的技术对所述1
×
1卷积层和所述3
×
3卷积层进行融合,得到卷积融合层并完成所述残差连接大层的融合,以得到结构重参数化的检测模型,包括:
[0012]步骤一、将所述残差连接大层的非线性层放在所述残差连接大层的最后一层;
[0013]步骤二、预设所述残差连接大层的输入为输出为
[0014]步骤三、当C1=C2,H1=H2,W1=W2时,可得所述输入M1与所述输出M2的关系式为:且所述残差连接大层的权重为所述残差连接大层的偏置为其中表示从1开始,i表示通道,C1为所述残差连接大层的输入通道,C2为所述残差连接大层的输出通道,N是batch_size,H1、W1是输入特征的高和宽,H2,W2是输出特征的高和宽,var
(3)
,E
(3)

(3)

(3)
分别表示3x3 Conv之后的BN的方差、均值、偏置系数、缩放系数,var
(1)
,E
(1)

(1)

(1)
分别表示1x1 Conv之后的BN的方差、均值、偏置系数、缩放系数,var
(0)
,E
(0)

(0)

(0)
分别表示identity之后的BN的方差、均值、偏置系数、缩放系数,identity表示一个层;
[0015]步骤四、根据所述步骤一、所述步骤二、所述步骤三可得,所述残差连接大层按照所述Conv层与BN层融合的技术可以完成融合。
[0016]作为优选,所述当C1=C2,H1=H2,W1=W2时,可得所述输入M1与所述输出M2的关系式为:还包括:
[0017]当所述C1=C2,H1=H2,W1=W2不成立时,所述输入M1与所述输出M2的关系式为M2=BN(Conv(M1),var
(3)
,E
(3)

(3)

(3)
)+BN(Conv(M1*W1),var
(1)
,E
(1)

(1)

(1)
)。
[0018]作为优选,所述非线性层包括激活函数Relu。
[0019]一种基于结构重参数化的模型推理速度提升装置,包括:
[0020]模型训练模块:用于获取样本数据并按照多分支结构进行训练,得到训练模型,所述训练模型包括残差连接大层,且所述残差连接大层包括1
×
1卷积层和3
×
3卷积层;
[0021]结构重参数化的模型生成模块:用于将所述残差连接大层的非线性层放在所述残差连接大层的最后一层,再根据Conv层与BN层融合的技术对所述1
×
1卷积层和所述3
×
3卷积层进行融合,得到卷积融合层并完成所述残差连接大层的融合,以得到结构重参数化的检测模型。
[0022]作为优选,所述结构重参数化的模型生成模块,包括:
[0023]Conv层与BN层第一融合模块:用于分别将Conv层的表达式Conv=Wc*x+Bc和BN层的表达式带入公式y=BN(Conv(x))中进行计算,得到第一公式
其中x为所述Conv层的输入,y为所述Conv层的输出,Wc为所述Conv层的权重,Bc为所述Conv层的偏置项,γ为所述BN层的缩放系数,β为所述BN层的偏移系数,E为所述BN层的均值,var为所述BN层的方差;
[0024]Conv层与BN层第二融合模块:用于对所述第一公式进行变形,得到第二公式所述第二公式为一元一次方程,完成所述Conv层和所述BN层的融合。
[0025]作为优选,所述结构重参数化的模型生成模块,还包括:
[0026]第一结构重参数化模块:用于步骤一、将所述残差连接大层的非线性层放在所述残差连接大层的最后一层;
[0027]第二结构重参数化模块:用于步骤二、预设所述残差连接大层的输入为输出为
[0028]第三结构重参数化模块:用于步骤本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于结构重参数化的模型推理速度提升方法,其特征在于,包括:获取样本数据并按照多分支结构进行训练,得到训练模型,所述训练模型包括残差连接大层,且所述残差连接大层包括1
×
1卷积层和3
×
3卷积层;将所述残差连接大层的非线性层放在所述残差连接大层的最后一层,再根据Conv层与BN层融合的技术对所述1
×
1卷积层和所述3
×
3卷积层进行融合,得到卷积融合层并完成所述残差连接大层的融合,以得到结构重参数化的检测模型。2.根据权利要求1所述的一种基于结构重参数化的模型推理速度提升方法,其特征在于,所述Conv层与BN层融合的技术,包括:分别将Conv层的表达式Conv=Wc*x+Bc和BN层的表达式带入公式y=BN(Conv(x))中进行计算,得到第一公式其中x为所述Conv层的输入,y为所述Conv层的输出,Wc为所述Conv层的权重,Bc为所述Conv层的偏置项,γ为所述BN层的缩放系数,β为所述BN层的偏移系数,E为所述BN层的均值,var为所述BN层的方差;对所述第一公式进行变形,得到第二公式所述第二公式为一元一次方程,完成所述Conv层和所述BN层的融合。3.根据权利要求2所述的一种基于结构重参数化的模型推理速度提升方法,其特征在于,所述将所述残差连接大层的非线性层放在所述残差连接大层的最后一层,再根据Conv层与BN层融合的技术对所述1
×
1卷积层和所述3
×
3卷积层进行融合,得到卷积融合层并完成所述残差连接大层的融合,以得到结构重参数化的检测模型,包括:步骤一、将所述残差连接大层的非线性层放在所述残差连接大层的最后一层;步骤二、预设所述残差连接大层的输入为输出为步骤三、当C1=C2,H1=H2,W1=W2时,可得所述输入M1与所述输出M2的关系式为:且所述残差连接大层的权重为所述残差连接大层的偏置为其中表示从1开始,i表示通道,C1为所述残差连接大层的输入通道,C2为所述残差连接大层的输出通道,N是batch_size,H1、W1是输入特征的高和宽,H2,W2是输出特征的高和宽,var
(3)
,E
(3)

(3)

(3)
分别表示3x3 Conv之后的BN的方差、均值、偏置系数、缩放系数,var
(1)
,E
(1)

(1)

(1)
分别表示1x1 Conv之后的BN的方差、均值、偏置系数、缩放系数,var
(0)
,E
(0)

(0)

(0)
分别表示identity之后的BN的方差、均值、偏置系数、缩放系数,identity表示一个层;步骤四、根据所述步骤一、所述步骤二、所述步骤三可得,所述残差连接大层按照所述Conv层与BN层融合的技术可以完成融合。4.根据权利要求3所述的一种基于结构重参数化的模型推理速度提升方法,其特征在
于,所述当C1=C2,H1=H2,W1=W2时,可得所述输入M1与所述输出M2的关系式为:还包括:当所述C1=C2,H1=H2,W1=W2不成立时,所述输入M1与所述输出M2的关系式为M2=BN(Conv(M1),var
(3)
,E
(3)

(3)

(3)
)+BN(Conv(M1*W1),var
(1)
,E
(1)

(1)

(...

【专利技术属性】
技术研发人员:周祖煜白博文林波陈煜人张澎彬莫志敏张浩李天齐刘俊
申请(专利权)人:杭州领见数字农业科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1