分布式机器学习模型的训练方法、装置、设备及介质制造方法及图纸

技术编号:32645349 阅读:20 留言:0更新日期:2022-03-12 18:26
本发明专利技术实施例提供一种分布式机器学习模型的训练方法、装置、设备及介质。该方法包括:获取对预设机器学习模型进行训练时计算节点设备对应的多个训练样本数据;采用本地自适应随机梯度下降算法和训练样本数据对预设机器学习模型进行本轮多次迭代训练,以获得本轮迭代训练后的模型参数及自适应学习率关联参数;将本轮迭代训练后的模型参数及自适应学习率关联参数发送给参数服务器,以指示参数服务器根据本轮迭代训练后的模型参数及自适应学习率关联参数计算下一轮迭代训练时的模型参数及自适应学习率关联参数。该方法能够有效减少计算节点设备与参数服务器之间的通信交互次数,进而有效提高机器学习模型的训练效率。进而有效提高机器学习模型的训练效率。进而有效提高机器学习模型的训练效率。

【技术实现步骤摘要】
分布式机器学习模型的训练方法、装置、设备及介质


[0001]本专利技术实施例涉及人工智能
,尤其涉及一种分布式机器学习模型的训练方法、装置、设备及介质。

技术介绍

[0002]随着人工智能技术的发展,采用机器学习模型对数据进行分类的应用场景也不断丰富。如推荐系统的应用场景,自然语言处理的应用场景等。由于对机器学习模型进行训练时,模型参数规模和样本数据的规模都是十分巨大的。所以单个计算节点设备已无法满足要求。所以分布式机器学习便成为模型训练的主要方式。
[0003]目前在采用分布式机器学习对模型进行训练时,一般由计算节点设备采用朴素梯度下降算法和训练样本对机器学习模型进行多次迭代训练后将训练后的模型参数发送给参数服务器,在参数服务器对多个训练后的模型参数进行处理获得下一轮模型参数后发送给多个计算节点设备,再次进行下一轮的迭代训练,以此类推,直到将机器学习模型训练至收敛。
[0004]在实现本专利技术过程中,专利技术人发现现有技术中至少存在如下问题:由于采用朴素梯度下降算法对机器学习模型进行迭代训练,导致迭代训练较多次才会使机器学习模型达到收敛。并且由于迭代训练较多次导致计算节点设备与参数服务器之间的通信交互较多,最终导致了机器学习模型的训练效率较低。

技术实现思路

[0005]本专利技术实施例提供一种分布式机器学习模型的训练方法、装置、设备及介质,用以解决现有技术中机器学习模型的训练效率较低的技术问题。
[0006]第一方面,本专利技术实施例提供分布式机器学习模型的训练方法,所述方法应用于计算节点设备,所述方法包括:
[0007]获取对预设机器学习模型进行训练时计算节点设备对应的多个训练样本数据;
[0008]采用本地自适应随机梯度下降算法和所述训练样本数据对所述预设机器学习模型进行本轮多次迭代训练,以获得本轮迭代训练后的模型参数及自适应学习率关联参数;
[0009]将本轮迭代训练后的模型参数及自适应学习率关联参数发送给参数服务器,以指示所述参数服务器根据本轮迭代训练后的模型参数及自适应学习率关联参数计算下一轮迭代训练时的模型参数及自适应学习率关联参数。
[0010]第二方面,本专利技术实施例提供一种分布式机器学习模型的训练方法,所述方法应用于参数服务器,所述方法包括:
[0011]接收多个计算节点设备发送的对预设机器学习模型进行本轮迭代训练后的模型参数及自适应学习率关联参数;所述本轮迭代训练后的模型参数及自适应学习率关联参数是对应的计算节点设备采用本地自适应随机梯度下降算法和训练样本数据对所述预设机器学习模型进行本轮多次迭代训练后获得的;
[0012]根据多个本轮迭代训练后的模型参数及自适应学习率关联参数计算下一轮迭代训练时的模型参数及自适应学习率关联参数。
[0013]第三方面,本专利技术实施例提供一种分布式机器学习模型的训练装置,所述装置位于计算节点设备中,所述装置包括:
[0014]获取模块,用于获取对预设机器学习模型进行训练时计算节点设备对应的多个训练样本数据;
[0015]迭代训练模块,用于采用本地自适应随机梯度下降算法和所述训练样本数据对所述预设机器学习模型进行本轮多次迭代训练,以获得本轮迭代训练后的模型参数及自适应学习率关联参数;
[0016]发送模块,用于将本轮迭代训练后的模型参数及自适应学习率关联参数发送给参数服务器,以指示所述参数服务器根据本轮迭代训练后的模型参数及自适应学习率关联参数计算下一轮迭代训练时的模型参数及自适应学习率关联参数。
[0017]第四方面,本专利技术实施例提供一种分布式机器学习模型的训练装置,所述装置位于参数服务器中,所述装置包括:
[0018]接收模块,用于接收多个计算节点设备发送的对预设机器学习模型进行本轮迭代训练后的模型参数及自适应学习率关联参数;所述本轮迭代训练后的模型参数及自适应学习率关联参数是对应的计算节点设备采用本地自适应随机梯度下降算法和训练样本数据对所述预设机器学习模型进行本轮多次迭代训练后获得的;
[0019]计算模块,用于根据多个本轮迭代训练后的模型参数及自适应学习率关联参数计算下一轮迭代训练时的模型参数及自适应学习率关联参数。
[0020]第五方面,本专利技术实施例提供一种计算节点设备,包括:至少一个处理器、存储器及收发器;
[0021]所述处理器、所述存储器及所述收发器电路互联;
[0022]所述存储器存储计算机执行指令;所述收发器用于与参数服务器收发数据;
[0023]所述至少一个处理器执行所述存储器存储的计算机执行指令,使得所述至少一个处理器执行如第一方面任一项所述的方法。
[0024]第六方面,本专利技术实施例提供一种参数服务器,包括:至少一个处理器、存储器及收发器;
[0025]所述处理器、所述存储器及所述收发器电路互联;
[0026]所述存储器存储计算机执行指令;所述收发器用于与多个计算节点设备收发数据;
[0027]所述至少一个处理器执行所述存储器存储的计算机执行指令,使得所述至少一个处理器执行如第二方面任一项所述的方法。
[0028]第七方面,本专利技术实施例提供一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机执行指令,所述计算机执行指令被处理器执行时用于实现如第一方面或第二方面任一项所述的方法。
[0029]第八方面,本专利技术实施例提供一种计算机程序产品,包括计算机程序,该计算机程序被处理器执行时实现第一方面或第二方面任一项所述的方法。
[0030]本专利技术实施例提供的分布式机器学习模型的训练方法、装置、设备及介质,通过获
取对预设机器学习模型进行训练时计算节点设备对应的多个训练样本数据;采用本地自适应随机梯度下降算法和训练样本数据对预设机器学习模型进行本轮多次迭代训练,以获得本轮迭代训练后的模型参数及自适应学习率关联参数;将本轮迭代训练后的模型参数及自适应学习率关联参数发送给参数服务器,以指示参数服务器根据本轮迭代训练后的模型参数及自适应学习率关联参数计算下一轮迭代训练时的模型参数及自适应学习率关联参数。由于本地自适应随机梯度下降算法能够综合本地随机梯度算法和自适应梯度下降算法的优势。即该本地自适应随机梯度下降算法在每次进行迭代训练时,既能够根据对应的训练样本数据的特点计算出随机梯度,又能够在随机梯度的基础上确定出自适应学习率关联参数,进而能够自适应的优化模型参数。进而能够有效减少迭代次数。有效减少计算节点设备与参数服务器之间的通信交互次数,进而有效提高机器学习模型的训练效率。并且在稀疏化的深度学习模型的应用场景中具有更快的收敛速度和更好的泛化结果。
附图说明
[0031]此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本公开的实施例,并与说明书一起用于解释本公开的原理。
[0032]图1是可以实现本专利技术实施例的分布式机器学习模型的训练方法的一种网络架构图;
[0033]图2是本专利技术一实施例提供的分布式机器学习模型的本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种分布式机器学习模型的训练方法,其特征在于,所述方法应用于计算节点设备,所述方法包括:获取对预设机器学习模型进行训练时计算节点设备对应的多个训练样本数据;采用本地自适应随机梯度下降算法和所述训练样本数据对所述预设机器学习模型进行本轮多次迭代训练,以获得本轮迭代训练后的模型参数及自适应学习率关联参数;将本轮迭代训练后的模型参数及自适应学习率关联参数发送给参数服务器,以指示所述参数服务器根据本轮迭代训练后的模型参数及自适应学习率关联参数计算下一轮迭代训练时的模型参数及自适应学习率关联参数。2.根据权利要求1所述的方法,其特征在于,所述采用本地自适应随机梯度下降算法和所述训练样本数据对所述预设机器学习模型进行本轮多次迭代训练,以获得本轮迭代训练后的模型参数及自适应学习率关联参数,包括:在本轮每次迭代训练时执行以下操作:获取预设机器学习模型的本轮本次迭代训练时的模型参数和本轮上一次迭代训练时的自适应学习率关联参数;从所述训练样本数据中随机获取任意一个训练样本数据;采用本地自适应随机梯度下降算法、所述任意一个训练样本数据、本轮本次迭代训练时的模型参数和本轮上一次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练时的自适应学习率关联参数;根据本轮本次迭代训练时的模型参数及本轮本次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练后的模型参数。3.根据权利要求2所述的方法,其特征在于,所述本地自适应随机梯度下降算法包括:本地随机梯度计算算法及自适应梯度下降算法;所述采用本地自适应随机梯度下降算法、所述任意一个训练样本数据、本轮本次迭代训练时的模型参数和本轮上一次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练时的自适应学习率关联参数,包括:采用任意一个训练样本数据和所述本轮本次迭代训练时的模型参数计算所述计算节点设备对应的损失函数;采用所述本地随机梯度计算算法对所述损失函数进行反向传播操作,以计算获得本轮本次迭代训练时的随机梯度;采用所述自适应梯度下降算法、所述本轮本次迭代训练时的随机梯度及本轮上一次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练时的自适应学习率关联参数。4.根据权利要求3所述的方法,其特征在于,所述自适应梯度下降算法为学习率通用框架梯度下降AMSGrad算法,所述自适应学习率关联参数包括:动量及二阶动量;所述采用所述自适应梯度下降算法、所述本轮本次迭代训练时的随机梯度及本轮上一次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练时的自适应学习率关联参数,包括:将所述本轮本次迭代训练时的随机梯度及本轮上一次迭代训练时的动量输入到AMSGrad算法中,通过所述AMSGrad算法计算本轮本次迭代训练时的动量;将所述本轮本次迭代训练时的随机梯度及所述本轮上一次迭代训练时的二阶动量输
入到AMSGrad算法中,通过所述AMSGrad算法计算本轮本次迭代训练时的二阶动量。5.根据权利要求2

4任一项所述的方法,其特征在于,所述根据本轮本次迭代训练时的模型参数及本轮本次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练后的模型参数,包括:获取所述本地自适应随机梯度下降算法对应的预设固定学习率;根据所述预设固定学习率和所述本轮本次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练时的自适应学习率;根据所述本轮本次迭代训练时的自适应学习率对所述本轮本次迭代训练时的模型参数进行更新,以获得本轮本次迭代训练后的模型参数。6.根据权利要求1

4任一项所述的方法,其特征在于,所述将本轮迭代训练后的模型参数及自适应学习率关联参数发送给参数服务器之后,还包括:判断是否接收到参数服务器发送的下一轮迭代训练时的模型参数及自适应学习率关联参数;若接收到参数服务器发送的下一轮迭代训练时的模型参数及自适应学习率关联参数,则继续执行获取对预设机器学习模型进行训练时计算节点设备对应的多个训练样本数据的步骤。7.根据权利要求1

4任一项所述的方...

【专利技术属性】
技术研发人员:沈力孙昊陶大程
申请(专利权)人:京东科技信息技术有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1