一种模型的训练方法及装置制造方法及图纸

技术编号:32121551 阅读:20 留言:0更新日期:2022-01-29 19:08
本发明专利技术实施例提供了一种模型的训练方法及装置,应用于包括任务节点设备以及参数节点设备的分布式系统,该方法包括:在任务节点设备新增训练数据的情况下,将目标数据标识发送至参数节点设备;在参数节点设备未查到目标数据标识对应的模型参数的情况下,针对目标数据标识生成对应的目标模型参数,将目标数据标识以及目标模型参数采用非连续存储的方式存储在参数节点设备的内存中;通过参数节点设备将目标模型参数发送至任务节点设备,进而根据目标数据标识对应训练数据以及目标模型参数计算得到目标梯度值。然后基于目标梯度值更新目标模型参数。本发明专利技术在新增训练数据的情况下,可以保证训练正常进行,从而无需终止训练。从而无需终止训练。从而无需终止训练。

【技术实现步骤摘要】
一种模型的训练方法及装置


[0001]本专利技术涉及深度学习模型训练领域,尤其涉及一种模型的训练方法及装置。

技术介绍

[0002]模型训练涉及大量的数据以及计算,整个过程十分复杂,导致训练周期通常很长。目前,分布式训练以其极高的效率受到大家的追捧。其中,PS(Parameter Server)架构为代表,该架构包括参数服务器节点和任务节点;其中,参数服务器节点的主要功能是初始化和保存模型参数、接受任务节点计算出的局部梯度、汇总计算全局梯度,并更新模型参数。任务节点的主要功能是各自保存部分训练数据,初始化模型,从参数服务器节点拉取最新的模型参数,再读取模型参数,根据训练数据计算局部梯度,然后将局部梯度上传给参数服务器节点。两部分节点各司其职,避免数据与计算过程集中在一处。
[0003]现有深度学习平台采用PS架构进行模型训练的过程中,需要预先准备好训练数据,并基于训练数据对应设置模型参数。启动训练之后,基于准备好的训练数据对模型进行训练。
[0004]然而,在上述训练过程中,若额外增加新的训练数据,会导致训练终止,从而只能重新开始训练,耽误训练进程。

技术实现思路

[0005]鉴于上述问题,本专利技术实施例提供一种模型的训练方法及装置,以解决现有技术中新增训练数据时,需要停止训练导致影响训练效率的问题。
[0006]在本专利技术实施的第一方面,提供了一种模型的训练方法,应用于包括任务节点设备以及参数节点设备的分布式系统,其中,所述任务节点设备存储有训练模型所需的训练数据,所述参数节点设备存储有每条训练数据的数据标识以及每个数据标识对应的模型参数,所述方法包括:
[0007]在所述任务节点设备新增训练数据的情况下,将目标数据标识发送至参数节点设备,其中,所述目标数据标识为新增训练数据中的数据标识;
[0008]在所述参数节点设备未查到所述目标数据标识对应的模型参数的情况下,针对所述目标数据标识生成对应的目标模型参数,将所述目标数据标识以及所述目标模型参数采用非连续存储的方式存储在所述参数节点设备的内存中;
[0009]通过所述参数节点设备将所述目标模型参数发送至所述任务节点设备,以使所述任务节点设备根据所述目标数据标识对应训练数据以及所述目标模型参数计算得到目标梯度值;
[0010]通过所述任务节点设备将所述目标梯度值发送至所述参数节点设备,以使所述参数节点设备基于所述目标梯度值更新所述目标模型参数。
[0011]可选地,在所述参数节点设备基于所述目标梯度值更新所述目标模型参数之后,所述方法还包括:
[0012]通过所述参数节点设备基于所述内存中的数据标识以及模型参数,确定实际数据范围值,其中,所述实际数据范围值包括:所述内存中数据标识和模型参数的每一维度的元素数量;
[0013]在所述实际数据范围值与预设数据范围值不同的情况下,按所述实际数据范围值将所述内存中数据标识和模型参数存储至所述参数节点设备的外存中,其中,所述预设数据范围值包括开始训练前基于训练数据确定的数据标识和模型参数的每一维度的元素数量。
[0014]可选地,所述按所述实际数据范围值将所述内存中数据标识和模型参数存储至所述参数节点设备的外存中,包括:
[0015]分别根据所述内存中的数据标识以及模型参数,生成标识张量和参数张量;
[0016]按照所述标识张量和所述参数张量的每一维度的元素数量,分别将所述标识张量和所述参数张存储至所述参数节点设备的外存中。
[0017]可选地,所述参数节点设备基于所述目标梯度值更新所述目标模型参数,包括:
[0018]基于所述目标梯度值采用目标优化器对所述目标模型参数进行更新,其中,所述目标优化器为针对非连续性存储的变量设计的优化器。
[0019]可选地,在新增训练数据包括多条训练数据,且所述参数节点设备的数量为多个的情况下,所述将目标数据标识发送至参数节点设备包括:
[0020]分别将所述多条训练数据中各自的目标数据标识发送至不同的参数节点设备。
[0021]在本专利技术实施的第二方面,还提供了一种模型的训练装置,应用于包括任务节点设备以及参数节点设备的分布式系统,其中,所述任务节点设备存储有训练模型所需的训练数据,所述参数节点设备存储有每条训练数据的数据标识以及每个数据标识对应的模型参数,所述方法包括:接收模块,用于接收用户输入的目标信息,其中,所述目标信息包括在所述多个不同应用各自对应的数据源中查询数据所需的信息;所述装置包括:
[0022]发送模块,用于在所述任务节点设备新增训练数据的情况下,将目标数据标识发送至参数节点设备,其中,所述目标数据标识为新增训练数据中的数据标识;
[0023]参数返回模块,用于在所述参数节点设备未查到所述目标数据标识对应的模型参数的情况下,针对所述目标数据标识生成对应的目标模型参数,将所述目标数据标识以及所述目标模型参数采用非连续存储的方式存储在所述参数节点设备的内存中;
[0024]梯度模块,用于通过所述参数节点设备将所述目标模型参数发送至所述任务节点设备,以使所述任务节点设备根据所述目标数据标识对应训练数据以及所述目标模型参数计算得到目标梯度值;
[0025]更新模块,用于通过所述任务节点设备将所述目标梯度值发送至所述参数节点设备,以使所述参数节点设备基于所述目标梯度值更新所述目标模型参数。
[0026]可选地,所述装置还包括:
[0027]判断模块,用于通过所述参数节点设备基于所述内存中的数据标识以及模型参数,确定实际数据范围值,其中,所述实际数据范围值包括:所述内存中数据标识和模型参数的每一维度的元素数量;
[0028]存储模块,用于在所述实际数据范围值与预设数据范围值不同的情况下,按所述实际数据范围值将所述内存中数据标识和模型参数存储至所述参数节点设备的外存中,其
中,所述预设数据范围值包括开始训练前基于训练数据确定的数据标识和模型参数的每一维度的元素数量。
[0029]可选地,所述存储模块,包括:
[0030]生成单元,用于分别根据所述内存中的数据标识以及模型参数,生成标识张量和参数张量;
[0031]存储单元,用于按照所述标识张量和所述参数张量的每一维度的元素数量,分别将所述标识张量和所述参数张存储至所述参数节点设备的外存中。
[0032]可选地,所述更新模块,具体用于基于所述目标梯度值采用目标优化器对所述目标模型参数进行更新,其中,所述目标优化器为针对非连续性存储的变量设计的优化器。
[0033]可选地,在新增训练数据包括多条训练数据,且所述参数节点设备的数量为多个的情况下,所述发送模块,具体用于分别将所述多条训练数据中各自的目标数据标识发送至不同的参数节点设备。
[0034]在本专利技术实施的第三方面,还提供了一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
[0035]存储器本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型的训练方法,其特征在于,应用于包括任务节点设备以及参数节点设备的分布式系统,其中,所述任务节点设备存储有训练模型所需的训练数据,所述参数节点设备存储有每条训练数据的数据标识以及每个数据标识对应的模型参数,所述方法包括:在所述任务节点设备新增训练数据的情况下,将目标数据标识发送至参数节点设备,其中,所述目标数据标识为新增训练数据中的数据标识;在所述参数节点设备未查到所述目标数据标识对应的模型参数的情况下,针对所述目标数据标识生成对应的目标模型参数,将所述目标数据标识以及所述目标模型参数采用非连续存储的方式存储在所述参数节点设备的内存中;通过所述参数节点设备将所述目标模型参数发送至所述任务节点设备,以使所述任务节点设备根据所述目标数据标识对应训练数据以及所述目标模型参数计算得到目标梯度值;通过所述任务节点设备将所述目标梯度值发送至所述参数节点设备,以使所述参数节点设备基于所述目标梯度值更新所述目标模型参数。2.根据权利要求1所述的方法,其特征在于,在所述参数节点设备基于所述目标梯度值更新所述目标模型参数之后,所述方法还包括:通过所述参数节点设备基于所述内存中的数据标识以及模型参数,确定实际数据范围值,其中,所述实际数据范围值包括:所述内存中数据标识和模型参数的每一维度的元素数量;在所述实际数据范围值与预设数据范围值不同的情况下,按所述实际数据范围值将所述内存中数据标识和模型参数存储至所述参数节点设备的外存中,其中,所述预设数据范围值包括开始训练前基于训练数据确定的数据标识和模型参数的每一维度的元素数量。3.根据权利要求2所述的方法,其特征在于,所述按所述实际数据范围值将所述内存中数据标识和模型参数存储至所述参数节点设备的外存中,包括:分别根据所述内存中的数据标识以及模型参数,生成标识张量和参数张量;按照所述标识张量和所述参数张量的每一维度的元素数量,分别将所述标识张量和所述参数张存储至所述参数节点设备的外存中。4.根据权利要求1所述的方法,其特征在于,所述参数节点设备基于所述目标梯度值更新所述目标模型参数,包括:基于所述目标梯度值采用目标优化器对所述目标模型参数进行更新,其中,所述目标优化器为针对非连续性存储的变量设计的优化器。5.根据权利要求1所述的方法,其特征在于,在新增训练数据包括多条训练数据,且所述参数节点设备的数量为多个的情况下,所述将目标数据标识发送至参数节点设备包括:分别将所述多条训练数据中各自的目标数据标识发送至不同的参数节点设备。6.一种模型的训练装置,其特征在于,应用...

【专利技术属性】
技术研发人员:孙宇
申请(专利权)人:北京奇艺世纪科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1