图处理网络模型的训练方法、装置、电子设备和存储介质制造方法及图纸

技术编号:29050009 阅读:20 留言:0更新日期:2021-06-26 06:12
本公开公开了一种图处理网络模型的训练方法、装置、电子设备和存储介质,尤其涉及深度学习、计算机视觉等人工智能技术领域。其中,具体实现方案为:将训练样本分别输入学生网络和教师网络,以获取学生网络的第i层输出的第一特征图和所述教师网络的第i层输出的第二特征图;根据所述第一特征图与所述第二特征图间的差异,确定所述学生网络对应的第一修正梯度;获取所述学生网络输出的第一软标签及所述教师网络输出的第二软标签;根据所述第一软标签与所述第二软标签的差异,确定所述学生网络中对应的第二修正梯度;基于所述第一修正梯度及所述第二修正梯度,对所述学生网络进行修正。由此,提高了学生网络的学习能力和性能。提高了学生网络的学习能力和性能。提高了学生网络的学习能力和性能。

【技术实现步骤摘要】
图处理网络模型的训练方法、装置、电子设备和存储介质


[0001]本公开涉及计算机
,具体涉及深度学习、计算机视觉等人工智能
,尤其涉及一种图处理网络模型的训练方法、装置、电子设备和存储介质。

技术介绍

[0002]随着计算机技术的发展,深度学习在各个领域取得了卓越的突破,由于人工神经网络具有强大的自学习能力,其在模式识别、智能机器人、自动控制、生物、医学和经济等领域应用越来越广泛。不过,越先进的网络模型所需的模型参数越多,占用的存储空间和计算资源也越大,由此,知识蒸馏(也称为教师

学生网络)应运而生。在使用教师网络训练学生网络时,如何提高学生网络的效果,成为目前亟待解决的问题。

技术实现思路

[0003]本公开提供了一种图处理网络模型的训练方法、装置、电子设备和存储介质。
[0004]本公开一方面,提供了一种图处理网络模型的训练方法,包括:
[0005]将训练样本分别输入学生网络和教师网络,以获取所述学生网络的第i层输出的第一特征图和所述教师网络的第i层输出的第二特征图,其中,i为大于或等于1且小于或等于N的整数,其中,N为所述学生网络及所述教师网络中包含的网络层的数量;
[0006]根据所述第一特征图与所述第二特征图间的差异,确定所述学生网络对应的第一修正梯度;
[0007]获取所述学生网络输出的第一软标签及所述教师网络输出的第二软标签;
[0008]根据所述第一软标签与所述第二软标签的差异,确定所述学生网络中对应的第二修正梯度
[0009]基于所述第一修正梯度及所述第二修正梯度,对所述学生网络进行修正。
[0010]本公开的另一方面,提供了一种图处理网络模型的训练装置,包括:
[0011]第一获取模块,用于将训练样本分别输入学生网络和教师网络,以获取所述学生网络的第i层输出的第一特征图和所述教师网络的第i层输出的第二特征图,其中,i为大于或等于1且小于或等于N的整数,其中,N为所述学生网络及所述教师网络中包含的网络层的数量;
[0012]第一确定模块,用于根据所述第一特征图与所述第二特征图间的差异,确定所述学生网络对应的第一修正梯度;
[0013]第二获取模块,用于获取所述学生网络输出的第一软标签及所述教师网络输出的第二软标签;
[0014]第二确定模块,用于根据所述第一软标签与所述第二软标签的差异,确定所述学生网络中对应的第二修正梯度;
[0015]修正模块,用于基于所述第一修正梯度及所述第二修正梯度,对所述学生网络进行修正。
[0016]本公开的另一方面,提供了一种电子设备,包括:
[0017]至少一个处理器;以及
[0018]与所述至少一个处理器通信连接的存储器;其中,
[0019]所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行上述一方面实施例所述的图处理网络模型的训练方法。
[0020]本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其上存储有计算机程序,所述计算机指令用于使所述计算机执行上述一方面实施例所述的图处理网络模型的训练方法。
[0021]本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现上述一方面实施例所述的图处理网络模型的训练方法。
[0022]本公开提供的图处理网络模型的训练方法、装置、电子设备和存储介质,可以先将训练样本分别输入学生网络和教师网络,以获取学生网络的第i层输出的第一特征图和教师网络的第i层输出的第二特征图,之后根据第一特征图与第二特征图间的差异,确定学生网络对应的第一修正梯度,再获取学生网络输出的第一软标签及教师网络输出的第二软标签,之后根据第一软标签与第二软标签的差异,确定学生网络中对应的第二修正梯度,从而基于第一修正梯度及第二修正梯度,对学生网络进行修正。由此,在对学生网络进行训练时,既考虑到了局部信息,又关注到了全局信息,从而可使得训练生成的学生网络,具有和教师网络更加相似的特征,并且具有更好的学习能力和特征表达能力,从而学生网络的效果和性能得到了极大提升。
[0023]应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
[0024]附图用于更好地理解本方案,不构成对本公开的限定。其中:
[0025]图1为本公开一实施例提供的一种图处理网络模型的训练方法的流程示意图;
[0026]图2为本公开另一实施例提供的一种图处理网络模型的训练方法的流程示意图;
[0027]图3为本公开一实施例提供的一种图处理网络模型的训练装置的结构示意图;
[0028]图4为用来实现本公开实施例的图处理网络模型的训练方法的电子设备的框图。
具体实施方式
[0029]以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
[0030]人工智能是研究使计算机来模拟人的某些思维过程和智能行为(如学习、推理、思考、规划等)的学科,既有硬件层面的技术也有软件层面的技术。人工智能硬件技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理等技术;人工智能软件技术主要包括计算机视觉技术、语音识别技术、自然语言处理技术以及机器学习、深度学习、
大数据处理技术、知识图谱技术等几大方向。
[0031]深度学习是指多层的人工神经网络和训练它的方法。一层神经网络会把大量矩阵数字作为输入,通过非线性激活方法取权重,再产生另一个数据集合作为输出。通过合适的矩阵数量,多层组织链接一起,形成神经网络“大脑”进行精准复杂的处理,就像人们识别物体标注图片一样。
[0032]计算机视觉是一个跨学科的科学领域,研究如何让计算机从数字图像或视频中获得高水平的理解。从工程学的角度来看,它寻求人类视觉系统能够完成的自动化任务。计算机视觉任务包括获取、处理、分析和理解数字图像的方法,以及从现实世界中提取高维数据以便例如以决策的形式产生数字或符号信息的方法。
[0033]下面参考附图描述本公开实施例的图处理网络模型的训练方法、装置、电子设备和存储介质。
[0034]本公开实施例的图处理网络模型的训练方法,可由本公开实施例提供的图处理网络模型的训练装置执行,该装置可配置于电子设备中。
[0035]图1为本公开实施例提供的一种图处理网络模型的训练方法的流程示意图。
[0036]如图1所示,该图处理网络模型的训练方法,可以包括以下步骤:
[0037]步骤10本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种图处理网络模型的训练方法,包括:将训练样本分别输入学生网络和教师网络,以获取所述学生网络的第i层输出的第一特征图和所述教师网络的第i层输出的第二特征图,其中,i为大于或等于1且小于或等于N的整数,其中,N为所述学生网络及所述教师网络中包含的网络层的数量;根据所述第一特征图与所述第二特征图间的差异,确定所述学生网络对应的第一修正梯度;获取所述学生网络输出的第一软标签及所述教师网络输出的第二软标签;根据所述第一软标签与所述第二软标签的差异,确定所述学生网络对应的第二修正梯度;基于所述第一修正梯度及所述第二修正梯度,对所述学生网络进行修正。2.如权利要求1所述的方法,其中,在所述获取所述学生网络的第i层输出的第一特征图和所述教师网络的第i层输出的第二特征图之后,还包括:将所述第一特征图输入判别网络,以获取所述判别网络输出的判别结果;根据所述判别网络输出的判别结果,确定所述学生网络对应的第三修正梯度;所述基于所述第一修正梯度及所述第二修正梯度,对所述学生网络进行修正,包括:基于所述第一修正梯度、所述第二修正梯度及所述第三修正梯度,对所述学生网络进行修正。3.如权利要求2所述的方法,其中,所述根据所述判别网络输出的判别结果,确定所述学生网络对应的第三修正梯度,包括:在所述判别网络输出的判别结果指示所述第一特征图为所述学生网络生成的情况下,根据所述判别结果的置信度确定所述第三修正梯度。4.如权利要求2所述的方法,其中,所述根据所述判别网络输出的判别结果,确定所述学生网络对应的第三修正梯度,包括:在所述判别网络输出的判别结果指示所述第一特征图为所述教师网络生成的情况下,确定所述第三修正梯度为零。5.如权利要求1

4任一所述的方法,其中,所述根据所述第一特征图与所述第二特征图间的差异,确定所述学生网络对应的第一修正梯度,包括:根据所述第一特征图中每个像素点与所述第二特征图中对应的像素点间的欧氏距离,确定所述学生网络对应的第一修正梯度。6.如权利要求1

4任一所述的方法,其中,所述基于所述第一修正梯度及所述第二修正梯度,对所述学生网络进行修正,包括:基于所述第二修正梯度及所述第一修正梯度,对所述学生网络中的第i层的网络参数进行修正;基于所述第二修正梯度,对所述学生网络中的其余各层的网络参数进行修正。7.如权利要求1

4任一所述的方法,其中,所述基于所述第一修正梯度及所述第二修正梯度,对所述学生网络进行修正,包括:基于所述第二修正梯度及所述第一修正梯度,对所述学生网络中的第i层及所述第i层至输入层间的各个网络层的网络参数进行修正;基于所述第二修正梯度,对所述学生网络中的第i+1层至输出层间的各个网络层的网
络参数进行修正。8.一种图处理网络模型的训练装置,包括:第一获取模块,用于将训练样本分别输入学生网络和教师网络,以获取所述学生网络的第i层输出的第一特征图和所述教师网络的第i层输...

【专利技术属性】
技术研发人员:杨喜鹏蒋旻悦谭啸孙昊
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1