当前位置: 首页 > 专利查询>浙江大学专利>正文

一种基于类扩张学习的神经网络模型优化方法技术

技术编号:23344937 阅读:77 留言:0更新日期:2020-02-15 04:24
本发明专利技术公开了一种基于类扩张学习的神经网络模型优化方法。具体包括步骤如下:获取用于训练的含有多种类别的样本的数据集,并定义算法目标;用通用模型提取数据集中每类图片的特征,并根据每类特征的分布情况评估各个类别易错的程度;将易错程度最高的几类数据加入训练池,并用训练池中的数据优化神经网络;优化完成后,将剩余易错程度最高的几类数据加入训练池,扩张训练池中的类别,并在上一次训练得到的神经网络基础上用训练池进一步优化神经网络;不断对训练池进行类扩张,直至整个数据集进入训练池,得到最终的优化的神经网络模型。本发明专利技术适用于监督学习中的基于多类别数据集的神经网络模型优化,面对各类复杂的情况具有较佳的效果和鲁棒性。

A neural network model optimization method based on class expansion learning

【技术实现步骤摘要】
一种基于类扩张学习的神经网络模型优化方法
本专利技术属于计算机视觉领域,特别地涉及一种基于类扩张学习的神经网络模型优化方法。
技术介绍
神经网络模型的优化方法是人工智能的底层技术,常作为高层视觉任务的基础,例如物体检测,目标识别,语意分割等。然而,受到计算机计算资源和内存资源的限制,目前的神经网络模型的优化方法依赖批随机梯度下降。这种方法是一种迭代式的,批层次的学习模型,每次的训练无法利用全局的数据,只能利用其中一批数据。由于每次训练的数据通常分布在极为稀疏和分散的空间上,神经网络模型的优化难度很大,并且在优化的同时会受大部分简单数据的影响,忽略少量复杂数据的信息。基于人类的认知学理论,目前课程学习和自步学习逐渐用于解决这种问题。现有的学习方法主要采用的是深度学习框架,输入一个数据集,通过特定的评判指标选出部分数据,然后在这部分数据上训练。在不断的迭代过程中,选出的部分数据会越来越多,直至包含整个数据集,从而达到渐进式地优化神经网络的效果。然而,这类优化方法细节多,实现成本高,复现困难;另一方面,这类方法往往是针对特定任务设计特定评判指本文档来自技高网...

【技术保护点】
1.一种基于类扩张学习的神经网络模型优化方法,其特征在于,以类扩张的模式优化神经网络,包括以下步骤:/nS1、获取用于训练的含有多种类别的图片样本的数据集,并定义算法目标;/nS2、用通用模型提取数据集中每类图片的特征,并根据每类特征的分布情况评估各个类别易错的程度;/nS3、将整个数据集中的易错程度最高的若干类数据加入预先置空的训练池,并用训练池中的数据优化神经网络;/nS4、上一轮优化完成后,将未加入训练池的剩余数据集中易错程度最高的若干类数据继续加入训练池,扩张训练池中的类别,并在上一轮训练得到的神经网络的基础上用扩张后的训练池进一步优化神经网络;/nS5、不断重复步骤S4对训练池进行类...

【技术特征摘要】
1.一种基于类扩张学习的神经网络模型优化方法,其特征在于,以类扩张的模式优化神经网络,包括以下步骤:
S1、获取用于训练的含有多种类别的图片样本的数据集,并定义算法目标;
S2、用通用模型提取数据集中每类图片的特征,并根据每类特征的分布情况评估各个类别易错的程度;
S3、将整个数据集中的易错程度最高的若干类数据加入预先置空的训练池,并用训练池中的数据优化神经网络;
S4、上一轮优化完成后,将未加入训练池的剩余数据集中易错程度最高的若干类数据继续加入训练池,扩张训练池中的类别,并在上一轮训练得到的神经网络的基础上用扩张后的训练池进一步优化神经网络;
S5、不断重复步骤S4对训练池进行类扩张和神经网络优化,直至整个数据集都被加入训练池并完成最后一轮神经网络优化,得到最终优化后的神经网络模型。


2.如权利要求1所述的基于类扩张学习的神经网络模型优化方法,其特征在于,步骤S1的具体实现步骤包括:
S11:获取包含M个类别的多类别数据集D:
D=C1∪C2...∪CM
其中,Cm表示第m个类别的数据,m=1,2,…,m;
每个类别数据Cm中包含Nm个图片样本x以及他们对应的标签y:



其中,表示第m个类别的数据Cm中第i个图片样本,ym表示第m个类别的数据Cm的标签,i∈{1,2,...,Nm};
S12:定义的算法目标为:通过优化loss函数l(·,·)得到神经网络模型f(·;θ)在数据集D上的最优参数θ*:
θ*=argminθ∑(x,y)∈Dl(f(x;θ),y)。


3.如权利要求2所述的基于类扩张学习的神经网络模型优化方法,其特征在于,步骤S2具体实现步骤包括:
S21、用一个通用模型g(·)提取数据集中每类图片中每幅图片的特征:



其中,表示第m个类别中图片样本x的特征;
S22、计算出每个类别中所有图片的特征的均值:



其中,um表示第m个类别中所有图片样本的特征均值;
S...

【专利技术属性】
技术研发人员:汪慧朱文武赵涵斌李玺
申请(专利权)人:浙江大学
类型:发明
国别省市:浙江;33

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1