一种模型训练方法及模型训练装置制造方法及图纸

技术编号:29405908 阅读:12 留言:0更新日期:2021-07-23 22:44
本发明专利技术实施例公开了一种模型训练方法及训练装置,用于在图像翻译模型的训练数据较少时,提升图像翻译模型的图像翻译质量。本发明专利技术实施例方法包括:利用训练数集对图像翻译模型的生成器和判别器做训练,并将训练后的图像翻译模型视为老师模型,训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和目标帧图像的前N帧图像数据;利用训练数集中的第一数据对图像翻译模型的生成器和判别器进行训练,并将训练后的图像翻译模型视为学生模型,第一数据包括目标帧图像、目标帧图像的轮廓线数据和目标帧图像的前M帧图像数据,M为大于等于1且小于等于N的整数;利用老师模型对学生模型进行知识蒸馏,得到知识蒸馏后的学生模型。

【技术实现步骤摘要】
一种模型训练方法及模型训练装置
本专利技术涉及图像翻译
,尤其涉及一种模型训练方法及模型训练装置。
技术介绍
所谓图像翻译,指从一副图像到另一副图像的转换。可以类比机器翻译,将一种语言转换为另一种语言。现有技术中较为经典的图像翻译模型有pix2pix,pix2pixHD,vid2vid。pix2pix提出了一个统一的框架解决了各类图像翻译问题,pix2pixHD则在pix2pix的基础上,较好地解决了高分辨率图像转换(翻译)的问题,vid2vid则在pix2pixHD的基础上,较好地解决了高分辨率的视频转换问题。数字人,是一种利用信息科学的方法对真实人体的形态和功能进行姿态仿真的虚拟人。目前的图像翻译模型,可以对图像中的数字人进行虚拟仿真,但现有技术中的图像翻译模型,如果训练数据的类型较少,会导致训练得到的图像翻译模型在数字人姿态仿真(或数字人姿态生成)时准确率较低。
技术实现思路
本专利技术实施例提供了一种模型训练方法及模型训练装置,用于在图像翻译模型的训练数据较少时,也可以提升图像翻译模型的图像翻译质量,从而使得图像翻译模型在实现数字人姿态仿真时,提升数字人姿态仿真的准确率。本申请实施例第一方面提供了一种模型训练方法,包括:利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为生成对抗网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。优选的,所述目标帧图像的前M帧图像数据,包括:所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据。优选的,所述目标帧图像的前M帧图像数据,包括:所述目标帧图像的前M帧降低像素后的图像和所述前M帧图像的轮廓线数据。优选的,所述图像翻译模型中的生成器为编码模型-解码模型结构,所述利用所述老师模型对所述学生模型进行知识蒸馏,包括:将所述老师模型中的判别器作为所述学生模型中的判别器;根据所述学生模型的损失函数,计算所述学生模型的第一损失;计算所述老师模型中第一隐藏变量与所述学生模型中第二隐藏变量之间的第二损失,其中,所述第一隐藏变量为所述老师模型的编码模型与解码模型之间的隐藏变量,所述第二隐藏变量为所述学生模型的编码模型与解码模型之间的隐藏变量;计算所述老师模型的生成器和所述学生模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失;根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。优选的,所述根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新,包括:获取所述第一损失、所述第二损失和所述第三损失对应的权重;根据所述第一损失、所述第二损失和所述第三损失以及对应的权重,计算目标损失;根据所述目标损失和反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。优选的,所述图像翻译模型包括pix2pix模型、pix2pixHD模型和vid2vid模型中的至少一种。本申请实施例第二方面提供了一种模型训练装置,包括:第一训练单元,用于利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为生成对抗网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;第二训练单元,用于利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;知识蒸馏单元,用于利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。优选的,所述目标帧图像的前M帧图像数据,包括:所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据。优选的,所述目标帧图像的前M帧图像数据,包括:所述目标帧图像的前M帧降低像素后的图像和所述前M帧图像的轮廓线数据。优选的,所述图像翻译模型中的生成器为编码模型-解码模型结构,所述知识蒸馏单元,包括:设置模块,用于将所述老师模型中的判别器作为所述学生模型中的判别器;第一计算模块,用于根据所述学生模型的损失函数,计算所述学生模型的第一损失;第二计算模块,用于计算所述老师模型中第一隐藏变量与所述学生模型中第二隐藏变量之间的第二损失,其中,所述第一隐藏变量为所述老师模型的编码模型与解码模型之间的隐藏变量,所述第二隐藏变量为所述学生模型的编码模型与解码模型之间的隐藏变量;第三计算模块,用于计算所述老师模型的生成器和所述学生模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失;更新模块,用于根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。具体的,所述更新模块具体包括:获取子模块,用于获取所述第一损失、所述第二损失和所述第三损失对应的权重;计算子模块,用于根据所述第一损失、所述第二损失和所述第三损失以及对应的权重,计算目标损失;更新子模块,用于根据所述目标损失和反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。优选的,所述图像翻译模型包括pix2pix模型、pix2pixHD模型和vid2vid模型中的至少一种。本申请实施例第三方面提供了一种计算机装置,包括处理器,所述处理器在执行存储于存储器上的计算机程序时,用于实现本申请实施例第一方面所述的模型训练方法。本申请实施例第四方面提供了一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执本文档来自技高网...

【技术保护点】
1.一种模型训练方法,其特征在于,所述方法包括:/n利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为生成对抗网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;/n利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;/n利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。/n

【技术特征摘要】
1.一种模型训练方法,其特征在于,所述方法包括:
利用训练数集对图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为老师模型,其中,所述图像翻译模型为生成对抗网络模型,所述训练数集包括目标帧图像、目标帧图像的轮廓线数据、目标帧图像的距离图像数据和所述目标帧图像的前N帧图像数据,其中,所述前N帧图像数据包括前N帧图像、所述前N帧图像的轮廓线数据和所述前N帧图像的距离图像数据,所述N为大于等于2的整数,所述目标帧图像为所述训练数集中的除前两帧以外的任意一帧或任意多帧图像;
利用所述训练数集中的第一数据对所述图像翻译模型的生成器和判别器分别进行训练,并将训练后的图像翻译模型确定为学生模型,其中,所述第一数据包括所述目标帧图像、所述目标帧图像的轮廓线数据和所述目标帧图像的前M帧图像数据,其中,所述M为大于等于1且小于等于N的整数;
利用所述老师模型对所述学生模型进行知识蒸馏,以得到知识蒸馏后的学生模型。


2.根据权利要求1所述的方法,其特征在于,所述目标帧图像的前M帧图像数据,包括:
所述目标帧图像的前M帧图像和所述前M帧图像的轮廓线数据。


3.根据权利要求1所述的方法,其特征在于,所述目标帧图像的前M帧图像数据,包括:
所述目标帧图像的前M帧降低像素后的图像和所述前M帧图像的轮廓线数据。


4.根据权利要求1-3中任一项所述的方法,其特征在于,所述图像翻译模型中的生成器为编码模型-解码模型结构,所述利用所述老师模型对所述学生模型进行知识蒸馏,包括:
将所述老师模型中的判别器作为所述学生模型中的判别器;
根据所述学生模型的损失函数,计算所述学生模型的第一损失;
计算所述老师模型中第一隐藏变量与所述学生模型中第二隐藏变量之间的第二损失,其中,所述第一隐藏变量为所述老师模型的编码模型与解码模型之间的隐藏变量,所述第二隐藏变量为所述学生模型的编码模型与解码模型之间的隐藏变量;
计算所述老师模型的生成器和所述学生模型的生成器在输入相同的目标帧图像时,得到两个目标帧生成图像之间的第三损失;
根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新。


5.根据权利要求4所述的方法,其特征在于,所述根据所述第一损失、所述第二损失和所述第三损失中的至少一项,及反向传播算法,对所述学生模型中生成器的卷积层的权重进行梯度更新,包括:
获取所述第一损失、所述第二损失和所述第三损失对应的权重;
根据所述第一损失、所述第二损失和所述第三损失...

【专利技术属性】
技术研发人员:王鑫宇刘炫鹏杨国基刘致远刘云峰
申请(专利权)人:深圳追一科技有限公司
类型:发明
国别省市:广东;44

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1