【技术实现步骤摘要】
模型训练方法、装置以及设备
[0001]本公开涉及人工智能
,具体涉及深度学习等
技术介绍
[0002]随着深度学习技术的发展,训练更大的模型成为一种提高模型精度性能的主流趋势。数据并行训练是一种分布式训练技术。每个设备维护相同的模型状态,但输入不同的数据进行计算,通过增加设备数量提升训练整体的计算吞吐,进而提升模型训练速度。
[0003]然而,更大的模型意味着更多的模型状态,更多的模型状态意味着更大的显存需求。但当前主流训练设备显存容量的增长远远跟不上模型大小增长对显存的需求。大模型训练过程中的显存需求远大于单个训练设备的实际物理显存大小,传统的模型训练技术在大模型场景下会出现过显存溢出问题。
技术实现思路
[0004]本公开实施例提出了一种模型训练方法、装置、设备、存储介质以及程序产品。
[0005]第一方面,本公开实施例提出了一种模型训练方法,包括:将模型的未切分模型状态集合和第i个已切分模型状态集合存储到第i个设备的显存中,其中,模型包含1个未切分模型状态集合和n个已切分模型状态集合,i和n为正整数,n大于1,i不小于1且不大于n;获取第i个训练样本集,以及通过设备间通信从其他n
‑
1个设备获取其他n
‑
1个已切分模型状态集合,对模型进行训练,得到模型的各个模型状态的第一梯度;基于模型的各个模型状态的第一梯度更新未切分模型状态集合和第i个已切分模型状态集合。
[0006]第二方面,本公开实施例提出了一种语音降噪方法,包 ...
【技术保护点】
【技术特征摘要】
1.一种模型训练方法,包括:将模型的未切分模型状态集合和第i个已切分模型状态集合存储到第i个设备的显存中,其中,所述模型包含1个未切分模型状态集合和n个已切分模型状态集合,i和n为正整数,n大于1,i不小于1且不大于n;获取第i个训练样本集,以及通过设备间通信从其他n
‑
1个设备获取其他n
‑
1个已切分模型状态集合,对所述模型进行训练,得到所述模型的各个模型状态的第一梯度;基于所述模型的各个模型状态的第一梯度更新所述未切分模型状态集合和所述第i个已切分模型状态集合。2.根据权利要求1所述的方法,其中,所述方法还包括:通过设备间通信将所述其他n
‑
1个已切分模型状态集合中的模型状态的第一梯度对应传递给所述其他n
‑
1个设备。3.根据权利要求1或2所述的方法,其中,所述方法还包括:利用设备间通信获取所述其他n
‑
1个设备传递的所述未切分模型状态集合和所述其他n
‑
1个已切分模型状态集合中的模型状态的第二梯度;以及所述基于所述模型的各个模型状态的第一梯度更新所述未切分模型状态集合和所述第i个已切分模型状态集合,包括:对于所述未切分模型状态集合和所述第i个已切分模型状态集合中的模型状态,基于该模型状态的第一梯度和第二梯度,计算该模型状态的目标梯度,以及基于该模型状态的目标梯度更新该模型状态。4.根据权利要求1
‑
3中任一项所述的方法,其中,所述对所述模型进行训练,得到所述模型的各个模型状态的第一梯度,包括:基于所述第i个训练样本集,对所述模型进行前向传播,计算得到中间变量和损失;基于所述中间变量和所述损失,对所述模型进行反向传播,计算得到所述模型的各个模型状态的第一梯度;以及所述基于所述模型的各个模型状态的第一梯度更新所述未切分模型状态集合和所述第i个已切分模型状态集合,包括:对于所述未切分模型状态集合和所述第i个已切分模型状态集合中的模型状态,基于该模型状态的第一梯度和所述模型的优化器状态更新该模型状态。5.根据权利要求1
‑
4中任一项所述的方法,其中,在所述将模型的未切分模型状态集合和第i个已切分模型状态集合存储到第i个设备的显存中之前,还包括:基于所述模型的模型状态总量和所述n个设备的可用显存量,对所述模型的模型状态进行切分,得到所述已切分模型状态集合和所述n个已切分模型状态集合。6.根据权利要求5所述的方法,其中,所述基于所述模型的模型状态总量和所述n个设备的可用显存量,对所述模型的模型状态进行切分,得到所述已切分模型状态集合和所述n个已切分模型状态集合,包括:获取所述n个设备的剩余显存量;基于所述模型的模型状态总量和所述n个设备的剩余显存量,确定模型状态切分数量;基于所述模型状态切分数量选取出需要进行切分的模型状态,其中,剩余的模型状态组成所述未切分模型状态集合;
基于所述n个设备的剩余显存量对所选取的模型状态依次切分,得到所述n个已切分模型状态集合。7.根据权利要求6所述的方法,其中,所述获取所述n个设备的剩余显存量,包括:通过所述第i个设备的应用程序编程接口获取所述第i个设备的可用显存量;基于所述第i个设备的可用显存量,确定所述第i个设备的剩余显存量;通过设备间同步获取所述其他n
‑
1个设备的剩余显存量。8.根据权利要求7所述的方法,其中,所述基于所述第i个设备的可用显存量,确定所述第i个设备的剩余显存量,包括:获取所述第i个设备的深度学习计算图,其中,所述深度学习计算图可以表达训练过程中的计算和数据;基于所述深度学习计算图计算中间变量的占用显存量;计算所述第i个设备的可用显存量与所述中间变量的占用显存量之差,得到所述第i个设备的剩余显存量;以及在所述基于所述n个设备的剩余显存量对所选取的模型状态依次切分,得到所述n个已切分模型状态集合之后,还包括:将所述其他n
‑
1个模型状态集合中的模型状态从所述第i个设备的深度学习计算图上删除。9.一种模型训练装置,包括:存储模块,被配置成将模型的未切分模型状态集合和第i个已切分模型状态集合存储到第i个设备的显存中,其中,所述模型包含1个未切分模型状态集合和n...
【专利技术属性】
技术研发人员:梁建中,敖玉龙,于佃海,
申请(专利权)人:北京百度网讯科技有限公司,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。