【技术实现步骤摘要】
一种神经网络构建方法以及装置
本申请涉及人工智能领域,尤其涉及一种神经网络构建方法以及装置。
技术介绍
在人工智能领域中,神经网络尤其是深度神经网络近年来在计算视觉领域取得了巨大的成功。受益于计算力的增加与越来越多不同的组成元件的提出,神经网络结构朝着越来越复杂的方向发展。在构建神经网络时,首先人工设计给定的超网络,该超网络可以包括多个节点,每个节点之间通过一种或者多种基础运算连接。然后通过bi-level优化方法来优化超网络,分别在训练数据集和验证数据集上优化权重参数和结构参数更新直到搜索结束,得到一个构建单元。之后对该构建单元进行离散化基础运算,即删除构建单元中的部分基础运算,保留结构参数最大的基础运算,且每个节点保留结构参数最大的两条边,得到的子网络作为最后的输出。最终得到的输出网络由该子网络堆叠得到。然而,超网络中每个单元中的节点之间,仅保留结构参数最大的基础运算,得到构建单元,然后由最终得到的构建单元堆叠得到最终的输出网络。因此,在构建神经网络时,在结构参数和权重参数更新完成之后,再进行离散化基础运算,最终得到的输出网络仅由相同的构建单元堆叠得到,仅能适用简单的场景。
技术实现思路
本申请提供一种神经网络构建方法以及装置,用于准确高效地构建目标神经网络,构建出的输出网络输出的准确度高,还可以应用于不同的应用场景中,泛化能力强。有鉴于此,第一方面,本申请提供一种神经网络构建方法,包括:获取初始超网络,初始超网络包括多个节点,多个节点之间通过至少一种基础运算连接,其中,多个节 ...
【技术保护点】
1.一种神经网络构建方法,其特征在于,包括:/n获取初始超网络,所述初始超网络包括多个节点,所述多个节点之间通过至少一种基础运算连接,其中,所述多个节点中的任意每两个节点包括第一节点和第二节点,每个基础运算的输出作为所述第二节点的输入,所述第一节点的输出经所述第一节点和所述第二节点之间连接的每个基础运算进行运算后得到的输出为所述每个基础运算的输出;/n通过预设的训练集对所述超网络进行至少一次迭代更新,得到至少一个输出网络,其中,在所述至少一次迭代更新中的任意一次迭代更新中,通过所述预设的训练集更新上一次迭代更新得到的输出网络中每个基础运算对应的结构参数,得到更新后的网络,删除所述更新后的网络中低于预设值的结构参数对应的基础运算,得到当前次迭代的输出网络,所述结构参数包括所述第一节点和所述第二节点之间连接的每个基础运算的输出占所述第二节点的输入的权重。/n
【技术特征摘要】
1.一种神经网络构建方法,其特征在于,包括:
获取初始超网络,所述初始超网络包括多个节点,所述多个节点之间通过至少一种基础运算连接,其中,所述多个节点中的任意每两个节点包括第一节点和第二节点,每个基础运算的输出作为所述第二节点的输入,所述第一节点的输出经所述第一节点和所述第二节点之间连接的每个基础运算进行运算后得到的输出为所述每个基础运算的输出;
通过预设的训练集对所述超网络进行至少一次迭代更新,得到至少一个输出网络,其中,在所述至少一次迭代更新中的任意一次迭代更新中,通过所述预设的训练集更新上一次迭代更新得到的输出网络中每个基础运算对应的结构参数,得到更新后的网络,删除所述更新后的网络中低于预设值的结构参数对应的基础运算,得到当前次迭代的输出网络,所述结构参数包括所述第一节点和所述第二节点之间连接的每个基础运算的输出占所述第二节点的输入的权重。
2.根据权利要求1所述的方法,其特征在于,在所述任意一次迭代更新中,所述通过所述预设的训练集更新上一次迭代更新得到的输出网络的每个基础运算对应的结构参数,得到更新后的网络,包括:
通过所述预设的数据集,结合损失函数更新上一次迭代更新得到的输出网络的每个基础运算对应的权重参数和结构参数,得到所述更新后的网络,所述权重参数为所述多个节点之间连接的每个基础运算内使用的参数,所述损失函数中包括针对结构参数的约束函数;
在得到当前次迭代的输出网络之后,所述方法还包括:
判断所述前次迭代的输出网络是否符合预设的剪枝条件,并根据判断结果调整所述约束函数在所述损失函数中所占的权重。
3.根据权利要求2所述的方法,其特征在于,所述根据判断结果调整所述约束函数在所述损失函数中所占的权重,包括:
若所述当前次迭代的输出网络符合所述预设的剪枝条件,则增加所述约束函数在所述损失函数中所占的权重;
若所述当前次迭代的输出网络不符合所述预设的剪枝条件,则减少所述约束函数在所述损失函数中所占的权重。
4.根据权利要求2或3所述的方法,其特征在于,所述预设的剪枝条件包括以下一项或者多项:所述前次迭代更新过程中删除的基础运算的数量大于预设数量、所述前次迭代更新过程中删除的基础运算占所述更新后的网络中的基础运算的比例大于预设比例、所述前次迭代更新过程中删除基础运算后得到的网络的分类精度的下降值超过预设精度值、或者,所述前次迭代更新过程中删除基础运算后得到的网络的每秒浮点运算次数FLOPS减少的值大于第一预设次数。
5.根据权利要求1-4中任一项所述的方法,其特征在于,在所述任意一次迭代更新中,所述通过所述预设的训练集更新上一次迭代更新得到的输出网络的权重参数和结构参数,包括:
将所述预设的数据集作为所述上一次迭代更新得到的输出网络的输入,得到所述上一次迭代更新得到的输出网络的输出数据;
通过损失函数以及所述上一次迭代更新得到的输出网络的输出数据,得到所述上一次迭代更新得到的输出网络的损失值;
根据所述损失值对所述上一次迭代更新得到的输出网络的权重参数更新,以及对所述上一次迭代更新得到的输出网络中的每个构建单元的结构参数分别进行更新,得到所述更新后的网络,所述更新后的网络中的每个构建单元的结构参数不完全相同。
6.根据权利要求1-5中任一项所述的方法,其特征在于,在所述任意一次迭代更新中,所述删除所述更新后的网络中低于预设值的结构参数对应的基础运算,得到当前次迭代的输出网络,包括:
删除所述更新后的网络中低于预设值的结构参数对应的基础运算,得到当前次迭代的网络;
若所述当前次迭代的网络符合预设的稳定条件,则将所述当前次迭代的网络作为当前次迭代的输出网络。
7.根据权利要求6所述的方法,其特征在于,所述预设的稳定条件包括以下一项或者多项:
所述当前次迭代的网络的输出的精度大于第一阈值,或者,所述当前次迭代的网络的输出的平均精度大于第二阈值,或者,所述当前次迭代的网络损失值不大于第三阈值,或者,所述当前次迭代的网络的推理时长不大于第四阈值,或者,所述当前次迭代的网络的FLOPS不大于第五阈值,所述平均精度为对所述当前次迭代的网络进行多次评估得到的多个精度的平均值,所述推理时长为从所述当前次迭代的网络根据输入得到输出结果的时长。
8.根据权利要求1-7中任一项所述的方法,其特征在于,所述删除所述更新后的网络中结构参数低于预设值的结构参数对应的基础运算,包括:
获取所述更新后的网络中结构参数最小的n个结构参数对应的基础运算,得到第一基础运算集合,所述n为正数;
从所述第一基础运算集合中获取结构参数低于b的结构参数对应的基础运算,以及获取所述更新后的网络中结构参数低于a的基础运算,得到第二基础运算集合,所述a、b为正数,且b>a。
9.根据权利要求1-8中任一项所述的方法,其特征在于,
所述通过预设的训练集对所述初始超网络进行迭代更新的次数为根据预设的终止条件确定,其中,当所述至少一次迭代更新中任意一次迭代更新得到的输出网络满足所述预设的终止条件,则终止所述迭代更新。
10.根据权利要求9所述的方法,其特征在于,所述预设的终止条件,包括以下一项或者多项:
所述输出网络的计算量低于预设计算量、所述输出网络的FLOPS小于第二预设次数、或者针对所述超网络进行迭代更新的时长大于预设时长。
11.一种神经网络构建装置,其特征在于,包括:
获取单元,用于获取初始超网络,所述初始超网络包括...
【专利技术属性】
技术研发人员:毕恺峰,魏龙辉,陈鑫,谢凌曦,田奇,
申请(专利权)人:华为技术有限公司,
类型:发明
国别省市:广东;44
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。