模型训练方法、装置以及设备制造方法及图纸

技术编号:36227011 阅读:9 留言:0更新日期:2023-01-04 12:26
本公开提供了一种模型训练方法、装置以及设备,涉及人工智能技术领域,具体涉及深度学习等技术领域。该方法的一具体实施方式包括:将模型的未切分模型状态集合和第i个已切分模型状态集合存储到第i个设备的显存中;获取第i个训练样本集,以及通过设备间通信从其他n

【技术实现步骤摘要】
模型训练方法、装置以及设备


[0001]本公开涉及人工智能
,具体涉及深度学习等


技术介绍

[0002]随着深度学习技术的发展,训练更大的模型成为一种提高模型精度性能的主流趋势。数据并行训练是一种分布式训练技术。每个设备维护相同的模型状态,但输入不同的数据进行计算,通过增加设备数量提升训练整体的计算吞吐,进而提升模型训练速度。
[0003]然而,更大的模型意味着更多的模型状态,更多的模型状态意味着更大的显存需求。但当前主流训练设备显存容量的增长远远跟不上模型大小增长对显存的需求。大模型训练过程中的显存需求远大于单个训练设备的实际物理显存大小,传统的模型训练技术在大模型场景下会出现过显存溢出问题。

技术实现思路

[0004]本公开实施例提出了一种模型训练方法、装置、设备、存储介质以及程序产品。
[0005]第一方面,本公开实施例提出了一种模型训练方法,包括:将模型的未切分模型状态集合和第i个已切分模型状态集合存储到第i个设备的显存中,其中,模型包含1个未切分模型状态集合和n个已切分模型状态集合,i和n为正整数,n大于1,i不小于1且不大于n;获取第i个训练样本集,以及通过设备间通信从其他n

1个设备获取其他n

1个已切分模型状态集合,对模型进行训练,得到模型的各个模型状态的第一梯度;基于模型的各个模型状态的第一梯度更新未切分模型状态集合和第i个已切分模型状态集合。
[0006]第二方面,本公开实施例提出了一种语音降噪方法,包括:存储模块,被配置成将模型的未切分模型状态集合和第i个已切分模型状态集合存储到第i个设备的显存中,其中,模型包含1个未切分模型状态集合和n个已切分模型状态集合,i和n为正整数,n大于1,i不小于1且不大于n;训练模块,被配置成获取第i个训练样本集,以及通过设备间通信从其他n

1个设备获取其他n

1个已切分模型状态集合,对模型进行训练,得到模型的各个模型状态的第一梯度;更新模块,被配置成基于模型的各个模型状态的第一梯度更新未切分模型状态集合和第i个已切分模型状态集合。
[0007]第三方面,本公开实施例提出了一种电子设备,包括:至少两个处理器;以及与至少两个处理器通信连接的存储器;其中,存储器存储有可被至少两个处理器执行的指令,指令被至少两个处理器执行,以使至少两个处理器能够执行如第一方面中任一实现方式描述的方法。
[0008]第四方面,本公开实施例提出了一种存储有计算机指令的非瞬时计算机可读存储介质,计算机指令用于使计算机执行如第一方面中任一实现方式描述的方法。
[0009]第五方面,本公开实施例提出了一种计算机程序产品,包括计算机程序,计算机程序在被处理器执行时实现如第一方面中任一实现方式描述的方法。
[0010]本公开实施例提供的模型训练方法,通过模型状态切分对数据并行训练技术进行
优化。通过在集群硬件设备上切分模型状态,让不同设备维护和存储不同的模型状态,在大模型训练过程中每个硬件上的模型状态的显存需求都不会超出设备实际显存的大小,让训练能够正常进行。并且,仅对部分模型状态进行切分,减少了模型状态切分数量,从而减少了训练过程中需要进行通信的模型状态的数量,进而减少了模型状态通信时间,优化训练速度,满足更好更快训练深度学习大模型的需求。
[0011]应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
[0012]通过阅读参照以下附图所作的对非限制性实施例所作的详细描述,本公开的其它特征、目的和优点将会变得更明显。附图用于更好地理解本方案,不构成对本公开的限定。其中:
[0013]图1是根据本公开的模型训练方法的一个实施例的流程图;
[0014]图2是根据本公开的模型训练方法的又一个实施例的流程图;
[0015]图3是根据本公开的模型状态切分方法的一个实施例的流程图;
[0016]图4是根据本公开的模型状态切分方法的又一个实施例的流程图;
[0017]图5是前向计算的深度学习计算图;
[0018]图6是反向计算的深度学习计算图;
[0019]图7是优化器更新的深度学习计算图;
[0020]图8是模型训练过程中的参数通信图;
[0021]图9是模型状态切分图;
[0022]图10是根据本公开的模型训练装置的一个实施例的结构示意图;
[0023]图11是用来实现本公开实施例的模型训练方法的电子设备的框图。
具体实施方式
[0024]以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
[0025]需要说明的是,在不冲突的情况下,本公开中的实施例及实施例中的特征可以相互组合。下面将参考附图并结合实施例来详细说明本公开。
[0026]图1示出了根据本公开的模型训练方法的一个实施例的流程100。该模型训练方法包括以下步骤:
[0027]步骤101,将模型的未切分模型状态集合和第i个已切分模型状态集合存储到第i个设备的显存中。
[0028]在本实施例中,第i个设备可以将模型的未切分模型状态集合和第i个已切分模型状态集合存储到第i个设备的显存中。
[0029]通常,为了提高模型训练速度,采用数据并行训练方法对模型进行训练。即,通过n个设备共同完成训练任务。这里,模型的模型状态可以被划分成1个未切分模型状态集合和
n个已切分模型状态集合。i和n为正整数,n>1,1≤i≤n。其中,每个设备的显存中可以存储未切分模型状态集合和1个已切分模型状态集合,使得这n个设备存储模型的所有模型状态。
[0030]为了避免显存较小的设备的显存溢出,可以考虑设备间显存大小差异。在对模型状态切分时,可以根据每个设备不同显存大小分配模型状态。例如,基于模型的模型状态总量和n个设备的可用显存量,对模型的模型状态进行切分,得到已切分模型状态集合和n个已切分模型状态集合。
[0031]步骤102,获取第i个训练样本集,以及通过设备间通信从其他n

1个设备获取其他n

1个已切分模型状态集合,对模型进行训练,得到模型的各个模型状态的第一梯度。
[0032]在本实施例中,第i个设备可以获取第i个训练样本集,以及通过设备间通信从其他n

1个设备获取其他n

1个已切分模型状态集合,对模型进行训练,得到模型的各个模型状态的第一梯度。
[0033]在采用数据并行训练方法对模型进行本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,包括:将模型的未切分模型状态集合和第i个已切分模型状态集合存储到第i个设备的显存中,其中,所述模型包含1个未切分模型状态集合和n个已切分模型状态集合,i和n为正整数,n大于1,i不小于1且不大于n;获取第i个训练样本集,以及通过设备间通信从其他n

1个设备获取其他n

1个已切分模型状态集合,对所述模型进行训练,得到所述模型的各个模型状态的第一梯度;基于所述模型的各个模型状态的第一梯度更新所述未切分模型状态集合和所述第i个已切分模型状态集合。2.根据权利要求1所述的方法,其中,所述方法还包括:通过设备间通信将所述其他n

1个已切分模型状态集合中的模型状态的第一梯度对应传递给所述其他n

1个设备。3.根据权利要求1或2所述的方法,其中,所述方法还包括:利用设备间通信获取所述其他n

1个设备传递的所述未切分模型状态集合和所述其他n

1个已切分模型状态集合中的模型状态的第二梯度;以及所述基于所述模型的各个模型状态的第一梯度更新所述未切分模型状态集合和所述第i个已切分模型状态集合,包括:对于所述未切分模型状态集合和所述第i个已切分模型状态集合中的模型状态,基于该模型状态的第一梯度和第二梯度,计算该模型状态的目标梯度,以及基于该模型状态的目标梯度更新该模型状态。4.根据权利要求1

3中任一项所述的方法,其中,所述对所述模型进行训练,得到所述模型的各个模型状态的第一梯度,包括:基于所述第i个训练样本集,对所述模型进行前向传播,计算得到中间变量和损失;基于所述中间变量和所述损失,对所述模型进行反向传播,计算得到所述模型的各个模型状态的第一梯度;以及所述基于所述模型的各个模型状态的第一梯度更新所述未切分模型状态集合和所述第i个已切分模型状态集合,包括:对于所述未切分模型状态集合和所述第i个已切分模型状态集合中的模型状态,基于该模型状态的第一梯度和所述模型的优化器状态更新该模型状态。5.根据权利要求1

4中任一项所述的方法,其中,在所述将模型的未切分模型状态集合和第i个已切分模型状态集合存储到第i个设备的显存中之前,还包括:基于所述模型的模型状态总量和所述n个设备的可用显存量,对所述模型的模型状态进行切分,得到所述已切分模型状态集合和所述n个已切分模型状态集合。6.根据权利要求5所述的方法,其中,所述基于所述模型的模型状态总量和所述n个设备的可用显存量,对所述模型的模型状态进行切分,得到所述已切分模型状态集合和所述n个已切分模型状态集合,包括:获取所述n个设备的剩余显存量;基于所述模型的模型状态总量和所述n个设备的剩余显存量,确定模型状态切分数量;基于所述模型状态切分数量选取出需要进行切分的模型状态,其中,剩余的模型状态组成所述未切分模型状态集合;
基于所述n个设备的剩余显存量对所选取的模型状态依次切分,得到所述n个已切分模型状态集合。7.根据权利要求6所述的方法,其中,所述获取所述n个设备的剩余显存量,包括:通过所述第i个设备的应用程序编程接口获取所述第i个设备的可用显存量;基于所述第i个设备的可用显存量,确定所述第i个设备的剩余显存量;通过设备间同步获取所述其他n

1个设备的剩余显存量。8.根据权利要求7所述的方法,其中,所述基于所述第i个设备的可用显存量,确定所述第i个设备的剩余显存量,包括:获取所述第i个设备的深度学习计算图,其中,所述深度学习计算图可以表达训练过程中的计算和数据;基于所述深度学习计算图计算中间变量的占用显存量;计算所述第i个设备的可用显存量与所述中间变量的占用显存量之差,得到所述第i个设备的剩余显存量;以及在所述基于所述n个设备的剩余显存量对所选取的模型状态依次切分,得到所述n个已切分模型状态集合之后,还包括:将所述其他n

1个模型状态集合中的模型状态从所述第i个设备的深度学习计算图上删除。9.一种模型训练装置,包括:存储模块,被配置成将模型的未切分模型状态集合和第i个已切分模型状态集合存储到第i个设备的显存中,其中,所述模型包含1个未切分模型状态集合和n...

【专利技术属性】
技术研发人员:梁建中敖玉龙于佃海
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1