模型训练方法、装置、设备及存储介质制造方法及图纸

技术编号:34851067 阅读:30 留言:0更新日期:2022-09-08 07:51
本公开实施例公开了一种模型训练方法、装置、设备及存储介质,其中,所述方法包括:损失函数缩放器创建接口基于缩放器构造参数,创建损失函数缩放器;在对深度学习模型进行的每一轮次混合精度训练的前向传播过程中,损失函数缩放器缩放接口基于所述损失函数缩放器,对所述深度学习模型在当前轮次混合精度训练中输出的损失值进行放大处理,得到放大后的所述损失值;在每一所述轮次混合精度训练的反向传播过程中,损失函数优化器迭代接口利用设定的优化器和所述损失函数缩放器,基于放大后的所述损失值,对所述深度学习模型中的网络参数进行更新,得到训练后的所述深度学习模型。得到训练后的所述深度学习模型。得到训练后的所述深度学习模型。

【技术实现步骤摘要】
模型训练方法、装置、设备及存储介质


[0001]本公开涉及但不限于人工智能
,尤其涉及一种模型训练方法、装置、设备及存储介质。

技术介绍

[0002]通常在对深度学习模型进行训练的过程中,深度学习模型中的网络参数的数据类型是统一的,一般为32位浮点(fp32)类型,即全精度浮点类型。出于减少显存消耗或者提高训练速度的考虑,可以将深度学习模型中的一些网络参数采用16位浮点(fp16)类型,即半精度浮点类型。但是如果模型中所有的参数都是fp16,则容易出现精度不足、数据溢出等问题,影响模型训练的正常进行或者影响模型训练的效果。相关技术中,通过采用混合精度训练技术对深度学习模型进行训练,把深度学习模型中的部分网络参数转换为fp16类型,在训练时一部分计算使用fp32精度进行,一部分计算使用fp16精度进行,以提高模型精度的同时降低模型训练过程中的存储空间占用以及执行时间。
[0003]但是,相关技术中对深度学习模型进行混合精度训练通常需要用户手动在算法中开发代码实现,或由第三方工具实现,或由训练框架实现,从而导致实现方式多样,难以统一和本文档来自技高网...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,所述方法包括:损失函数缩放器创建接口基于缩放器构造参数,创建损失函数缩放器;在对深度学习模型进行的每一轮次混合精度训练的前向传播过程中,损失函数缩放器缩放接口基于所述损失函数缩放器,对所述深度学习模型在当前轮次混合精度训练中输出的损失值进行放大处理,得到放大后的所述损失值;在每一所述轮次混合精度训练的反向传播过程中,损失函数优化器迭代接口利用设定的优化器和所述损失函数缩放器,基于放大后的所述损失值,对所述深度学习模型中的网络参数进行更新,得到训练后的所述深度学习模型。2.根据权利要求1所述的方法,其特征在于,所述损失函数缩放器中包括目标缩放值,所述缩放器构造参数包括所述目标缩放值的初始值和所述目标缩放值的缩放更新参数;所述损失函数缩放器创建接口基于缩放器构造参数,创建损失函数缩放器,包括:损失函数缩放器创建接口基于所述初始值和所述缩放更新参数,创建所述损失函数缩放器;其中,所述缩放更新参数包括以下至少之一:成长周期、成长系数和回退系数。3.根据权利要求2所述的方法,其特征在于,所述方法还包括以下至少之一:所述损失函数缩放器缩放接口基于所述缩放更新参数,对当前的所述目标缩放值进行更新,得到更新后的所述目标缩放值;所述损失函数优化器迭代接口基于所述缩放更新参数,对当前的所述目标缩放值进行更新,得到更新后的所述目标缩放值。4.根据权利要求3所述的方法,其特征在于,在所述缩放更新参数包括成长周期和成长系数的情况下,所述损失函数缩放器缩放接口基于所述缩放更新参数,对当前的所述目标缩放值进行更新,得到更新后的所述目标缩放值,包括:所述损失函数缩放器缩放接口确定当前连续得到的放大后的第二损失值均未发生溢出的混合精度训练的第一轮数;所述损失函数缩放器缩放接口在所述第一轮数达到所述成长周期的情况下,基于所述成长系数对当前的所述目标缩放值进行放大处理,得到放大后的所述目标缩放量。5.根据权利要求3或4所述的方法,其特征在于,在所述缩放更新参数包括回退系数的情况下,所述损失函数缩放器缩放接口基于所述缩放更新参数,对当前的所述目标缩放值进行更新,得到更新后的所述目标缩放值,包括:所述损失函数缩放器缩放接口在当前轮次混合精度训练中放大后的第二损失值发生溢出的情况下,基于所述回退系数对当前的所述目标缩放值进行缩小处理,得到缩小后的所述目标缩放量,并在当前轮次混合精度训练中停止对所述深度学习模型中的网络参数进行更新。6.根据权利要求3至5中任一项所述的方法,其特征在于,在所述缩放更新参数包括成长周期和成长系数的情况下,所述损失函数优化器迭代接口基于所述缩放更新参数,对当前的所述目标缩放值进行更新,得到更新后的所述目标缩放值,包括:所述损失函数优化器迭代接口确定当前连续对所述深度学习模型中的网络参数进行更新的过程中所述网络参数均未发生溢出的混合精度训练的第二轮数;所述损失函数优化器迭代接口在所述第二轮数达到所述成长周期的情况下,基于所述
成长系数对当前的所述目标缩放值进行放大处理,得到放大后的所述目标缩放量。7.根据权利要求3至6中任一项所述的方法,其特征在于,在所述缩放更新参数包括回退系数的情况下,所述损失函数优化器迭代接口基于所述缩放更新参数,对当前的所述目标缩放值进行更新,得到更新后的所述目标缩放值,包括:所述损失函数优化器迭代接口在当前轮次混合精度训练中对所述深度学习模型中的网络参数进行更新的过程中所述网络参数发生溢出的情况下,基于所述回退系数对当前的所述目标缩放值进行缩小处理,得到缩小后的所述目标缩放量,并在当前轮次混合精度训练中停止对所述深度学习模型中的网络参数进行更新。8.根据权利要求1至7中任一项所述的方法,其特征在于,所述方法还包括:精度转换接口基于设定的混合精度开关参数,将所述深度学习模型中设定的网络层的全精度网络参数转换为半精度网络参数,得到精度转换后的所述深度学习模型,并对所述全精度网络参数进行备份;所述损失函数缩放器缩放接口基于所述损失函数缩放器,对所述深度学习模型在当前轮次混合精度训练中输出的损失值进行放大处理,得到放大...

【专利技术属性】
技术研发人员:罗培超张行程
申请(专利权)人:上海商汤智能科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1