【技术实现步骤摘要】
一种基于迁移学习及生成对抗网络的图像数据扩增方法
[0001]本专利技术涉及图像生成对抗网络
,尤其涉及一种基于迁移学习及生成对抗网络的图像数据扩增方法。
技术介绍
[0002]近年来,深度学习飞速发展,为图像分类、目标检测、图像分割等领域带来了巨大的性能提升。然而,利用深度学习来完成这些任务常常需要依赖于大规模的训练数据集,例如常见的ImageNet、COCO数据集分别有约1400万、33万张图像。在许多现实场景中,受限于时间、成本、采集难度等因素,这样大规模的数据难以被采集。因此,很多数据扩增(即增加可用于训练的图像数量)的方法被提出,例如平移、翻转、裁剪、加噪声、颜色抖动等简单的像素层面的操作,以及利用生成对抗网络(Generative Adversarial Network,GAN,其主要特征为具备两个部分:生成器以及判别器,此二部分相互进行对抗训练来提升生成器的性能)来生成全新样本等方法。其中,像素层面的数据扩增方法带来的提升相对有限,因为其并未能有效带来更多、更全面的信息;而利用GAN在数据量样本有限的情况下,现有方法难以产出较高质量的图像。如何有效地对小数据集进行数据扩增,是一个业界广泛关注的问题。
[0003]现有的方法中,存在一种思路,其利用一个网络(其作者称其为Miner)对输入隐向量进行变换,即相当于将原本的输入空间变换为另一空间,从而使得输入隐向量更利于预训练生成器产生与目标数据集相像的图像;并将Miner与预训练GAN在目标数据集上进行进一步联合训练,来提升对于目标图像的生成能
【技术保护点】
【技术特征摘要】
1.一种基于迁移学习及生成对抗网络的图像数据扩增方法,其特征在于,包括以下步骤:构建第一GAN,其生成器和判别器分别称为第一图像生成器和第一图像判别器;利用第一图像生成器收集输入的隐向量与输出图像的样本对数据集,称为第一数据集;把第一数据集中的每一对数据融合,得到的数据集称为第二数据集;构建第一编码器,其输入、输出格式分别与第一数据集的图像、隐向量相同;构建第一向量判别器,其输入格式与第二数据集的数据相同;利用第一数据集以及第二数据集,对第一编码器以及第一向量判别器进行训练;收集目标数据集,作为第三数据集;对第三数据集预处理获得第四数据集;将第四数据集的图像依次输入训练后的第一编码器,得到相同数量的隐向量数据集,称为第五数据集;使用第四数据集以及第五数据集对GAN网络模型进行训练;利用训练后GAN网络模型进行数据扩增,从而有效达到图像数据扩增的效果。2.根据权利要求1所述的一种基于迁移学习及生成对抗网络的图像数据扩增方法,其特征在于,所述融合包括以下步骤:获取在数据集上预训练过的、具备从图像数据提取特征能力的网络,并去除该网络末端的若干层,只保留前端的特征提取部分,称保留下来的部分为第一特征提取器;对于第一数据集中的样本对,将样本对中的图像输入第一特征提取器,得到该图像对应的特征图;将此特征图的元素重新排列为一维向量,并与样本对中的隐向量进行拼接,即完成该样本对的融合。3.根据权利要求2所述的一种基于迁移学习及生成对抗网络的图像数据扩增方法,其特征在于,对第三数据集进行预处理的步骤包括:对数据进行缩放、裁切,使数据尺寸与第一图像生成器输出图像的尺寸相吻合;将数据进行归一化,使其取值范围在
‑
1至1之间,范围包含
‑
1以及1;为每一个数据各创建一份副本,并将此副本水平翻转,将翻转后的副本也加入到数据集。4.根据权利要求3所述的一种基于迁移学习及生成对抗网络的图像数据扩增方法,其特征在于,第一编码器的训练方式为监督训练与对抗训练相结合,第一向量判别器的训练方式为对抗训练;第一编码器包括监督训练的损失函数和对抗训练的损失函数,对抗训练的损失函数包括将第一编码器输出的向量以及对应的输入到第一编码器的图像,经过融合的图像
‑
隐向量融合数据,被第一向量判别器归属于真融合数据;第一向量判别器包括抗训练方面的损失函数,将第二数据集中的数据归属于真融合数据,并且将第一编码器得到的融合数据归属于假融合数据;归属是指第一向量判别器接收到某样本对后,输出的值与预设的期望输出值接近;训练过程中,第一编码器和第一向量判别器根据上述损失函数进行交替优化。5.根据权利要求4所述的一种基于迁移学习及生成对抗网络的图像数据扩增方法,其特征在于,用于对第一编码器进行训练的监督训练的损失函数为L2损失函数,具体形式为:
其中,m为训练过程中迭代批次的样本数量,i为该批次中样本的序号;z_01_pred为第一数据集中图像输入第一编码器后得到的预测隐向量;z_01为第一数据集中相应图像对应的隐向量。6.根据权利要求5所述的一种基于迁移学习及生成对抗网络的图像数据扩增方法,其特征在于,用于对第一编码器以及第一向量判别器进行对抗训练的损失函数具体形式为:特征在于,用于对第一编码器以及第一向量判别器进行对抗训练的损失函数具体形式为:其中,L
enc
为作用于第一编码器的对抗损失函数,L
disc_vec
为作用于第一向量判别器的对抗损失函数;m为训练过程中迭代批次的样本数量,i、j、k为批次中样本的序号;img_01为第一数据集中的图像;z_01为第一数据集中的隐向量;E()表示第一向量判别器的前向传播过程...
【专利技术属性】
技术研发人员:周智恒,李志豪,陶希远,曹英烈,杨俊怡,
申请(专利权)人:华南理工大学,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。