一种基于半监督生成对抗网络的通用图像分类方法和装置制造方法及图纸

技术编号:20076015 阅读:18 留言:0更新日期:2019-01-15 00:54
本发明专利技术公开了一种基于半监督生成对抗网络的通用图像分类方法和装置,涉及图像分类技术,所述方法包括:步骤1:训练得到深度卷积生成对抗网络DCGAN,所述DCGAN包括生成网络和判别网络,所述判别网络包括依次连接的卷积神经网络和Softmax多分类器;步骤2:将待分类的图像输入至所述卷积神经网络,得到图像特征;步骤3:将得到的图像特征输入所述Softmax多分类器,得到分类结果。本发明专利技术中,由于判别网络是在DCGAN中训练得到,故提高了判别网络的泛化性能和分类准确率,通过该判别网络得到图像特征后,再结合Softmax多分类器,能够极大提高图像分类的准确率。

A Universal Image Classification Method and Device Based on Semi-supervised Generating Countermeasure Network

The invention discloses a general image classification method and device based on semi-supervised generation antagonism network, which relates to image classification technology. The method includes: step 1: training to obtain deep convolution generation antagonism network DCGAN, which includes generation network and discrimination network, and the discrimination network includes convolution neural network connected sequentially and Softmax multi-classifier. The classified image is input to the convolution neural network to obtain the image features; Step 3: Input the obtained image features into the Softmax multi-classifier to obtain the classification results. In the present invention, since the discriminant network is trained in DCGAN, the generalization performance and classification accuracy of the discriminant network are improved. After obtaining image features through the discriminant network and combining with the Softmax multi-classifier, the accuracy of image classification can be greatly improved.

【技术实现步骤摘要】
一种基于半监督生成对抗网络的通用图像分类方法和装置
本专利技术属于深度学习领域,涉及图像分类技术,具体涉及一种基于半监督生成对抗网络的通用图像分类方法和装置。
技术介绍
图像分类技术,是计算机视觉和模式识别领域的主要分支之一。图像分类就是根据各自在图像信息中反映的不同特征,将图像中不同类别的目标区分开来的图像处理方法。图像分类是利用计算机对图像进行定量分析,把图像或图像中的某个区域划为若干类别中的某一种,以代替人的视觉判读。随着大数据时代的到来,数据在计算机视觉的任务中越来越明显,在数据足够多的情况下,可以使用基础的模型、算法,比如KNN(k-NearestNeighbor,最近邻分类),NaiveBayes(朴素贝叶斯)就能得到比较好的结果。图像分类在很多领域都得到广泛应用,包括安防领域的人脸识别、行为检测等,交通领域的车辆识别、车牌检测等,以及互联网领域的图像检索等。本专利技术关注的是深度学习领域的图像分类,即利用卷积神经网络进行图像分类。早期的图像分类目标主要集中在一些较为简单的任务,例如,形状分类、OCR(OpticalCharacterRecognition,光学字符识别)等。其中,在OCR中,手写数字识别是一个广泛研究的课题,与此相关的最著名的数据库是MNIST(MixedNationalInstituteofStandardsandTechnology)数据库,MNIST是手写数字识别领域的标准测试数据集,大小是60,000,一共包含10类阿拉伯数字,每一类有5,000张图像进行训练,1,000张图像进行测试。MNIST的图像大小为28*28,即784维,该数据集中的图像手写数字,存在较大的形变。CIFAR-10数据集也是使用比较广泛的一个数据集,该数据集共有60,000张彩色图像,这些图像大小是32*32,分为10个类,每类6,000张图。CIFAR-10数据集中有50,000张用于训练,构成了5个训练批,每一批10,000张图;另外10,000张用于测试,单独构成一批。测试批的数据是取自10类中的每一类,每一类随机取1,000张,剩下的图像就随机排列组成了训练批。本专利技术使用了MNIST和CIFAR-10两类数据集验证方法的有效性。半监督学习旨在缓解标签样本数量不够时的小样本问题,半监督学习方法大致可以分为四种:(1)生成式模型,通过预测条件概率密度来得到未标记样本的标签;(2)基于图的方法,利用标记样本和非标记样本来构建图模型;(3)低密度分离,旨在将边界放置在几乎没有标签或无标签数据的区域;(4)基于包装的方法,这种方法利用有监督方法并且迭代地标记未标记的数据。
技术实现思路
本专利技术要解决的技术问题是提供一种基于半监督生成对抗网络的通用图像分类方法和装置,以提高图像分类的准确率。为解决上述技术问题,本专利技术提供技术方案如下:一方面,提供一种基于半监督生成对抗网络的通用图像分类方法,包括:步骤1:训练得到深度卷积生成对抗网络DCGAN,所述DCGAN包括生成网络和判别网络,所述判别网络包括依次连接的卷积神经网络和Softmax多分类器;步骤2:将待分类的图像输入至所述卷积神经网络,得到图像特征;步骤3:将得到的图像特征输入所述Softmax多分类器,得到分类结果。进一步的,所述步骤1包括:步骤10:在DCGAN框架下,将随机噪声输入生成网络,根据真实数据的分布拟合噪声的分布,得到和真实数据分布相近的分布,生成无标记样本图像;步骤11:将真实图像作为有标记样本图像与所述无标记样本图像一起输入判别网络,以供判别网络学习两种类型的数据分布;步骤12:根据判别网络对输入的样本图像的真假性判别结果,计算梯度,固定生成网络的参数,通过反向传播算法更新判别网络的节点的权重系数;步骤13:根据判别网络的反馈结果,固定判别网络参数,通过反向传播算法更新生成网络的节点的权重系数;步骤14:判断判别网络的分类准确率误差是否小于预设阈值,如果否,则转至步骤10,重复进行训练,如果是,则训练结束,得到训练完成的DCGAN。进一步的,所述步骤10中,所述随机噪声服从高斯分布。进一步的,所述步骤11中,所述真实图像经过高斯滤波预处理后作为所述有标记样本图像。进一步的,所述步骤14中,所述分类准确率误差为使用对数似然函数的损失函数,计算公式如下:loss=-lnap,其中,ap代表类别p对应的分类概率。另一方面,提供一种基于半监督生成对抗网络的通用图像分类装置,包括:网络训练模块,用于训练得到深度卷积生成对抗网络DCGAN,所述DCGAN包括生成网络和判别网络,所述判别网络包括依次连接的卷积神经网络和Softmax多分类器;图像特征获取模块,用于将待分类的图像输入至所述卷积神经网络,得到图像特征;分类模块,用于将得到的图像特征输入所述Softmax多分类器,得到分类结果。进一步的,所述网络训练模块包括:第一输入子模块,用于在DCGAN框架下,将随机噪声输入生成网络,根据真实数据的分布拟合噪声的分布,得到和真实数据分布相近的分布,生成无标记样本图像;第二输入子模块,用于将真实图像作为有标记样本图像与所述无标记样本图像一起输入判别网络,以供判别网络学习两种类型的数据分布;第一更新子模块,用于根据判别网络对输入的样本图像的真假性判别结果,计算梯度,固定生成网络的参数,通过反向传播算法更新判别网络的节点的权重系数;第二更新子模块,用于根据判别网络的反馈结果,固定判别网络参数,通过反向传播算法更新生成网络的节点的权重系数;判断子模块,用于判断判别网络的分类准确率误差是否小于预设阈值,如果否,则转至第一输入子模块,重复进行训练,如果是,则训练结束,得到训练完成的DCGAN。进一步的,所述第一输入子模块中,所述随机噪声服从高斯分布。进一步的,所述第二输入子模块中,所述真实图像经过高斯滤波预处理后作为所述有标记样本图像。进一步的,所述判断子模块中,所述分类准确率误差为使用对数似然函数的损失函数,计算公式如下:loss=-lnap,其中,ap代表类别p对应的分类概率。本专利技术具有以下有益效果:上述方案中,由于判别网络是在DCGAN中训练得到,故提高了判别网络的泛化性能和分类准确率,通过该判别网络得到图像特征后,再结合Softmax多分类器,能够极大提高图像分类的准确率。附图说明图1为本专利技术的基于半监督生成对抗网络的通用图像分类方法的流程示意图;图2是本专利技术方法中在MNIST数据集上由生成网络生成的图像,其中(a)图是生成网络使用半监督方法生成的样本图像,(b)图是生成网络使用无监督方法生成的样本图像;图3是本专利技术方法中在CIFAR-10数据集上由生成网络生成的图像,其中(a)图是生成网络使用半监督方法生成的样本图像,(b)图是生成网络使用无监督方法生成的样本图像;图4为本专利技术的基于半监督生成对抗网络的通用图像分类装置的结构示意图。具体实施方式为使本专利技术要解决的技术问题、技术方案和优点更加清楚,下面将结合附图及具体实施例进行详细描述。一方面,本专利技术提供一种基于半监督生成对抗网络的通用图像分类方法,如图1所示,包括:步骤101:训练得到深度卷积生成对抗网络(DeepConvolutionalGenerativeAdversarialNet本文档来自技高网
...

【技术保护点】
1.一种基于半监督生成对抗网络的通用图像分类方法,其特征在于,包括:步骤1:训练得到深度卷积生成对抗网络DCGAN,所述DCGAN包括生成网络和判别网络,所述判别网络包括依次连接的卷积神经网络和Softmax多分类器;步骤2:将待分类的图像输入至所述卷积神经网络,得到图像特征;步骤3:将得到的图像特征输入所述Softmax多分类器,得到分类结果。

【技术特征摘要】
1.一种基于半监督生成对抗网络的通用图像分类方法,其特征在于,包括:步骤1:训练得到深度卷积生成对抗网络DCGAN,所述DCGAN包括生成网络和判别网络,所述判别网络包括依次连接的卷积神经网络和Softmax多分类器;步骤2:将待分类的图像输入至所述卷积神经网络,得到图像特征;步骤3:将得到的图像特征输入所述Softmax多分类器,得到分类结果。2.根据权利要求1所述的方法,其特征在于,所述步骤1包括:步骤10:在DCGAN框架下,将随机噪声输入生成网络,根据真实数据的分布拟合噪声的分布,得到和真实数据分布相近的分布,生成无标记样本图像;步骤11:将真实图像作为有标记样本图像与所述无标记样本图像一起输入判别网络,以供判别网络学习两种类型的数据分布;步骤12:根据判别网络对输入的样本图像的真假性判别结果,计算梯度,固定生成网络的参数,通过反向传播算法更新判别网络的节点的权重系数;步骤13:根据判别网络的反馈结果,固定判别网络参数,通过反向传播算法更新生成网络的节点的权重系数;步骤14:判断判别网络的分类准确率误差是否小于预设阈值,如果否,则转至步骤10,重复进行训练,如果是,则训练结束,得到训练完成的DCGAN。3.根据权利要求2所述的方法,其特征在于,所述步骤10中,所述随机噪声服从高斯分布。4.根据权利要求3所述的方法,其特征在于,所述步骤11中,所述真实图像经过高斯滤波预处理后作为所述有标记样本图像。5.根据权利要求1-4中任一所述的方法,其特征在于,所述步骤14中,所述分类准确率误差为使用对数似然函数的损失函数,计算公式如下:loss=-lnap,其中,ap代表类别p对应的分类概率。6.一种基于半监督生成对抗网络的通用图像分类装置,其...

【专利技术属性】
技术研发人员:苏磊凌平张万才
申请(专利权)人:国网上海市电力公司华东电力试验研究院有限公司
类型:发明
国别省市:上海,31

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1