超网络的训练方法和装置制造方法及图纸

技术编号:25346961 阅读:16 留言:0更新日期:2020-08-21 17:06
本申请涉及人工智能领域,公开了超网络的训练方法和装置。该方法包括:获取样本数据;将待训练的超网络作为初始的当前超网络,迭代执行多次裁剪训练操作直到当前超网络的各个特征提取层保留的连接数均为1;响应于确定裁剪完成的超网络未达到预设的收敛条件,基于样本数据对裁剪完成的超网络进行训练;裁剪训练操作包括:对当前超网络进行训练;利用训练后的当前超网络对图像数据进行特征提取得到第一特征图;对训练后的超网络中的特征提取层分别进行N次裁剪,利用裁剪后的超网络分别对图像数据进行特征提取得到N组第二特征图;确定与第一特征图的距离最小的一组第二特征图对应的裁剪后的超网络为新的当前超网络。该方法提升了超网络的准确性。

【技术实现步骤摘要】
超网络的训练方法和装置
本公开的实施例涉及计算机
,具体涉及人工智能
,尤其涉及超网络的训练方法和装置。
技术介绍
随着人工智能技术和数据存储技术的发展,深度神经网络在许多领域的任务中取得了重要的成果。神经网络模型的结构对其性能具有重要的影响,传统的神经网络模型结构的设计依赖于专家知识。NAS(NeuralArchitectureSearch,网络结构搜索)是通过评估不同的网络结构的性能来自动搜索出最优的网络结果的技术。NAS需要独立评估每个子网络的性能,因此搜索效率较低。为了解决NAS搜索效率的问题,可以训练一个包含多个完整神经网络结构的超网络,超网络所有网络结构共享超网络的参数。然而,由于超网络中所有网络结构是共存的,训练超网络过程中不同的网络结构的性能存在着互斥的问题。超网络的训练虽然解决了网络结构搜索效率的问题,但是基于超网络训练得到的子网络的性能与独立训练的子网络的性能之间存在差异,从而导致基于超网络不能准确地搜索出最优的模型结构。
技术实现思路
本公开的实施例提供了超网络的训练方法和装置、电子设备以及计算机可读存储介质。根据第一方面,提供了一种超网络的训练方法,包括:获取样本数据;将待训练的超网络作为初始的当前超网络,迭代执行多次裁剪训练操作,直到当前超网络的各个特征提取层保留的连接数均为1,得到裁剪完成的超网络;响应于确定裁剪完成的超网络未达到预设的收敛条件,基于样本数据对裁剪完成的超网络进行训练;其中,裁剪训练操作包括:基于样本数据对当前超网络进行训练;利用训练后的当前超网络对待处理的图像数据进行特征提取,得到第一特征图;针对训练后的当前超网络的每一个特征提取层,确定特征提取层包含的连接数量N,对训练后的超网络中的特征提取层分别进行N次裁剪,得到N个裁剪后的超网络,并利用N个裁剪后的超网络分别对待处理的图像数据进行特征提取,得到对应的N组第二特征图,其中每一次裁剪中分别裁剪特征提取层包含的N个连接中的一个;确定N组第二特征图中与第一特征图之间的距离最小的一组第二特征图对应的裁剪后的超网络为新的当前超网络;响应于确定新的当前超网络中的特征提取层的连接数量大于1,执行下一次裁剪训练操作。根据第二方面,提供了一种超网络的训练装置,包括:获取单元,被配置为获取样本数据;第一训练单元,被配置为将待训练的超网络作为初始的当前超网络,迭代执行多次裁剪训练操作,直到当前超网络的各个特征提取层保留的连接数均为1,得到裁剪完成的超网络;第二训练单元,被配置为响应于确定裁剪完成的超网络未达到预设的收敛条件,基于样本数据对裁剪完成的超网络进行训练;其中,第一训练单元包括:训练子单元,被配置为执行裁剪训练操作中的如下步骤:基于样本数据对当前超网络进行训练;特征提取子单元,被配置为执行裁剪训练操作中的如下步骤:利用训练后的当前超网络对待处理的图像数据进行特征提取,得到第一特征图;裁剪子单元,被配置为执行裁剪训练操作中的如下步骤:针对训练后的当前超网络的每一个特征提取层,确定特征提取层包含的连接数量N,对训练后的超网络中的特征提取层分别进行N次裁剪,得到N个裁剪后的超网络,并利用N个裁剪后的超网络分别对待处理的图像数据进行特征提取,得到对应的N组第二特征图,其中每一次裁剪中分别裁剪特征提取层包含的N个连接中的一个;确定子单元,被配置为执行裁剪训练操作中的如下步骤:确定N组第二特征图中与第一特征图之间的距离最小的一组第二特征图对应的裁剪后的超网络为新的当前超网络;迭代子单元,被配置为执行裁剪训练操作中的如下步骤:响应于确定新的当前超网络中的特征提取层的连接数量大于1,执行下一次裁剪训练操作。根据第三方面,提供了一种电子设备,包括:至少一个处理器;以及与至少一个处理器通信连接的存储器;其中,存储器存储有可被至少一个处理器执行的指令,指令被至少一个处理器执行,以使至少一个处理器能够执行第一方面提供的超网络的训练方法。根据第四方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,计算机指令用于使计算机执行第一方面提供的超网络的训练方法。根据本申请的方法提升了超网络的准确性。应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。附图说明通过阅读参照以下附图所作的对非限制性实施例所作的详细描述,本公开的其它特征、目的和优点将会变得更明显:图1是本公开的超网络的训练方法的一个实施例的流程图;图2示出了超网络的特征提取层的裁剪后结构的示意图;图3是本公开的超网络的训练方法的另一个实施例的流程图;图4是本公开的超网络的训练装置的一个实施例的结构示意图;图5是用来实现本公开的实施例的超网络的训练方法的电子设备的框图。具体实施方式下面结合附图和实施例对本公开作进一步的详细说明。可以理解的是,此处所描述的具体实施例仅仅用于解释相关专利技术,而非对该专利技术的限定。另外还需要说明的是,为了便于描述,附图中仅示出了与有关专利技术相关的部分。需要说明的是,在不冲突的情况下,本公开中的实施例及实施例中的特征可以相互组合。下面将参考附图并结合实施例来详细说明本公开。以下结合附图对本申请的示范性实施例做出说明,其中包括本申请实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本申请的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。本公开的方法或装置可以应用于终端设备或服务器,或者可以应用于包括终端设备、网络和服务器的系统架构。其中,网络用以在终端设备和服务器之间提供通信链路的介质,可以包括各种连接类型,例如有线、无线通信链路或者光纤电缆等等。终端设备可以是用户端设备,其上可以安装有各种客户端应用。例如,图像处理类应用、搜索应用、语音服务类应用等。终端设备可以是硬件,也可以是软件。当终端设备为硬件时,可以是各种电子设备,包括但不限于智能手机、平板电脑、电子书阅读器、膝上型便携计算机和台式计算机等等。当终端设备为软件时,可以安装在上述所列举的电子设备中。其可以实现成多个软件或软件模块,也可以实现成单个软件或软件模块。在此不做具体限定。服务器可以是运行各种服务的服务器,例如运行基于图像、视频、语音、文本、数字信号等数据的目标检测与识别、文本或语音识别、信号转换等服务的服务器。服务器可以获取各种媒体数据作为深度学习任务的训练样本数据,如图像数据、音频数据、文本数据等。服务器还可以根据具体的深度学习任务,利用训练样本数据训练超网络,并从超网络中采样出子网络进行评估,根据各子网络的评估结果确定用于执行上述深度学习任务的神经网络模型的结构和参数。服务器还可以将确定出的神经网络模型的结构和参数等数据发送至终端设备。终端设备根据接收到的数据在本地部署并运行神经网络模型,以执行相应的深本文档来自技高网...

【技术保护点】
1.一种超网络的训练方法,包括:/n获取样本数据;/n将待训练的超网络作为初始的当前超网络,迭代执行多次裁剪训练操作,直到所述当前超网络的各个特征提取层保留的连接数均为1,得到裁剪完成的超网络;/n响应于确定所述裁剪完成的超网络未达到预设的收敛条件,基于样本数据对所述裁剪完成的超网络进行训练;/n其中,所述裁剪训练操作包括:/n基于样本数据对当前超网络进行训练;/n利用训练后的当前超网络对待处理的图像数据进行特征提取,得到第一特征图;/n针对训练后的所述当前超网络的每一个特征提取层,确定所述特征提取层包含的连接数量N,对所述训练后的超网络中的所述特征提取层分别进行N次裁剪,得到N个裁剪后的超网络,并利用N个裁剪后的超网络分别对待处理的图像数据进行特征提取,得到对应的N组第二特征图,其中每一次裁剪中分别裁剪所述特征提取层包含的N个连接中的一个;/n确定所述N组第二特征图中与所述第一特征图之间的距离最小的一组第二特征图对应的裁剪后的超网络为新的当前超网络;/n响应于确定所述新的当前超网络中的所述特征提取层的连接数量大于1,执行下一次裁剪训练操作。/n

【技术特征摘要】
1.一种超网络的训练方法,包括:
获取样本数据;
将待训练的超网络作为初始的当前超网络,迭代执行多次裁剪训练操作,直到所述当前超网络的各个特征提取层保留的连接数均为1,得到裁剪完成的超网络;
响应于确定所述裁剪完成的超网络未达到预设的收敛条件,基于样本数据对所述裁剪完成的超网络进行训练;
其中,所述裁剪训练操作包括:
基于样本数据对当前超网络进行训练;
利用训练后的当前超网络对待处理的图像数据进行特征提取,得到第一特征图;
针对训练后的所述当前超网络的每一个特征提取层,确定所述特征提取层包含的连接数量N,对所述训练后的超网络中的所述特征提取层分别进行N次裁剪,得到N个裁剪后的超网络,并利用N个裁剪后的超网络分别对待处理的图像数据进行特征提取,得到对应的N组第二特征图,其中每一次裁剪中分别裁剪所述特征提取层包含的N个连接中的一个;
确定所述N组第二特征图中与所述第一特征图之间的距离最小的一组第二特征图对应的裁剪后的超网络为新的当前超网络;
响应于确定所述新的当前超网络中的所述特征提取层的连接数量大于1,执行下一次裁剪训练操作。


2.根据权利要求1所述的方法,其中,所述方法还包括:
响应于确定所述裁剪完成的超网络达到预设的收敛条件且所述裁剪完成的超网络中各个特征提取层的连接数均为1,基于所述裁剪完成的超网络的各个特征提取层构建目标神经网络模型。


3.根据权利要求1所述的方法,其中,所述裁剪训练操作还包括:
响应于确定所述新的当前超网络中的所述特征提取层的连接数量为1,保存所述新的当前超网络中的所述特征提取层中的连接对应的权重参数;以及
所述基于样本数据对所述裁剪完成的超网络进行训练,包括:
将所述裁剪完成的超网络中的各个特征提取层中的连接对应的权重参数作为所述裁剪完成的超网络中的初始权重参数,基于所述样本数据对所述裁剪完成的超网络中的权重参数进行迭代更新。


4.根据权利要求1所述的方法,其中,所述裁剪训练操作还包括:
确定所述N组第二特征图中与所述第一特征图之间的距离最小的一组第二特征图为目标第二特征图,保存所述目标第二特征图对应的裁剪后的超网络中被裁剪掉的一个连接对应的权重参数。


5.根据权利要求4所述的方法,其中,所述方法还包括:
根据保存的所述待训练的超网络的各个特征提取层中被裁剪掉的连接对应的权重参数、以及所述裁剪完成的超网络中各个特征提取层被保留的连接在所述裁剪完成的超网络训练完成后对应的权重参数,生成训练完成的超网络。


6.一种超网络的训练方法装置,包括:
获取单元,被配置为获取样本数据;
第一训练单元,被配置为将待训练的超网络作为初始的当前超网络,迭代执行多次裁剪训练操作,直到所述当前超网络的各个特征提取层保留的连接数均为1,得到裁剪完成的超网络;
第二训练单元,被配置为响应于确定所述裁剪完成的超网络未达到预设的收敛条件,基于样本数据对所述裁剪完成的超网络进行训练;
其中,所述第一训练单元包括:
训练子单元,被配置为执行所述裁剪训练操作中的如...

【专利技术属性】
技术研发人员:希滕张刚温圣召
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:北京;11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1