【技术实现步骤摘要】
超网络的训练方法和装置
本公开的实施例涉及计算机
,具体涉及人工智能
,尤其涉及超网络的训练方法和装置。
技术介绍
随着人工智能技术和数据存储技术的发展,深度神经网络在许多领域的任务中取得了重要的成果。神经网络模型的结构对其性能具有重要的影响,传统的神经网络模型结构的设计依赖于专家知识。NAS(NeuralArchitectureSearch,网络结构搜索)是通过评估不同的网络结构的性能来自动搜索出最优的网络结果的技术。NAS需要独立评估每个子网络的性能,因此搜索效率较低。为了解决NAS搜索效率的问题,可以训练一个包含多个完整神经网络结构的超网络,超网络所有网络结构共享超网络的参数。然而,由于超网络中所有网络结构是共存的,训练超网络过程中不同的网络结构的性能存在着互斥的问题。超网络的训练虽然解决了网络结构搜索效率的问题,但是基于超网络训练得到的子网络的性能与独立训练的子网络的性能之间存在差异,从而导致基于超网络不能准确地搜索出最优的模型结构。
技术实现思路
本公开的实施例提供了超网络的训练方法和装置、电子设备以及计算机可读存储介质。根据第一方面,提供了一种超网络的训练方法,包括:获取样本数据;将待训练的超网络作为初始的当前超网络,迭代执行多次裁剪训练操作,直到当前超网络的各个特征提取层保留的连接数均为1,得到裁剪完成的超网络;响应于确定裁剪完成的超网络未达到预设的收敛条件,基于样本数据对裁剪完成的超网络进行训练;其中,裁剪训练操作包括:基于样本数据对当前超网络进行训练; ...
【技术保护点】
1.一种超网络的训练方法,包括:/n获取样本数据;/n将待训练的超网络作为初始的当前超网络,迭代执行多次裁剪训练操作,直到所述当前超网络的各个特征提取层保留的连接数均为1,得到裁剪完成的超网络;/n响应于确定所述裁剪完成的超网络未达到预设的收敛条件,基于样本数据对所述裁剪完成的超网络进行训练;/n其中,所述裁剪训练操作包括:/n基于样本数据对当前超网络进行训练;/n利用训练后的当前超网络对待处理的图像数据进行特征提取,得到第一特征图;/n针对训练后的所述当前超网络的每一个特征提取层,确定所述特征提取层包含的连接数量N,对所述训练后的超网络中的所述特征提取层分别进行N次裁剪,得到N个裁剪后的超网络,并利用N个裁剪后的超网络分别对待处理的图像数据进行特征提取,得到对应的N组第二特征图,其中每一次裁剪中分别裁剪所述特征提取层包含的N个连接中的一个;/n确定所述N组第二特征图中与所述第一特征图之间的距离最小的一组第二特征图对应的裁剪后的超网络为新的当前超网络;/n响应于确定所述新的当前超网络中的所述特征提取层的连接数量大于1,执行下一次裁剪训练操作。/n
【技术特征摘要】
1.一种超网络的训练方法,包括:
获取样本数据;
将待训练的超网络作为初始的当前超网络,迭代执行多次裁剪训练操作,直到所述当前超网络的各个特征提取层保留的连接数均为1,得到裁剪完成的超网络;
响应于确定所述裁剪完成的超网络未达到预设的收敛条件,基于样本数据对所述裁剪完成的超网络进行训练;
其中,所述裁剪训练操作包括:
基于样本数据对当前超网络进行训练;
利用训练后的当前超网络对待处理的图像数据进行特征提取,得到第一特征图;
针对训练后的所述当前超网络的每一个特征提取层,确定所述特征提取层包含的连接数量N,对所述训练后的超网络中的所述特征提取层分别进行N次裁剪,得到N个裁剪后的超网络,并利用N个裁剪后的超网络分别对待处理的图像数据进行特征提取,得到对应的N组第二特征图,其中每一次裁剪中分别裁剪所述特征提取层包含的N个连接中的一个;
确定所述N组第二特征图中与所述第一特征图之间的距离最小的一组第二特征图对应的裁剪后的超网络为新的当前超网络;
响应于确定所述新的当前超网络中的所述特征提取层的连接数量大于1,执行下一次裁剪训练操作。
2.根据权利要求1所述的方法,其中,所述方法还包括:
响应于确定所述裁剪完成的超网络达到预设的收敛条件且所述裁剪完成的超网络中各个特征提取层的连接数均为1,基于所述裁剪完成的超网络的各个特征提取层构建目标神经网络模型。
3.根据权利要求1所述的方法,其中,所述裁剪训练操作还包括:
响应于确定所述新的当前超网络中的所述特征提取层的连接数量为1,保存所述新的当前超网络中的所述特征提取层中的连接对应的权重参数;以及
所述基于样本数据对所述裁剪完成的超网络进行训练,包括:
将所述裁剪完成的超网络中的各个特征提取层中的连接对应的权重参数作为所述裁剪完成的超网络中的初始权重参数,基于所述样本数据对所述裁剪完成的超网络中的权重参数进行迭代更新。
4.根据权利要求1所述的方法,其中,所述裁剪训练操作还包括:
确定所述N组第二特征图中与所述第一特征图之间的距离最小的一组第二特征图为目标第二特征图,保存所述目标第二特征图对应的裁剪后的超网络中被裁剪掉的一个连接对应的权重参数。
5.根据权利要求4所述的方法,其中,所述方法还包括:
根据保存的所述待训练的超网络的各个特征提取层中被裁剪掉的连接对应的权重参数、以及所述裁剪完成的超网络中各个特征提取层被保留的连接在所述裁剪完成的超网络训练完成后对应的权重参数,生成训练完成的超网络。
6.一种超网络的训练方法装置,包括:
获取单元,被配置为获取样本数据;
第一训练单元,被配置为将待训练的超网络作为初始的当前超网络,迭代执行多次裁剪训练操作,直到所述当前超网络的各个特征提取层保留的连接数均为1,得到裁剪完成的超网络;
第二训练单元,被配置为响应于确定所述裁剪完成的超网络未达到预设的收敛条件,基于样本数据对所述裁剪完成的超网络进行训练;
其中,所述第一训练单元包括:
训练子单元,被配置为执行所述裁剪训练操作中的如...
【专利技术属性】
技术研发人员:希滕,张刚,温圣召,
申请(专利权)人:北京百度网讯科技有限公司,
类型:发明
国别省市:北京;11
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。