神经网络模型的训练、数据处理方法、装置以及介质制造方法及图纸

技术编号:39593236 阅读:10 留言:0更新日期:2023-12-03 19:48
本公开提供了一种神经网络模型的训练方法、数据处理方法、装置以及存储介质,其中的训练方法包括:基于训练集对神经网络初始模型进行初始训练,获得神经网络初始最优模型;对神经网络初始最优模型进行测试处理,获得神经网络初始最优模型中各个网络层的资源消耗信息;基于资源消耗信息在各个网络层中选取待剪枝层,并确定待剪枝层的剪枝策略;根据与待剪枝层对应的剪枝策略,对待剪枝层进行剪枝处理,生成神经网络剪枝模型;基于训练集对神经网络剪枝模型进行调整,获得神经网络目标模型。本公开避免了在剪枝处理中对人工经验的依赖,训练得到的模型精度高,并减少模型处理时占用的计算资源和内存负载,避免了计算资源的过度消耗。耗。耗。

【技术实现步骤摘要】
神经网络模型的训练、数据处理方法、装置以及介质


[0001]本公开涉及人工智能
,尤其涉及一种神经网络模型的训练方法、数据处理方法、装置以及存储介质。

技术介绍

[0002]随着人工智能技术的发展,在计算机视觉、机器翻译、语音识别等领域中广泛采用神经网络模型对人脸图像、文本、语音等数据进行处理。随着神经网络技术的发展,神经网络结构的设计更加复杂,卷积核对应的权重矩阵中包含的权重数量也越来越多,增加了神经网络的运算工作量和参数量。神经网络模型随着网络深度的增加能够获得更好的性能,但是会增加计算量和参数量,降低处理速度,并且,神经网络难以部署在移动终端等计算能力和存储空间有限的硬件资源上。目前,可以在移动终端上部署轻量化模型,或者将进行了剪枝处理的神经网络模型部署在移动终端上。但是,现有的轻量化模型在移动终端仅能满足基本的实时性需求,且计算消耗及内存占用仍未达到理想状态,并且,现有的剪枝处理方法根据依赖人工经验确定剪枝率等,使剪枝后的神经网络模型的精度大幅下降。

技术实现思路

[0003]有鉴于此,本专利技术要解决的一个技术问题是提供一种神经网络模型的训练方法、数据处理方法、装置以及存储介质。
[0004]根据本公开的第一方面,提供一种神经网络模型的训练方法,包括:基于训练集对神经网络初始模型进行初始训练,获得神经网络初始最优模型;对所述神经网络初始最优模型进行测试处理,获得所述神经网络初始最优模型中各个网络层的资源消耗信息;基于所述资源消耗信息在所述各个网络层中选取待剪枝层,并确定所述待剪枝层的剪枝策略;根据与所述待剪枝层对应的剪枝策略,对所述待剪枝层进行剪枝处理,生成神经网络剪枝模型;基于所述训练集对所述神经网络剪枝模型进行调整,获得神经网络目标模型。
[0005]可选地,所述待剪枝层包括:卷积层;所述剪枝策略包括:所述待剪枝层的目标剪枝率;所述确定所述待剪枝层的剪枝策略包括:对所述待剪枝层进行剪枝实验,用以确定所述待剪枝层的目标剪枝率。
[0006]可选地,对所述待剪枝层进行剪枝实验包括:在剪枝实验启动之后,基于初始剪枝率对所述待剪枝层进行剪枝处理,并确定所述神经网络初始最优模型的模型精度损失;在所述模型精度损失小于精度损失阈值的情况下,基于实验剪枝率对所述待剪枝层进行至少一次剪枝处理,直至所述模型精度损失大于或等于所述精度损失阈值,则终止剪枝实验。
[0007]可选地,对所述待剪枝层进行剪枝处理包括:基于所述初始剪枝率或所述实验剪枝率确定实验剪除通道数量;确定所述待剪枝层中各个卷积核的权重值和;在所述各个卷积核中选取所述权重值和的绝对值最小、并且数量等于所述实验剪除通道数量的卷积核作为待剪枝卷积核;对所述待剪枝卷积核进行剪除处理。
[0008]可选地,基于所述初始剪枝率以及剪枝率增长值确定所述实验剪枝率;其中,所述
实验剪枝率随着所述待剪枝层进行剪枝处理的次数增加而增加。
[0009]可选地,所述确定所述待剪枝层的目标剪枝率包括:在基于所述初始剪枝率对所述待剪枝层进行剪枝处理之后,如果确定所述模型精度损失大于或等于所述精度损失阈值,则确定所述目标剪枝率为0。
[0010]可选地,所述确定所述待剪枝层的目标剪枝率包括:在基于所述初始剪枝率对所述待剪枝层进行剪枝处理之后,如果确定所述模型精度损失小于所述精度损失阈值,则获取与对所述待剪枝层进行的倒数第二次剪枝处理相对应的所述初始剪枝率或所述实验剪枝率,作为所述目标剪枝率。
[0011]可选地,所述根据与所述待剪枝层对应的剪枝策略,对所述待剪枝层进行剪枝处理包括:基于所述目标剪枝率确定所述待剪枝层的目标剪除通道数量;确定所述待剪枝层中各个卷积核的权重值和;在所述各个卷积核中选取所述权重值和的绝对值最小、并且数量等于所述目标剪除通道数量的卷积核作为目标剪枝卷积核;对所述目标剪枝卷积核进行剪除处理。
[0012]可选地,所述资源消耗信息包括:网络层运行消耗时长;所述对所述神经网络初始最优模型进行测试处理,获得所述神经网络初始最优模型中各个网络层的资源消耗信息包括:将所述神经网络初始最优模型部署在目标设备上;使用模型测速工具对所述神经网络初始最优模型进行测试处理,获得所述神经网络初始最优模型中各个网络层的网络层运行消耗时长。
[0013]可选地,所述基于所述资源消耗信息在所述各个网络层中选取待剪枝层包括:设置网络层运行消耗时长占所述各个网络层的运行消耗总时长的比例阈值;基于所述比例阈值确定运行消耗时长阈值;在所述各个网络层中选取所述网络层运行消耗时长大于所述运行消耗时长阈值的网络层,作为所述待剪枝层。
[0014]可选地,所述基于训练集对神经网络初始模型进行初始训练,获得神经网络初始最优模型包括:使用训练集对所述神经网络初始模型进行训练,通过损失函数获得所述神经网络初始模型的损失信息;根据所述神经网络初始模型的损失信息并通过梯度反向传播算法对所述神经网络初始模型的参数进行调整,获得所述神经网络初始最优模型。
[0015]可选地,所述基于所述训练集对所述神经网络剪枝模型进行调整,获得神经网络目标模型包括:使用所述训练集对所述神经网络剪枝模型进行训练,通过损失函数获得所述神经网络剪枝模型的损失信息;根据所述神经网络剪枝模型的损失信息对所述神经网络剪枝模型的参数进行调整,获得所述神经网络目标模型。
[0016]可选地,所述训练集中的训练样本包括:人脸图像样本。
[0017]根据本公开的第二方面,提供一种数据处理方法,包括:获取神经网络目标模型;其中,所述神经网络目标模型是通过如上所述的训练方法训练得到;使用所述神经网络目标模型对待处理数据进行处理,获得处理结果。
[0018]可选地,所述待处理数据包括:人脸图像数据;所述处理结果包括:人脸检测结果。
[0019]根据本公开的第三方面,提供一种神经网络模型的训练装置,包括:初始训练模块,用于基于训练集对神经网络初始模型进行初始训练,获得神经网络初始最优模型;模型测试模块,用于对所述神经网络初始最优模型进行测试处理,获得所述神经网络初始最优模型中各个网络层的资源消耗信息;策略确定模块,用于基于所述资源消耗信息在所述各
个网络层中选取待剪枝层,并确定所述待剪枝层的剪枝策略;模型剪枝模块,用于根据与所述待剪枝层对应的剪枝策略,对所述待剪枝层进行剪枝处理,生成神经网络剪枝模型;模型调整模块,用于基于所述训练集对所述神经网络剪枝模型进行调整,获得神经网络目标模型。
[0020]根据本公开的第四方面,提供一种数据处理装置,包括:获取模块,用于获取神经网络目标模型;其中,所述神经网络目标模型是通过如上所述的训练方法训练得到;处理模块,用于使用所述神经网络目标模型对待处理数据进行处理,获得处理结果。
[0021]根据本公开的第五方面,提供一种神经网络模型的训练装置,包括:存储器;以及耦接至所述存储器的处理器,所述处理器被配置为基于存储在所述存储器中的指令,执行如上所述的神经网络模型的训练方法。
[0本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种神经网络模型的训练方法,包括:基于训练集对神经网络初始模型进行初始训练,获得神经网络初始最优模型;对所述神经网络初始最优模型进行测试处理,获得所述神经网络初始最优模型中各个网络层的资源消耗信息;基于所述资源消耗信息在所述各个网络层中选取待剪枝层,并确定所述待剪枝层的剪枝策略;根据与所述待剪枝层对应的剪枝策略,对所述待剪枝层进行剪枝处理,生成神经网络剪枝模型;基于所述训练集对所述神经网络剪枝模型进行调整,获得神经网络目标模型。2.如权利要求1所述的方法,其中,所述待剪枝层包括:卷积层;所述剪枝策略包括:所述待剪枝层的目标剪枝率;所述确定所述待剪枝层的剪枝策略包括:对所述待剪枝层进行剪枝实验,用以确定所述待剪枝层的目标剪枝率。3.如权利要求2所述的方法,对所述待剪枝层进行剪枝实验包括:在剪枝实验启动之后,基于初始剪枝率对所述待剪枝层进行剪枝处理,并确定所述神经网络初始最优模型的模型精度损失;在所述模型精度损失小于精度损失阈值的情况下,基于实验剪枝率对所述待剪枝层进行至少一次剪枝处理,直至所述模型精度损失大于或等于所述精度损失阈值,则终止剪枝实验。4.如权利要求3所述的方法,对所述待剪枝层进行剪枝处理包括:基于所述初始剪枝率或所述实验剪枝率确定实验剪除通道数量;确定所述待剪枝层中各个卷积核的权重值和;在所述各个卷积核中选取所述权重值和的绝对值最小、并且数量等于所述实验剪除通道数量的卷积核作为待剪枝卷积核;对所述待剪枝卷积核进行剪除处理。5.一种数据处理方法,包括:获取神经网络目标模型;其中,所述神经网络目...

【专利技术属性】
技术研发人员:金勇逸甘启张璐陶明
申请(专利权)人:上海任意门科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1