【技术实现步骤摘要】
分布式机器学习模型的训练方法、装置、设备及介质
[0001]本专利技术实施例涉及人工智能
,尤其涉及一种分布式机器学习模型的训练方法、装置、设备及介质。
技术介绍
[0002]随着人工智能技术的发展,采用机器学习模型对数据进行分类的应用场景也不断丰富。如推荐系统的应用场景,自然语言处理的应用场景等。由于对机器学习模型进行训练时,模型参数规模和样本数据的规模都是十分巨大的。所以单个计算节点设备已无法满足要求。所以分布式机器学习便成为模型训练的主要方式。
[0003]目前在采用分布式机器学习对模型进行训练时,一般由计算节点设备采用朴素梯度下降算法和训练样本对机器学习模型进行多次迭代训练后将训练后的模型参数发送给参数服务器,在参数服务器对多个训练后的模型参数进行处理获得下一轮模型参数后发送给多个计算节点设备,再次进行下一轮的迭代训练,以此类推,直到将机器学习模型训练至收敛。
[0004]在实现本专利技术过程中,专利技术人发现现有技术中至少存在如下问题:由于采用朴素梯度下降算法对机器学习模型进行迭代训练,导致迭代训练较多次才会使机器学习模型达到收敛。并且由于迭代训练较多次导致计算节点设备与参数服务器之间的通信交互较多,最终导致了机器学习模型的训练效率较低。
技术实现思路
[0005]本专利技术实施例提供一种分布式机器学习模型的训练方法、装置、设备及介质,用以解决现有技术中机器学习模型的训练效率较低的技术问题。
[0006]第一方面,本专利技术实施例提供分布式机器学习模型的训练方法 ...
【技术保护点】
【技术特征摘要】
1.一种分布式机器学习模型的训练方法,其特征在于,所述方法应用于计算节点设备,所述方法包括:获取对预设机器学习模型进行训练时计算节点设备对应的多个训练样本数据;采用本地自适应随机梯度下降算法和所述训练样本数据对所述预设机器学习模型进行本轮多次迭代训练,以获得本轮迭代训练后的模型参数及自适应学习率关联参数;将本轮迭代训练后的模型参数及自适应学习率关联参数发送给参数服务器,以指示所述参数服务器根据本轮迭代训练后的模型参数及自适应学习率关联参数计算下一轮迭代训练时的模型参数及自适应学习率关联参数。2.根据权利要求1所述的方法,其特征在于,所述采用本地自适应随机梯度下降算法和所述训练样本数据对所述预设机器学习模型进行本轮多次迭代训练,以获得本轮迭代训练后的模型参数及自适应学习率关联参数,包括:在本轮每次迭代训练时执行以下操作:获取预设机器学习模型的本轮本次迭代训练时的模型参数和本轮上一次迭代训练时的自适应学习率关联参数;从所述训练样本数据中随机获取任意一个训练样本数据;采用本地自适应随机梯度下降算法、所述任意一个训练样本数据、本轮本次迭代训练时的模型参数和本轮上一次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练时的自适应学习率关联参数;根据本轮本次迭代训练时的模型参数及本轮本次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练后的模型参数。3.根据权利要求2所述的方法,其特征在于,所述本地自适应随机梯度下降算法包括:本地随机梯度计算算法及自适应梯度下降算法;所述采用本地自适应随机梯度下降算法、所述任意一个训练样本数据、本轮本次迭代训练时的模型参数和本轮上一次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练时的自适应学习率关联参数,包括:采用任意一个训练样本数据和所述本轮本次迭代训练时的模型参数计算所述计算节点设备对应的损失函数;采用所述本地随机梯度计算算法对所述损失函数进行反向传播操作,以计算获得本轮本次迭代训练时的随机梯度;采用所述自适应梯度下降算法、所述本轮本次迭代训练时的随机梯度及本轮上一次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练时的自适应学习率关联参数。4.根据权利要求3所述的方法,其特征在于,所述自适应梯度下降算法为学习率通用框架梯度下降AMSGrad算法,所述自适应学习率关联参数包括:动量及二阶动量;所述采用所述自适应梯度下降算法、所述本轮本次迭代训练时的随机梯度及本轮上一次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练时的自适应学习率关联参数,包括:将所述本轮本次迭代训练时的随机梯度及本轮上一次迭代训练时的动量输入到AMSGrad算法中,通过所述AMSGrad算法计算本轮本次迭代训练时的动量;将所述本轮本次迭代训练时的随机梯度及所述本轮上一次迭代训练时的二阶动量输
入到AMSGrad算法中,通过所述AMSGrad算法计算本轮本次迭代训练时的二阶动量。5.根据权利要求2
‑
4任一项所述的方法,其特征在于,所述根据本轮本次迭代训练时的模型参数及本轮本次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练后的模型参数,包括:获取所述本地自适应随机梯度下降算法对应的预设固定学习率;根据所述预设固定学习率和所述本轮本次迭代训练时的自适应学习率关联参数计算本轮本次迭代训练时的自适应学习率;根据所述本轮本次迭代训练时的自适应学习率对所述本轮本次迭代训练时的模型参数进行更新,以获得本轮本次迭代训练后的模型参数。6.根据权利要求1
‑
4任一项所述的方法,其特征在于,所述将本轮迭代训练后的模型参数及自适应学习率关联参数发送给参数服务器之后,还包括:判断是否接收到参数服务器发送的下一轮迭代训练时的模型参数及自适应学习率关联参数;若接收到参数服务器发送的下一轮迭代训练时的模型参数及自适应学习率关联参数,则继续执行获取对预设机器学习模型进行训练时计算节点设备对应的多个训练样本数据的步骤。7.根据权利要求1
‑
4任一项所述的方...
【专利技术属性】
技术研发人员:沈力,孙昊,陶大程,
申请(专利权)人:京东科技信息技术有限公司,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。