当前位置: 首页 > 专利查询>中南大学专利>正文

一种基于WGAN-GP和过采样的不平衡学习方法技术

技术编号:21225365 阅读:44 留言:0更新日期:2019-05-29 06:06
本发明专利技术公开了一种基于WGAN‑GP和过采样的不平衡学习方法,包括:生成器网络,由三层全连接网络组成并且每一层的输出都应用了Batch Normalization(BN)归一化去防止梯度弥散,所述生成器网络由于最终需要产生特定标签的样本,将高斯随机噪声z和样本标签yi∈{0,1}组合成新的向量作为输入,输出样本标签yi对应的生成样本x;同样由三层全连接网络组成的判别器网络,所述判别器网络的输入为样本标签yi以及真实样本或生成样本,输出结果为判断样本是真实的或者为被生成的。本发明专利技术能大大减少噪声样本的产生,优化生成少数样本的质量,最终提升原始不平衡数据集上训练的分类器的泛化性能。

An Unbalanced Learning Method Based on WGAN-GP and Oversampling

The invention discloses an unbalanced learning method based on WGAN GP and oversampling, which comprises a generator network consisting of three layers of fully connected network and the output of each layer is normalized by Batch Normalization (BN) to prevent gradient dispersion. The generator network generates a sample of a specific tag, and sets Gauss random noise Z and sample tag Yi <{0,1} due to the ultimate need to generate a sample of a specific tag. The new vector is synthesized as input, and the output sample label Yi corresponds to the generated sample X. The discriminator network is also composed of three layers of fully connected network. The input of the discriminator network is the sample label Yi and the real sample or generated sample, and the output result is to judge that the sample is real or generated. The method can greatly reduce the generation of noise samples, optimize the quality of generating a few samples, and ultimately improve the generalization performance of the classifier trained on the original unbalanced data set.

【技术实现步骤摘要】
一种基于WGAN-GP和过采样的不平衡学习方法
本专利技术涉及计算机
,尤其涉及一种基于WGAN-GP和过采样的不平衡学习方法。
技术介绍
在不平衡数据集上进行建模学习对于学术界和产业界来说都是一个非常重要的问题。不平衡的学习问题可以定义为来自二分类或多分类数据集的学习问题,其中一类的实例数(成为多数类)明显高于其余类的实例数(成为少数类)。不平衡率(IR),定义为多数类与每一个少数类之间的比例,这个比例对于不同的应用是不同的,对于二分类问题,不平衡率在100到100000之间。不平衡问题是多种实际应用场景普遍存在的,例如:异常检测、故障诊断、电子邮件归档、人脸识别、欺诈检测。标准的机器学习方法在不平衡数据集中表现不佳,因为他们会更偏向于多数类,以准确率为导向的学习过程将缺乏对少数类的关注,因此很容易将少数类误判成多数类。然而对人们有价值的数据通常是少数类,这些少数类的错分代价往往非常大,有效提高少数类的分类精度具有实际社会、经济、技术价值。目前,人们也提出许多方法来解决不平衡分类问题。Galar等人系统地将现有工作分为四类:数据级方法,成本敏感法,算法级方法和集成学习方法。数据级方法通过基于特征空间中最近邻居的过采样或欠采样将不平衡数据转换为平衡数据;成本敏感学习调整传统方法的初始权重,以更多地关注少数人;算法级方法直接修改现有的学习算法,以减轻对多数对象的偏见;最后,集合方法将采样或其他技术与传统的集合方法(如bagging或boosting)相结合,这种方法对于困难的数据具有高度竞争性和鲁棒性综上,不平衡数据集的传统机器学习方法的难点在于:少数类相比多数类的绝对数量劣势,使它不能完整学习少数类的分布。如果数据集中存在子集群,则聚类是学习此类特征的典型可行方法。
技术实现思路
本专利技术的重点是过采样技术,这将为少数类生成人工数据,从而将不平衡数据转换为平衡数据。标准的过采样方法受到合成少数类的过采样技术(SMOTE)算法的启发,沿着连接少数类样本的线生成合成样本。数据生成过程的直接方法是使用捕获实际数据分布的生成模型。生成对抗网络(GAN)是一种使用神经网络创建生成模型的最新方法。条件生成对抗网络(cGAN)通过加入类别信息的训练来扩展GAN模型。在专利技术中,考虑到现实生活中的大部分场景下,数据更一般的形式是结构化存储的,所以我们将针对性的提出一种改进的cGAN,针对结构化数据中的离散category特征,我们会利用Embedding层将其转换成稠密的嵌入向量表示,并且为了解决原始cGAN的训练稳定性问题,修改了原模型的生成器和判别器目标函数。最终的生成器用于为少数类创建人工数据,即生成器对应于过采样算法。本专利技术旨在至少解决现有技术中存在的技术问题。为此,本专利技术公开了一种基于WGAN-GP和过采样的不平衡学习方法,包括:生成器网络,由三层全连接网络组成并且每一层的输出都应用了BatchNormalization(BN)归一化去防止梯度弥散,所述生成器网络由于最终需要产生特定标签的样本,将高斯随机噪声z和样本标签yi∈{0,1}组合成新的向量作为输入,输出样本标签yi对应的生成样本x;同样由三层全连接网络组成的判别器网络,所述判别器网络的输入为样本标签yi以及真实样本或生成样本,输出结果为判断样本是真实的或者为被生成的。更进一步地,包括以下步骤:S1.获取原始数据;S2.将原始数据,分别将少数类样本随机采样,并且与预设量的多数类样本组成多个不同IR率的数据集;S3.依次训练每一个二分类的不平衡数据集其中n表示当前数据集的样本数,yi∈{0,1},y=1表示的是少数类样本的标签,ConditionalWGAN-GP中生成器的输入为样本标签和输入的随机噪声,输出一个生成样本;同时判别器的输入为真实样本与对应标签,或生成样本与其输入标签,输出判断时真实样本还是生成器生成样本,然后用生成器和判别器的损失函数计算损失,利用梯度下降优化模型参数,得到一个可以产生以假乱真样本的生成器;S4.对步骤S3中得到的生成器,向生成器输入随机噪声和少数类标签(z,y=1),生成多个少数类样本,知道使不平衡数据集变成平衡为止;S5.将得到的平衡数据集用五种不同的分类器进行训练并得到在测试集上的预测结果,将原始数据集利用其它几种对比过采样算法进行过采样操作得到平衡数据集,并同样用五种分类器得到测试集上的预测结果。更进一步地,判别器的损失函数,如下:其中,D(·)、G(·)分别表示判别器和生成器模型的函数表达式,Pr表示真实样本的数据分布,Pg表生成器生成样本的数据分布,指的是判别器D(x)的梯度,L(·)表示损失函数。更进一步地,和生成器的损失函数如下:其中,D(·)、G(·)分别表示判别器和生成器模型的函数表达式,Pr表示真实样本的数据分布,Pg表生成器生成样本的数据分布,其中指的是判别器D(x)的梯度。本专利技术是基于对抗神经网络(GAN)一种改进应用,利用本专利技术在不平衡数据上训练得到一个可以生成指定标签样本的生成器,从而实现对不平衡数据的过采样使其变成平衡数据集。本专利技术方法包含一个生成器一个判别器,且都是简单的单隐藏层全连接网络,可以用任何编程语言实现部署。在本专利技术中,将WGAN-GP对抗神经网络模型应用到了不平衡,并且取得了比较好的效果;针对结构化数据中存在大量类别特征的特点,本专利技术将每个真实样本的类别特征利用Embeddinglayer映射到高维稠密空间中,然后再连同其他数值特征一起输入到判别器进行训练,能有效提高模型的性能。针对实际问题的具体应用方式如下:S1.获取原始数据(原始数据可以为任一二分类问题存在两个标签的数据,分别为多数类和少数类);S2.在计算机上实现本专利技术的ConditionalWGAN-GP网络;S2.将本专利技术的ConditionalWGAN-GP中生成器的输入为样本标签和输入的随机噪声,输出一个生成样本;同时判别器的输入为真实样本与对应标签,或生成样本与其输入标签,输出判断时真实样本还是生成器生成样本。然后用生成器和判别器的损失函数计算损失,利用梯度下降优化模型参数,得到一个可以产生以假乱真样本的生成器。S4.对步骤S3中得到的生成器,向生成器输入随机噪声和少数类标签(z,y=1),生成多个少数类样本,知道使不平衡数据集变成平衡为止;S5.通过上述步骤得到平衡数据集之后,就可以拿来训练普通的分类器,这样得到的分类效果往往比直接使用原始不平衡数据的效果要好许多。综上所述,本专利技术的有益效果:我们受对抗神经网络的优点启发,基于对WGAN-GP的研究,提出了一种用于不平衡数据集上的过采样方法。相较于传统的过采样方法简单的利用原始数据的统计特性去人工生成少数类样本,本专利技术通过生成器和判别器的对抗训练过程,可以通过真实数据的本质特征,刻画出样本的数据分布特征,学习从随机高斯噪声到不同类别原始数据的数据分布映射,从而自动生成符合少数类样本真实分布的样本,能大大减少噪声样本的产生,优化生成样本的质量,最终提升原始不平衡数据集上训练的分类器的泛化性能。附图说明从以下结合附图的描述可以进一步理解本专利技术。图中的部件不一定按比例绘制,而是将重点放在示出实施例的原理上。在图中,在不同的视图中,相同的附图标记指定对应的本文档来自技高网
...

【技术保护点】
1.一种基于WGAN‑GP和过采样的不平衡学习方法,其特征在于,包括:生成器网络,由三层全连接网络组成并且每一层的输出都应用了Batch Normalization(一种在训练神经网络时,对每批训练数据进行归一化处理的技术)归一化去防止梯度弥散,所述生成器网络由于最终需要产生特定标签的样本,将高斯随机噪声z和样本标签yi∈{0,1}组合成新的向量作为输入,输出样本标签yi对应的生成样本x;同样由三层全连接网络组成的判别器网络,所述判别器网络的输入为样本标签yi以及真实样本或生成样本,输出结果为判断样本是真实的或者为被生成的。

【技术特征摘要】
1.一种基于WGAN-GP和过采样的不平衡学习方法,其特征在于,包括:生成器网络,由三层全连接网络组成并且每一层的输出都应用了BatchNormalization(一种在训练神经网络时,对每批训练数据进行归一化处理的技术)归一化去防止梯度弥散,所述生成器网络由于最终需要产生特定标签的样本,将高斯随机噪声z和样本标签yi∈{0,1}组合成新的向量作为输入,输出样本标签yi对应的生成样本x;同样由三层全连接网络组成的判别器网络,所述判别器网络的输入为样本标签yi以及真实样本或生成样本,输出结果为判断样本是真实的或者为被生成的。2.如权利要求1所述的一种基于WGAN-GP和过采样的不平衡学习方法,其特征在于,包括以下步骤:S1.获取原始数据;S2.将原始数据,分别将少数类样本随机采样,并且与预设量的多数类样本组成多个不同IR率的数据集;S3.依次训练每一个二分类的不平衡数据集其中n表示当前数据集的样本数,yi∈{0,1},y=1表示的是少数类样本的标签,ConditionalWGAN-GP中生成器的输入为样本标签和输入的随机噪声,输出一个生成样本;同时判别器的输入为真实样本与对应标签,或生成样本与其输入标签,输出判断时真实样本...

【专利技术属性】
技术研发人员:邓晓衡黄戎沈海澜
申请(专利权)人:中南大学
类型:发明
国别省市:湖南,43

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1