一种基于迁移学习及生成对抗网络的图像数据扩增方法技术

技术编号:37597906 阅读:7 留言:0更新日期:2023-05-18 11:46
本发明专利技术公开一种基于迁移学习及生成对抗网络的图像数据扩增方法,包括以下步骤:利用GAN的第一图像生成器收集输入的隐向量与输出图像的样本对数据集,称为第一数据集;把第一数据集中的每一对数据融合,得到第二数据集;构建第一编码器,其输入、输出格式分别与第一数据集的图像、隐向量相同;构建第一向量判别器,其输入格式与第二数据集的数据相同;利用第一数据集以及第二数据集,对第一编码器以及第一向量判别器进行训练;收集第三数据集;对第三数据集预处理获得第四数据集;将第四数据集的图像依次输入训练后的第一编码器,得到相同数量的隐向量数据集,称为第五数据集;对GAN网络模型进行训练;利用训练后GAN网络模型进行数据扩增。行数据扩增。行数据扩增。

【技术实现步骤摘要】
一种基于迁移学习及生成对抗网络的图像数据扩增方法


[0001]本专利技术涉及图像生成对抗网络
,尤其涉及一种基于迁移学习及生成对抗网络的图像数据扩增方法。

技术介绍

[0002]近年来,深度学习飞速发展,为图像分类、目标检测、图像分割等领域带来了巨大的性能提升。然而,利用深度学习来完成这些任务常常需要依赖于大规模的训练数据集,例如常见的ImageNet、COCO数据集分别有约1400万、33万张图像。在许多现实场景中,受限于时间、成本、采集难度等因素,这样大规模的数据难以被采集。因此,很多数据扩增(即增加可用于训练的图像数量)的方法被提出,例如平移、翻转、裁剪、加噪声、颜色抖动等简单的像素层面的操作,以及利用生成对抗网络(Generative Adversarial Network,GAN,其主要特征为具备两个部分:生成器以及判别器,此二部分相互进行对抗训练来提升生成器的性能)来生成全新样本等方法。其中,像素层面的数据扩增方法带来的提升相对有限,因为其并未能有效带来更多、更全面的信息;而利用GAN在数据量样本有限的情况下,现有方法难以产出较高质量的图像。如何有效地对小数据集进行数据扩增,是一个业界广泛关注的问题。
[0003]现有的方法中,存在一种思路,其利用一个网络(其作者称其为Miner)对输入隐向量进行变换,即相当于将原本的输入空间变换为另一空间,从而使得输入隐向量更利于预训练生成器产生与目标数据集相像的图像;并将Miner与预训练GAN在目标数据集上进行进一步联合训练,来提升对于目标图像的生成能力。遗憾的是,这样的方式比较简单,难以获得令人满意的性能。本专利技术基于同样的出发点,公开一种更有效的对产生输入隐向量并用于在目标数据集上训练GAN,最终有效对目标数据集进行扩增的方式。

技术实现思路

[0004]本专利技术的目的是利用预训练GAN(生成对抗网络,Generative Adversarial Network)以及已有的少量目标数据样本,借助所公开的方法更有效地对输入空间进行变换,提供给生成器更好的输入隐向量,从而提高对于目标数据集图像的生成能力,进而达到对目标数据集进行有效数据扩增的目标。
[0005]本专利技术至少通过如下技术方案之一实现。
[0006]一种基于迁移学习及生成对抗网络的图像数据扩增方法,包括以下步骤:
[0007]构建第一GAN,其生成器和判别器分别称为第一图像生成器和第一图像判别器;
[0008]利用第一图像生成器收集输入的隐向量与输出图像的样本对数据集,称为第一数据集;
[0009]把第一数据集中的每一对数据融合,得到的数据集称为第二数据集;
[0010]构建第一编码器,其输入、输出格式分别与第一数据集的图像、隐向量相同;
[0011]构建第一向量判别器,其输入格式与第二数据集的数据相同;
[0012]利用第一数据集以及第二数据集,对第一编码器以及第一向量判别器进行训练;
[0013]收集目标数据集,作为第三数据集;对第三数据集预处理获得第四数据集;将第四数据集的图像依次输入训练后的第一编码器,得到相同数量的隐向量数据集,称为第五数据集;
[0014]使用第四数据集以及第五数据集对GAN网络模型进行训练;利用训练后GAN网络模型进行数据扩增,从而有效达到图像数据扩增的效果。
[0015]进一步地,所述融合包括以下步骤:
[0016]获取在数据集上预训练过的、具备从图像数据提取特征能力的网络,并去除该网络末端的若干层,只保留前端的特征提取部分,称保留下来的部分为第一特征提取器;
[0017]对于第一数据集中的样本对,将样本对中的图像输入第一特征提取器,得到该图像对应的特征图;将此特征图的元素重新排列为一维向量,并与样本对中的隐向量进行拼接,即完成该样本对的融合。
[0018]进一步地,对第三数据集进行预处理的步骤包括:
[0019]对数据进行缩放、裁切,使数据尺寸与第一图像生成器输出图像的尺寸相吻合;
[0020]将数据进行归一化,使其取值范围在

1至1之间,范围包含

1以及1;
[0021]为每一个数据各创建一份副本,并将此副本水平翻转,将翻转后的副本也加入到数据集。
[0022]进一步地,第一编码器的训练方式为监督训练与对抗训练相结合,第一向量判别器的训练方式为对抗训练;
[0023]第一编码器包括监督训练的损失函数和对抗训练的损失函数,对抗训练的损失函数包括将第一编码器输出的向量以及对应的输入到第一编码器的图像,经过融合的图像

隐向量融合数据,被第一向量判别器归属于真融合数据;
[0024]第一向量判别器包括抗训练方面的损失函数,将第二数据集中的数据归属于真融合数据,并且将第一编码器得到的融合数据归属于假融合数据;归属是指第一向量判别器接收到某样本对后,输出的值与预设的期望输出值接近;
[0025]训练过程中,第一编码器和第一向量判别器根据上述损失函数进行交替优化。
[0026]进一步地,用于对第一编码器进行训练的监督训练的损失函数为L2损失函数,具体形式为:
[0027][0028]其中,m为训练过程中迭代批次的样本数量,i为该批次中样本的序号;z_01_pred为第一数据集中图像输入第一编码器后得到的预测隐向量;z_01为第一数据集中相应图像对应的隐向量。
[0029]进一步地,用于对第一编码器以及第一向量判别器进行对抗训练的损失函数具体形式为:
[0030][0031][0032]其中,L
enc
为作用于第一编码器的对抗损失函数,L
disc_vec
为作用于第一向量判别器的对抗损失函数;m为训练过程中迭代批次的样本数量,i、j、k为批次中样本的序号;img_01为第一数据集中的图像;z_01为第一数据集中的隐向量;E()表示第一向量判别器的前向传播过程;F()表示融合方式;D
vec
()表示第一向量判别器的前向传播过程。
[0033]进一步地,所述GAN网络模型包括修饰器和第二GAN,第四数据集上进行GAN网络模型训练的步骤包括:
[0034]构建修饰器,其输入与输出的格式均为隐向量的格式;建立一个与第一GAN的结构相同的第二GAN,并用第一GAN的可训练参数的值对第二GAN进行初始化;第二GAN的生成器部分称为第二图像生成器,判别器部分称为第二图像判别器。
[0035]进一步地,修饰器以及第二图像生成器的前向传播过程包括以下步骤:
[0036]首先,获取修饰器的输入:从第五数据集中采样隐向量,称为z_enc;随机采样和第一图像生成器预训练时的输入同分布的隐向量,为称z_noise01;对z_enc和z_noise01求加权和,得到修饰器的输入隐向量z_inp01,即z_inp01=a*z_enc+b*z_noise01,其中,a、b为任意实数;...

【技术保护点】

【技术特征摘要】
1.一种基于迁移学习及生成对抗网络的图像数据扩增方法,其特征在于,包括以下步骤:构建第一GAN,其生成器和判别器分别称为第一图像生成器和第一图像判别器;利用第一图像生成器收集输入的隐向量与输出图像的样本对数据集,称为第一数据集;把第一数据集中的每一对数据融合,得到的数据集称为第二数据集;构建第一编码器,其输入、输出格式分别与第一数据集的图像、隐向量相同;构建第一向量判别器,其输入格式与第二数据集的数据相同;利用第一数据集以及第二数据集,对第一编码器以及第一向量判别器进行训练;收集目标数据集,作为第三数据集;对第三数据集预处理获得第四数据集;将第四数据集的图像依次输入训练后的第一编码器,得到相同数量的隐向量数据集,称为第五数据集;使用第四数据集以及第五数据集对GAN网络模型进行训练;利用训练后GAN网络模型进行数据扩增,从而有效达到图像数据扩增的效果。2.根据权利要求1所述的一种基于迁移学习及生成对抗网络的图像数据扩增方法,其特征在于,所述融合包括以下步骤:获取在数据集上预训练过的、具备从图像数据提取特征能力的网络,并去除该网络末端的若干层,只保留前端的特征提取部分,称保留下来的部分为第一特征提取器;对于第一数据集中的样本对,将样本对中的图像输入第一特征提取器,得到该图像对应的特征图;将此特征图的元素重新排列为一维向量,并与样本对中的隐向量进行拼接,即完成该样本对的融合。3.根据权利要求2所述的一种基于迁移学习及生成对抗网络的图像数据扩增方法,其特征在于,对第三数据集进行预处理的步骤包括:对数据进行缩放、裁切,使数据尺寸与第一图像生成器输出图像的尺寸相吻合;将数据进行归一化,使其取值范围在

1至1之间,范围包含

1以及1;为每一个数据各创建一份副本,并将此副本水平翻转,将翻转后的副本也加入到数据集。4.根据权利要求3所述的一种基于迁移学习及生成对抗网络的图像数据扩增方法,其特征在于,第一编码器的训练方式为监督训练与对抗训练相结合,第一向量判别器的训练方式为对抗训练;第一编码器包括监督训练的损失函数和对抗训练的损失函数,对抗训练的损失函数包括将第一编码器输出的向量以及对应的输入到第一编码器的图像,经过融合的图像

隐向量融合数据,被第一向量判别器归属于真融合数据;第一向量判别器包括抗训练方面的损失函数,将第二数据集中的数据归属于真融合数据,并且将第一编码器得到的融合数据归属于假融合数据;归属是指第一向量判别器接收到某样本对后,输出的值与预设的期望输出值接近;训练过程中,第一编码器和第一向量判别器根据上述损失函数进行交替优化。5.根据权利要求4所述的一种基于迁移学习及生成对抗网络的图像数据扩增方法,其特征在于,用于对第一编码器进行训练的监督训练的损失函数为L2损失函数,具体形式为:
其中,m为训练过程中迭代批次的样本数量,i为该批次中样本的序号;z_01_pred为第一数据集中图像输入第一编码器后得到的预测隐向量;z_01为第一数据集中相应图像对应的隐向量。6.根据权利要求5所述的一种基于迁移学习及生成对抗网络的图像数据扩增方法,其特征在于,用于对第一编码器以及第一向量判别器进行对抗训练的损失函数具体形式为:特征在于,用于对第一编码器以及第一向量判别器进行对抗训练的损失函数具体形式为:其中,L
enc
为作用于第一编码器的对抗损失函数,L
disc_vec
为作用于第一向量判别器的对抗损失函数;m为训练过程中迭代批次的样本数量,i、j、k为批次中样本的序号;img_01为第一数据集中的图像;z_01为第一数据集中的隐向量;E()表示第一向量判别器的前向传播过程...

【专利技术属性】
技术研发人员:周智恒李志豪陶希远曹英烈杨俊怡
申请(专利权)人:华南理工大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1