模型训练方法及相关装置、可读存储介质制造方法及图纸

技术编号:28676827 阅读:19 留言:0更新日期:2021-06-02 02:53
本申请实施例公开了一种模型训练方法及模型训练装置,用于提升图像翻译模型的推理速度。本发明专利技术实施例方法包括:采用第一数据和第二数据对原始图像翻译模型中的原始生成模型进行训练,以得到当前帧的第一生成图片,第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,第二数据包括当前帧的原始图片和前两帧的原始图片;对原始生成模型执行fine‑tune微调操作,得到一代生成模型,微调操作包括根据预设的损失函数,计算当前帧的原始图片与当前帧的第一生成图片的第一损失,根据第一损失及反向传播算法,对原始生成模型中卷积层的权重进行梯度更新,一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,一代生成模型的图像生成质量不大于预设的FID值。

【技术实现步骤摘要】
模型训练方法及相关装置、可读存储介质
本专利技术涉及图像翻译
,尤其涉及模型训练方法及相关装置、可读存储介质。
技术介绍
所谓图像翻译,指从一副图像到另一副图像的转换。可以类比机器翻译,将一种语言转换为另一种语言。现有技术中较为经典的图像翻译模型有pix2pix,pix2pixHD,vid2vid。pix2pix提出了一个统一的框架解决了各类图像翻译问题,pix2pixHD则在pix2pix的基础上,较好的解决了高分辨率图像转换(翻译)的问题,vid2vid则在pix2pixHD的基础上,较好的解决了高分辨率的视频转换问题。但目前的vid2vid模型,如Nvidia的vid2vid中的头部姿态翻译模型,在实际训练过程中,因为其采用的GAN模型数据计算量大,如目前的头部姿态翻译模型需要输入第一部分数据和第二部分数据,其中,第一部分数据包括当前帧和前两帧的轮廓线,以及当前帧和前两帧的distanceMap数据,进一步,每一帧的轮廓线为1维数据,则当前帧和前两帧的轮廓线共3维数据,而每一帧的distanceMap包括4维数据,则本文档来自技高网...

【技术保护点】
1.一种模型训练方法,其特征在于,所述方法包括:/n采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的第一生成图片,所述第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,所述第二数据包括当前帧的原始图片和前两帧的原始图片;/n对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第一生成图片之间的第一损失,根据所述第一损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一...

【技术特征摘要】
1.一种模型训练方法,其特征在于,所述方法包括:
采用第一数据和第二数据对原始图像翻译模型中的GAN网络中的原始生成模型进行训练,以得到当前帧的第一生成图片,所述第一数据包括当前帧的轮廓线数据和前两帧的轮廓线数据,所述第二数据包括当前帧的原始图片和前两帧的原始图片;
对所述原始生成模型执行fine-tune微调操作,直至得到一代生成模型,所述微调操作包括根据预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第一生成图片之间的第一损失,根据所述第一损失及反向传播算法,对所述原始生成模型中卷积层的权重进行梯度更新,其中,所述一代生成模型为一代图像翻译模型中的GAN网络中的生成模型,且所述一代生成模型的图像生成质量不大于预设的FID值。


2.根据权利要求1所述的模型训练方法,其特征在于,所述方法还包括:
在对所述原始生成模型中卷积层的权重进行梯度更新的过程中,减少对所述原始生成模型中卷积层的学习率。


3.根据权利要求2所述的方法,其特征在于,所述方法还包括:
采用第三数据和第四数据对所述一代生成模型进行训练,以得到当前帧的第二生成图片,所述第三数据包括当前帧的轮廓线数据和前一帧的轮廓线数据,所述第四数据包括当前帧的原始图片和前一帧的原始图片;
对所述一代生成模型执行fine-tune微调操作,直至得到二代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第二生成图片之间的第二损失,根据所述第二损失及所述反向传播算法,对所述一代生成模型中卷积层的权重进行梯度更新,其中,所述二代生成模型为二代图像翻译模型中的GAN网络中的生成模型,且所述二代生成模型的图像生成质量不大于所述预设的FID值。


4.根据权利要求3所述的方法,其特征在于,所述方法还包括:
在对所述一代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述一代生成模型中卷积层的学习率。


5.根据权利要求4所述的方法,其特征在于,所述方法还包括:
采用第五数据和第六数据对所述二代生成模型进行训练,以得到当前帧的第三生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第六数据包括当前帧的原始图片和第一帧的原始图片;
对所述二代生成模型执行fine-tune微调操作,直至得到三代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第三生成图片之间的第三损失,根据所述第三损失及所述反向传播算法,对所述二代生成模型中卷积层的权重进行梯度更新,其中,所述三代生成模型为三代图像翻译模型中的GAN网络中的生成模型,且所述三代生成模型的图像生成质量不大于所述预设的FID值。


6.根据权利要求5所述的方法,其特征在于,所述方法还包括:
在对所述二代生成模型中卷积层的权重进行梯度更新的过程中,减少对所述二代生成模型中卷积层的学习率。


7.根据权利要求6所述的方法,其特征在于,所述方法还包括:
采用所述第五数据和第七数据对所述三代生成模型进行训练,以得到当前帧的第四生成图片,所述第五数据包括当前帧的轮廓线数据和第一帧的轮廓线数据,所述第七数据包括当前帧的原始图片,及降低像素后的第一帧的图片;
对所述三代生成模型执行fine-tune微调操作,直至得到四代生成模型,所述微调操作包括根据所述预设的损失函数,计算所述当前帧的原始图片与所述当前帧的第四生成图片之间的第四损...

【专利技术属性】
技术研发人员:王鑫宇刘志远杨国基刘炫鹏陈泷翔
申请(专利权)人:深圳追一科技有限公司
类型:发明
国别省市:广东;44

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1