模型训练方法、设备和存储介质技术

技术编号:37165594 阅读:8 留言:0更新日期:2023-04-20 22:38
本公开提供了一种模型训练方法、设备和存储介质,涉及深度学习、机器学习等人工智能技术领域。具体实现方案为:获取与预训练好的源生成式对抗网络GAN模型相同的目标GAN模型;获取目标任务对应的样本图像集合;针对样本图像集合中的各个样本图像,确定源GAN模型的生成器在生成样本图像时所使用的目标噪声变量;确定目标噪声变量所服从的数据分布;根据数据分布和样本图像集合对目标GAN模型进行训练,由此,基于样本图像集合在源GAN模型中所学习到的数据分布来对目标GAN模型进行训练,可更好地利用源GAN模型的信息,实现对源GAN模型的信息的继承以及目GAN标模型的自适应调整,避免目标GAN模型出现过拟合,提高目标GAN模型的泛化能力。化能力。化能力。

【技术实现步骤摘要】
模型训练方法、设备和存储介质


[0001]本公开涉及计算机
,具体涉及深度学习、机器学习等人工智能
,尤其涉及模型训练方法、设备和存储介质。

技术介绍

[0002]目前,通常采用成式对抗网络(Generative Adversarial Network,GAN)模型来生成图像。其中,通常训练GAN模型通常需要依赖巨大的训练数据,然而,很多实际任务中通常只有非常有限的样本,例如,罕见的物体、特殊风格的图像等。
[0003]相关技术中,在目标任务所对应的样本量较少的情况下,通常采用基于大数据训练所得到的源GAN模型的网络参数对目标任务所对应的目标GAN模型进行初始化,并基于目标任务对应的样本图像数据对目标GAN模型进行训练。然而,上述方式训练得到的目标GAN模型的泛化能力较差,容易出现过拟合现象。

技术实现思路

[0004]本公开提供了一种用于模型训练方法、设备和存储介质。
[0005]根据本公开的一方面,提供了一种模型训练方法,包括:获取与预训练好的源生成式对抗网络GAN模型相同的目标GAN模型;获取目标任务对应的样本图像集合;针对所述样本图像集合中的各个样本图像,确定所述源GAN模型的生成器在生成所述样本图像时所使用的目标噪声变量;确定所述目标噪声变量所服从的数据分布;根据所述数据分布和所述样本图像集合对所述目标GAN模型进行训练。
[0006]根据本公开的另一方面,提供了一种模型训练装置,包括:第一获取模块,用于获取与预训练好的源生成式对抗网络GAN模型相同的目标GAN模型;第二获取模块,用于获取目标任务对应的样本图像集合;第一确定模块,用于针对所述样本图像集合中的各个样本图像,确定所述源GAN模型的生成器在生成所述样本图像时所使用的目标噪声变量;第二确定模块,用于确定所述目标噪声变量所服从的数据分布;训练模块,用于根据所述数据分布和所述样本图像集合对所述目标GAN模型进行训练。
[0007]根据本公开的另一方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本公开的模型训练方法。
[0008]根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使所述计算机执行本公开实施例公开的模型训练方法。
[0009]根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序被处理器执行时实现本公开的模型训练方法。
[0010]应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
[0011]附图用于更好地理解本方案,不构成对本公开的限定。其中:
[0012]图1是根据本公开第一实施例的示意图;
[0013]图2是根据本公开第一实施例的示意图;
[0014]图3是根据本公开第三实施例的示意图;
[0015]图4是根据本公开第四实施例的示意图;
[0016]图5是根据本公开第五实施例的示意图;
[0017]图6是根据本公开第六实施例的示意图;
[0018]图7是根据本公开一个实施例中的模型训练方法的示例图;
[0019]图8是根据本公开第七实施例的示意图;
[0020]图9是根据本公开第八实施例的示意图;
[0021]图10是用来实现本公开实施例的模型训练方法的电子设备的框图。
具体实施方式
[0022]以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
[0023]下面参考附图描述本公开实施例的模型训练方法、设备和存储介质。
[0024]图1是根据本公开第一实施例的示意图。
[0025]如图1所示,该模型训练方法可以包括:
[0026]步骤101,获取与预训练好的源生成式对抗网络GAN模型相同的目标GAN模型。
[0027]其中,需要说明的是,上述模型训练方法的执行主体为模型训练装置,该模型训练装置可以由软件和/或硬件的方式实现,该实施例中的模型训练装置可以为电子设备,或者,可以配置在电子设备中。
[0028]其中,本示例实施例中电子设备可以包括但不限于终端设备和服务器等设备,该实施例对电子设备不作限定。
[0029]其中,本示例中的源GAN模型是基于大量样本图像训练而得到的。
[0030]可以理解的是,本示例中的目标GAN模型的模型结构以及模型参数均是与源GAN模型相同的。
[0031]步骤102,获取目标任务对应的样本图像集合。
[0032]其中,本示例中的样本图像集合中样本图像的数量是有限的。
[0033]其中,本示例中的样本图像用于对目标模型进行训练。
[0034]在一些示例中,样本图像集合中各个样本图像所对应的类型是相同的。
[0035]步骤103,针对样本图像集合中的各个样本图像,确定源GAN模型的生成器在生成样本图像时所使用的目标噪声变量。
[0036]在一些示例性的实施方式中,针对各个样本图像,可基于预先保存的样本图像和对应的噪声变量之间的对应关系中,获取源GAN模型的生成器在生成样本图像时所使用的目标噪声变量。
[0037]其中,对应关系中的噪声变量是源GAN模型的生成器在生成样本图像时所使用的噪声变量。
[0038]步骤104,确定目标噪声变量所服从的数据分布。
[0039]在一些示例性的实施方式中,在确定出各个样本图像对应的目标噪声变量后,可对所有目标噪声变量进行拟合,以得到目标噪声变量所服从的数据分布。
[0040]在确定出数据分布后,可确定出该数据分布所对应的标准差以及期望值等。
[0041]其中,通过对数据分布进行分析,可知,本示例中数据分布为混合高斯分布。
[0042]步骤105,根据数据分布和样本图像集合对目标GAN模型进行训练。
[0043]在一些示例性的实施方式中,在目标任务所对应的样本图像集合中各样本图像的类别是相同的,并且,源GAN模型对一个类别的图像进行处理过程的情况下,可直接基于数据分布和样本图像集合对目标GAN模型进行训练。
[0044]在另一些示例性的实施方式中,在目标任务所对应的样本图像集合中的各样本图像的类别是相同的,而源GAN模型可对多个类别的图像进行处理过程的情况下,在根据数据分布和样本图像集合对目标GAN模型进行训练之前,还可以将目标GAN模型的生成器中的多个类别嵌入层替换为单个类别本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,包括:获取与预训练好的源生成式对抗网络GAN模型相同的目标GAN模型;获取目标任务对应的样本图像集合;针对所述样本图像集合中的各个样本图像,确定所述源GAN模型的生成器在生成所述样本图像时所使用的目标噪声变量;确定所述目标噪声变量所服从的数据分布;根据所述数据分布和所述样本图像集合对所述目标GAN模型进行训练。2.根据权利要求1所述的方法,其中,所述针对所述样本图像集合中的各个样本图像,确定所述源GAN模型的生成器在生成所述样本图像时所使用的目标噪声变量,包括:针对所述样本图像集合中的各个样本图像,获取所述源GAN模型的生成器初始噪声变量为所述样本图像所生成的第一生成图像,其中,所述初始噪声变量是从服从标准正态分布的噪声数据中随机采样得到的;根据各个样本图像和其对应的第一生成图像之间的像素级差异,确定损失值;根据所述损失值,对所述初始噪声变量进行优化,直至所述损失值满足预设结束条件,其中,在对所述初始噪声变量进行优化的过程中,所述源GAN模型生成器的参数是固定的;将在所述损失值满足所述预设结束条件时,所述源GAN模型的生成器在生成各个样本图像所对应的第一生成图像时所使用的噪声变量,作为所述源GAN模型的生成器在生成所述样本图像时所使用的目标噪声变量。3.根据权利要求1所述的方法,其中,在所述根据所述数据分布和所述样本图像集合对所述目标GAN模型进行训练之前,所述方法还包括:将所述目标GAN模型的生成器中的多个类别嵌入层替换为单个类别嵌入层,并对所述单个类别嵌入层进行随机初始化;将所述目标GAN模型的判别器中的多个全连接层替换为单个全连接层,并对所述单个全连接层进行随机初始化。4.根据权利要求1所述的方法,其中,所述根据所述数据分布和所述样本图像集合对所述目标GAN模型进行训练,包括:对所述数据分布进行多次随机采样,并获取所述目标GAN的生成器基于多次随机采样所得到的噪声变量所生成的生成图像集合;将所述生成图像集合中的各个第二生成图像和所述样本图像集合中各个样本图像输入到所述目标GAN模型的判别器中,以得到各个第二生成图像的第一分类结果和各个样本图像的第二分类结果,其中,所述第一分类结果用于表示所述生成图像来自所述样本图像集合或者所述生成图像集合,所述第二分类结果用于表示所述样本图像来自所述样本图像集合或者所述生成图像集合;从所述判别器中获得各个第二生成图像的特征向量;根据所述第一分类结果、所述特征向量和所述第二分类结果,对所述目标GAN模型的生成器和判别器进行交替训练,直至满足训练结束条件。5.根据权利要求4所述的方法,其中,所述根据所述第一分类结果、所述特征向量和所述第二分类结果,对所述目标GAN模型的生成器和判别器进行交替训练,直至满足训练结束条件,包括:
根据所述第一分类结果、所述特征向量和所述第二分类结果,确定所述目标GAN模型的总损失值;根据所述总损失值,对所述目标GAN模型的判别器进行训练,其中,在对所述目标GAN模型的判别器进行训练的过程中,所述目标GAN模型的生成器的参数不变;根据所述总损失值,对所述目标GAN模型的生成器进行训练,其中,在对所述目标GAN模型的生成器进行训练的过程中,所述目标GAN模型的判别器的参数不变;交替执行对所述目标GAN模型的判别器和生成器进行训练的步骤,直至满足训练结束条件。6.根据权利要求4所述的方法,其中,所述根据所述第一分类结果、所述特征向量和所述第二分类结果,确定所述目标GAN模型的总损失值,包括:根据所述第一分类结果和所述第二分类结果,确定所述目标GAN模型的第一损失值;根据所述特征向量,确定所述生成图像集合的特征矩阵;根据所述特征矩阵,确定所述目标GAN模型的第二损失值;根据所述第一损失值和所述第二损失值,确定所述目标GAN模型的总损失值。7.根据权利要求6所述的方法,其中,所述根据所述特征矩阵,确定所述目标GAN模型的第二损失值,包括:对所述特征矩阵进行奇异值分解,以得到所述特征矩阵所对应的多个奇异值;按照从小到大的顺序对所述多个奇异值进行排序,以得到排序结果;从排序结果中获取排序在前K位的目标奇异值,其中,K为大于或者等于1的整数,并且,K小于或者等于N,N为所述多个奇异值的总数;根据排序在前K位的目标奇异值,确定所述目标GAN模型的第二损失值。8.根据权利要求1

7中任一项所述的方法,其中,在对所述目标GAN进行T轮训练时,针对第t轮训练,在根据所述数据分布和所述样本图像集合对所述目标GAN模型进行训练之前,所述方法还包括:根据t和T,对所述数据分布的均值和标准差进行调整,以对所述数据分布进行更新,其中,t为大于1的整数,并且,t小于或者等于T。9.一种模型训练装置,包括:第一获取模块,用于获取与预训练好的源生成式对抗网络GAN模型相同的目标GAN模型;第二获取模块,用于获取目标任务对应的样本图像集合;第一确定模块,用于针对所述样本图像集合中的各个样本图像,确定所述源GAN...

【专利技术属性】
技术研发人员:李兴建张泽人窦德景
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1