一种RDN超分辨网络的训练方法及图像生成方法技术

技术编号:27938761 阅读:35 留言:0更新日期:2021-04-02 14:19
本发明专利技术公开了一种RDN超分辨网络的训练方法及图形生成方法,其步骤包括:1)将低分辨率样本图像输入到生成器,生成高分辨率图像;2)将生成的高分辨率图像作为假数据,计算该假数据与对应标签数据之间的损失值;3)提取该假数据的特征与对应标签数据的特征,然后计算特征之间的损失值;4)将生成的高分辨率图像及其多个下采样数据作为假数据,根据该假数据与对应真实数据计算生成器中损失函数的损失值;5)将生成的高分辨率图像及其多个下采样数据作为假数据,根据该假数据与对应真实数据计算判别器中损失函数的损失值,进行判别器参数更新;6)将步骤2)~4)所得损失值加权在一起,作为生成器的损失,进行生成器的参数更新。

【技术实现步骤摘要】
一种RDN超分辨网络的训练方法及图像生成方法
本专利技术属于超分辨领域,涉及改善一种与深度学习相结合的数据插值方法,可用于残差密集网络RDN(ResidualDenseNetwork)超分辨网络的训练,具体涉及一种新的RDN超分辨网络训练方法及图像生成方法。
技术介绍
大量的电子图像应用领域,人们经常期望得到高分辨率(简称HR)图像。高分辨率意味着图像中的像素密度高,能够提供更多的细节,而这些细节在许多实际应用中不可或缺。例如,高分辨率医疗图像对于医生做出正确的诊断是非常有帮助的;使用高分辨率卫星图像就很容易从相似物中区别相似的对象;如果能够提供高分辨的图像,计算机视觉中的模式识别的性能就会大大提高。自从上世纪七十年代以来,电荷耦合器件(CCD)、CMOS图像传感器已被广泛用来捕获数字图像。尽管对于大多数的图像应用来说这些传感器是合适的,但是当前的分辨率水平和消费价格不能满足今后的需求。例如,人们希望得到一个便宜的高分辨率数码相机/便携式摄像机,或者期望其价格逐渐下降;科学家通常需要一个非常高的接近35毫米模拟胶片的分辨率水平,这样在放大一个图像的时候就不会有可见的瑕疵。因此,寻找一种增强当前分辨率水平的方法是非常必须的。图像插值,即增加单幅图像的尺寸可以增强图像的分辨率。传统的图像插值有最近邻插值法,双线性插值法,三次内插法等。都在一定程度上完成了图像超分辨率任务。但是传统方法没有解决一个基本问题,就是当图片处于低分辨率的时候,图片本身缺少足够的高频信息。传统方法经过线性差值之后,并没有补全足够的高频信息,这使得图片在放大之后面临,边缘模糊,图像不清晰的问题。深度学习提供了另一种解决思路,通过神经网络强大的学习能力,以及用很深的网络来学习低分辨率图片所能提供的语义信息和边缘信息,通过非线性的方式生成对应的高分辨率图片。深度学习网络在网络特别深的时候,会发生梯度消失问题,RDN网络在加深网络的同时吸取了desnet网络和resnet网络,将网络的中间信息处理层加深到了100多层,并取得了优秀的成绩,但是仍然存在信息填补不能满足需要的问题。本专利技术聚焦RDN在现实应用中所存在的问题,进行了改进,通过使用生成模型的方法来重新训练RDN网络。
技术实现思路
针对现有技术中存在的技术问题,本专利技术的目的在于提供一种RDN超分辨网络的训练方法及图像生成方法,利用训练后的RDN超分辨网络对低分辨图像进行处理可以生成高分辨率图像。本专利技术将RDN超分辨网络放在GAN的框架上进行训练,有助于网络学到真实世界的数据分布。而GAN的框架一共分为两个部分,生成器和判别器。生成器部分就是RDN网络结构不变。在每个卷积层都添加普归一化,以促进网络训练的稳定性。而判别器由本专利技术设计。实验证明,如果判别器的判别能力更强的话,最后生成器生成的结果也会变得更好。为了提升判别起的能力,我们决定采用结合LSGAN的多尺度判别器。传统GAN有模型崩塌以及生成器经常学习不到有用的东西的问题。模型崩塌的后果是生成模型无法生成多种样本。但是我们的任务是超分辨率,目的就是尽可能的接近原图,所以不必考虑模型崩塌的问题。至于生成器学习不到有用的东西的问题,究其原因是因为生成器和判别器学习能力不对等,经常出现判别器学习速度过快,损失很快降到0,从未导致的生成器无法进行梯度更新。LSGAN已经解决了这个问题,将损失函数变为MSE损失,变成与真实标签的距离。因此无论判别器的损失是否为0,生成器都可以进行更新。提升网络的感受野,可以提升判别器的判别能力。最直接的做法就是使用更大的网络卷积核。例如从原来的3*3卷积核提升为5*5或7*7的卷积核。但是这么做的后果就是会使得判别器的网络变大,需要训练更多的参数,整个模型占用的显存更多,因此我们使用另一种方法,将生成器生成的图片进行不同程度的下采样,那么相对应的判别器的感受野也就获得了不同程度的扩大。对每张图像用bicubic插值方法,进行下采样2倍和4倍。在使用GAN架构的同时,我们还引用了perceptualloss作为辅助训练的损失函数,用于计算损失值Lvgg,即perceptualloss=LVGG。Perceptualloss是生成器输出的图片与真实图片数据在经过预训练过得VGG网络的不同层后,输出的特征的曼哈顿距离(L1距离)的加权和。若要逼近生成器生成图与原图的距离,不应只是从像素的程度逼近,更应从特征的角度逼近。特别地,如果可以的话,应该逼近不同特征提取的网络所提取的特征的角度进行逼近。VGGj是指VGG网络第j层的输出,c、w、h是指输出矩阵的信道数、列和行;yg是指网络生成的图片,y是指标签数据即真实数据。本专利技术的技术方案为:一种RDN超分辨网络的训练方法,其特征在于,将RDN超分辨网络放在GAN的框架上进行训练,GAN的框架包括生成器和判别器,生成器为RDN超分辨网络;其步骤包括:1)将训练所用的每对低分辨率样本图像和高分辨率样本图像进行归一化处理;其中低分辨率图像作为输入,高分辨率图像作为标签数据;2)将低分辨率样本图像输入到生成器,提取该低分辨率样本图像的浅层信息;3)提取到的浅层信息输入到生成器的RDB层,提取该低分辨率样本图像的边缘信息;4)生成器将各RDB层提取的信息在信道维度上叠加在一起,然后进行卷积处理、上采样后重构生成高分辨率图像;5)将步骤4)生成的高分辨率图像作为假数据,计算该假数据与对应标签数据之间的曼哈顿距离,得到损失值Lsr;6)利用VGG网络提取该假数据的特征与对应标签数据的特征,然后计算该假数据的特征与对应标签数据的特征之间的损失值Lvgg;7)将步骤4)生成的高分辨率图像及其多个下采样数据作为假数据,将输入的该低分辨率样本图像对应的标签数据及其多个下采样后的数据作为真实数据,根据该假数据与该真实数据计算生成器中损失函数的损失值;8)将步骤4)生成的高分辨率图像及其多个下采样数据作为假数据,将输入的该低分辨率样本图像对应的标签数据及其多个下采样后的数据作为真实数据,根据该假数据与该真实数据计算判别器中损失函数的损失值;9)将步骤5)、6)、7)所得损失值以加权的方式叠加在一起,作为生成器的损失,进行生成器的参数更新;将步骤8)得到的损失值作为判别器的损失,进行判别器参数更新;重复步骤1)~8)的处理,直至达到收敛条件。进一步的,所述判别器采用结合LSGAN的多尺度判别器。进一步的,所述RDN超分辨网络采用kaiming初始化方式进行初始化。进一步的,对每一RDB层的输出进行上采样并将前RDB层的上采样输出作为后一RDB层的条件。进一步的,以步骤5)得到的损失值Lsr为主,步骤6)、7)所得损失值为辅,以加权的方式叠加在一起,作为生成器的损失值;即Lsr权重大于步骤6)、7)所得损失值的权重。进一步的,采用公式损失函数计算得到损失值Lvgg;其中,VGGj是VGG网络第j层的输出,c、w、h是指VGG网络输出本文档来自技高网
...

【技术保护点】
1.一种RDN超分辨网络的训练方法,其特征在于,将RDN超分辨网络放在GAN的框架上进行训练,GAN的框架包括生成器和判别器,生成器为RDN超分辨网络;其步骤包括:/n1)将训练所用的每对低分辨率样本图像和高分辨率样本图像进行归一化处理;其中低分辨率图像作为输入,高分辨率图像作为标签数据;/n2)将低分辨率样本图像输入到生成器,提取该低分辨率样本图像的浅层信息;/n3)提取到的浅层信息输入到生成器的RDB层,提取该低分辨率样本图像的边缘信息;/n4)生成器将各RDB层提取的信息在信道维度上叠加在一起,然后进行卷积处理、上采样后重构生成高分辨率图像;/n5)将步骤4)生成的高分辨率图像作为假数据,计算该假数据与对应标签数据之间的曼哈顿距离,得到损失值L

【技术特征摘要】
1.一种RDN超分辨网络的训练方法,其特征在于,将RDN超分辨网络放在GAN的框架上进行训练,GAN的框架包括生成器和判别器,生成器为RDN超分辨网络;其步骤包括:
1)将训练所用的每对低分辨率样本图像和高分辨率样本图像进行归一化处理;其中低分辨率图像作为输入,高分辨率图像作为标签数据;
2)将低分辨率样本图像输入到生成器,提取该低分辨率样本图像的浅层信息;
3)提取到的浅层信息输入到生成器的RDB层,提取该低分辨率样本图像的边缘信息;
4)生成器将各RDB层提取的信息在信道维度上叠加在一起,然后进行卷积处理、上采样后重构生成高分辨率图像;
5)将步骤4)生成的高分辨率图像作为假数据,计算该假数据与对应标签数据之间的曼哈顿距离,得到损失值Lsr;
6)利用VGG网络提取该假数据的特征与对应标签数据的特征,然后计算该假数据的特征与对应标签数据的特征之间的损失值Lvgg;
7)将步骤4)生成的高分辨率图像及其多个下采样数据作为假数据,将输入的该低分辨率样本图像对应的标签数据及其多个下采样后的数据作为真实数据,根据该假数据与该真实数据计算生成器中损失函数的损失值;
8)将步骤4)生成的高分辨率图像及其多个下采样数据作为假数据,将输入的该低分辨率样本图像对应的标签数据及其多个下采样后的数据作为真实数据,根据该假数据与该真实数据计算判别器中损失函数的损失值;
9)将步骤5)、6)、7)所得损失值以加权的方式叠加在一起,作...

【专利技术属性】
技术研发人员:刘凯刘冠群王雷王鑫刘泽艺郭晓博何原野
申请(专利权)人:中国科学院信息工程研究所
类型:发明
国别省市:北京;11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1