计算方法、装置、计算机设备和存储介质制造方法及图纸

技术编号:38946713 阅读:13 留言:0更新日期:2023-09-25 09:43
本公开涉及一种计算方法、装置、计算机设备和存储介质。该方法包括:基于生成网络模型对样本图像进行恢复处理,生成第一输出图像;根据第一输出图像和第一损失函数计算出第一梯度;根据判别网络模型对第一输出图像的第一判别结果,计算出第二梯度;根据判别网络模型对与样本图像对应的标签图像的第二判别结果和第二损失函数,计算出第二梯度偏置;将第二梯度与第二梯度偏置的差值作为第三梯度;根据第一梯度和第三梯度,计算出目标梯度;根据目标梯度对生成网络模型的参数进行更新,以完成生成网络模型的模型训练。可以提高模型训练速度,平衡训练过程中两损失函数的优化方向的冲突,且可以提高训练得到的模型的精度和准确度。度。度。

【技术实现步骤摘要】
计算方法、装置、计算机设备和存储介质


[0001]本公开涉及计算机
,特别是涉及一种计算方法、装置、计算机设备和存储介质。

技术介绍

[0002]随着神经网络技术的发展,神经网络技术在图像重建恢复中也得到了广泛的发展和应用。为得到更好的能够对输入图像进行恢复输出更接近于真实图像的生成器(generator),通过同步对生成器和判别器(discriminator,也称鉴别器)进行训练的方式,得到所需的生成器。相关技术中,利用重建损失函数和对抗损失函数(adversarial loss)进行模型训练得到的生成器,但是训练好的生成器的输出图像存在边缘变形、出现错误纹理、偏色等问题。

技术实现思路

[0003]基于此,有必要针对上述技术问题,提供一种计算方法、装置、计算机设备和存储介质。
[0004]根据本公开的一方面,提供了一种计算方法,所述方法包括:
[0005]基于生成网络模型对样本图像进行恢复处理,生成第一输出图像;
[0006]根据所述第一输出图像和第一损失函数计算出第一梯度;
[0007]根据判别网络模型对所述第一输出图像的第一判别结果,计算出第二梯度;
[0008]根据判别网络模型对与所述样本图像对应的标签图像的第二判别结果和第二损失函数,计算出第二梯度偏置;
[0009]将所述第二梯度与所述第二梯度偏置的差值作为第三梯度;
[0010]根据所述第一梯度和所述第三梯度,计算出目标梯度;
[0011]根据所述目标梯度对所述生成网络模型的参数进行更新,以完成所述生成网络模型的模型训练。
[0012]在一种可能的实现方式中,根据所述第一梯度和所述第三梯度,计算出目标梯度,包括:
[0013]根据所述第一梯度的第一p

范数和所述第三梯度的第三p

范数,对所述第一梯度和/或所述第三梯度进行调整,以使调整后第一梯度与第三梯度的p

范数一致;
[0014]将p

范数一致的第一梯度和第三梯度的和作为所述目标梯度。
[0015]在一种可能的实现方式中,根据所述第一梯度的第一p

范数和所述第三梯度的第三p

范数,对所述第一梯度和/或所述第三梯度进行调整,以使调整后第一梯度与第三梯度的p

范数一致,包括:
[0016]根据所述第一p

范数和所述第三p

范数,确定出目标p

范数,所述目标p

范数大于或等于所述第一p

范数和所述第三p

范数中的最小值、且小于或等于所述第一p

范数和所述第三p

范数中的最大值;
[0017]将所述第三梯度和所述第一梯度中p

范数与所述目标p

范数不一致的梯度进行调整,以使调整后第一梯度与第三梯度的p

范数一致。
[0018]在一种可能的实现方式中,根据所述第一梯度和所述第三梯度,计算出目标梯度,包括:
[0019]在根据所述第一梯度和所述第三梯度,计算出目标梯度之前,根据梯度调整模型对所述第一梯度和所述第三梯度进行调整。
[0020]在一种可能的实现方式中,所述方法包括:
[0021]根据所述目标梯度对所述生成网络模型的参数进行更新之后,基于更新后生成网络模型对样本图像进行恢复处理,生成第二输出图像;
[0022]根据所述第二输出图像和第一损失函数,计算出第一损失值;
[0023]根据判别网络模型对所述第二输出图像的判别结果和第二损失函数,计算出更新用第二损失值;
[0024]根据所述更新用第一损失值和所述更新用第二损失值计算出所述梯度值调整模型的第三损失值;
[0025]根据所述第三损失值对所述梯度调整模型进行更新,已完成所述损失值调整模型的模型训练。
[0026]在一种可能的实现方式中,第一损失函数包括重建损失函数,第二损失函数包括对抗损失函数。
[0027]根据本公开的另一方面,提供了一种计算装置,所述装置包括:
[0028]第一图像获取模块,用于基于生成网络模型对样本图像进行恢复处理,生成第一输出图像;
[0029]第一梯度获取模块,用于根据所述第一输出图像和第一损失函数计算出第一梯度;
[0030]第二梯度获取模块,用于根据判别网络模型对所述第一输出图像的第一判别结果,计算出第二梯度;
[0031]偏置获取模块,用于根据判别网络模型对与所述样本图像对应的标签图像的第二判别结果和第二损失函数,计算出第二梯度偏置;
[0032]第三梯度获取模块,用于将所述第二梯度与所述第二梯度偏置的差值作为第三梯度;
[0033]目标梯度获取模块,用于根据所述第一梯度和所述第三梯度,计算出目标梯度;
[0034]第一更新模块,用于根据所述目标梯度对所述生成网络模型的参数进行更新,以完成所述生成网络模型的模型训练。
[0035]在一种可能的实现方式中,所述目标梯度获取模块,包括:
[0036]第一梯度调整子模块,用于根据所述第一梯度的第一p

范数和所述第三梯度的第三p

范数,对所述第一梯度和/或所述第三梯度进行调整,以使调整后第一梯度与第三梯度的p

范数一致;
[0037]梯度计算子模块,用于将p

范数一致的第一梯度和第三梯度的和作为所述目标梯度。
[0038]在一种可能的实现方式中,所述第一梯度调整子模块,包括:
[0039]目标范数确定子模块,用于根据所述第一p

范数和所述第三p

范数,确定出目标p

范数,所述目标p

范数大于或等于所述第一p

范数和所述第三p

范数中的最小值、且小于或等于所述第一p

范数和所述第三p

范数中的最大值;
[0040]调整子模块,用于将所述第三梯度和所述第一梯度中p

范数与所述目标p

范数不一致的梯度进行调整,以使调整后第一梯度与第三梯度的p

范数一致。
[0041]在一种可能的实现方式中,所述目标梯度获取模块,包括:
[0042]第二梯度调整子模块,用于在根据所述第一梯度和所述第三梯度,计算出目标梯度之前,根据梯度调整模型对所述第一梯度和所述第三梯度进行调整。
[0043]在一种可能的实现方式中,所述装置还包括:
[0044]第二图像获取模块,用于根据所述目标梯度对所述生成网络模型的参数进行更新之后,基于更新后生成网络模型本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种计算方法,其特征在于,所述方法包括:基于生成网络模型对样本图像进行恢复处理,生成第一输出图像;根据所述第一输出图像和第一损失函数计算出第一梯度;根据判别网络模型对所述第一输出图像的第一判别结果,计算出第二梯度;根据判别网络模型对与所述样本图像对应的标签图像的第二判别结果和第二损失函数,计算出第二梯度偏置;将所述第二梯度与所述第二梯度偏置的差值作为第三梯度;根据所述第一梯度和所述第三梯度,计算出目标梯度;根据所述目标梯度对所述生成网络模型的参数进行更新,以完成所述生成网络模型的模型训练。2.根据权利要求1所述的方法,其特征在于,根据所述第一梯度和所述第三梯度,计算出目标梯度,包括:根据所述第一梯度的第一p

范数和所述第三梯度的第三p

范数,对所述第一梯度和/或所述第三梯度进行调整,以使调整后第一梯度与第三梯度的p

范数一致;将p

范数一致的第一梯度和第三梯度的和作为所述目标梯度。3.根据权利要求2所述的方法,其特征在于,根据所述第一梯度的第一p

范数和所述第三梯度的第三p

范数,对所述第一梯度和/或所述第三梯度进行调整,以使调整后第一梯度与第三梯度的p

范数一致,包括:根据所述第一p

范数和所述第三p

范数,确定出目标p

范数,所述目标p

范数大于或等于所述第一p

范数和所述第三p

范数中的最小值、且小于或等于所述第一p

范数和所述第三p

范数中的最大值;将所述第三梯度和所述第一梯度中p

范数与所述目标p

范数不一致的梯度进行调整,以使调整后第一梯度与第三梯度的p

范数一致。4.根据权利要求1

3任意一项所述的方法,其特征在于,根据所述第一梯度和所述第三梯度,计算出目标梯度,包括:在根据所述第一梯度和所述第三梯度,计算出目标梯度之前,根据梯度调整模型对所述第一梯度和所述第三梯度进行调整。5.根据权利要求4所述的方法,其特征在于,所述方法包括:根据所述目标梯度对所述生成网络模型的参数进行更新之后,基于更新后生成网络模型对样本图像进行恢复处理,生成第二输出图像;根据所述第二输出图像和第一损失函数,计算出第一损失值;根据判别网络模型对所述第二输出图像的判别结果和第二损失函数,计算出更新用第二损失值;根据所述更新用第一损失值和所述更新用第二损失值计算出所述梯度值调整模型的第三损失值;根据所述第三损失值对所述梯度调整模型进行更新,已完成所述损失值调整模型的模型训练。6.根据权利要求1

5任意一项所述的方法,其特征在于,第一损失函数包括重建损失函数,第二损失函数包括对抗损失函数。
7.一种计算装置,其特征在于,所述装置包括:第一图像获取模块,用于基于生成网络模型对样本图像进行恢复处理,生成第一输出图像;第一梯度获取模块,用于根据所述第一输出图像和第一损失函数计算出第一梯度;第二梯度获取模块,用于根据判别网络模型对所述第一输出图像的第一判别结果,计算出第二梯度;偏置获取模块,用于根据判别网络模型对与所述样本图像对应的标签图像的第二判别结果和第二损失函数,计算出第二梯度偏置;第三梯度获取模块,用于将所述第二梯度与所述第二梯度偏置的差值作为第三梯度;目标梯度获取模块,用于根据所述第一梯度...

【专利技术属性】
技术研发人员:请求不公布姓名
申请(专利权)人:寒武纪昆山信息科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1