一种分布式训练中的梯度更新方法及相关设备技术

技术编号:19935258 阅读:32 留言:0更新日期:2018-12-29 04:55
本申请公开了一种分布式训练中的梯度更新方法及相关设备,方法应用于计算节点服务器,所述方法包括:监测对目标数据进行分布式训练的全局轮换更新步数;基于所述全局轮换更新步数满足所述计算节点服务器的全局梯度更新条件的判断,将所述计算节点服务器当前的本地梯度累计值传输给参数服务器中,由所述参数服务器对全局训练模型基于所述本地梯度累计值进行梯度更新;基于所述全局轮换更新步数不满足所述计算节点服务器的全局梯度更新条件的判断,计算当前本地梯度并将所述当前本地梯度累加到所述本地梯度累计值中。

【技术实现步骤摘要】
一种分布式训练中的梯度更新方法及相关设备
本申请涉及计算机
,尤其涉及一种分布式训练中的梯度更新方法及相关设备。
技术介绍
目前,在深度学习的分布式训练中,通常会遇到梯度更新过程中计算节点等待轮换梯度更新而导致设备占用率较低的问题。因此,亟需一种能够解决分布式训练中设备占用率较低问题的技术方案。
技术实现思路
有鉴于此,本申请提供一种分布式训练中的梯度更新方法及相关设备,用以解决现有技术中分布式训练中设备占用率较低的技术问题。本申请提供了一种分布式训练中的梯度更新方法,应用于计算节点服务器,所述方法包括:监测对目标数据进行分布式训练的全局轮换更新步数;基于所述全局轮换更新步数满足所述计算节点服务器的全局梯度更新条件的判断,将所述计算节点服务器当前的本地梯度累计值传输给参数服务器中,由所述参数服务器对全局训练模型基于所述本地梯度累计值进行梯度更新;基于所述全局轮换更新步数不满足所述计算节点服务器的全局梯度更新条件的判断,计算当前本地梯度并将所述当前本地梯度累加到所述本地梯度累计值中。上述方法,优选的,计算当前本地梯度并将所述当前本地梯度累加到所述本地梯度累计值中,包括:基于本地训练模型,计算当前本地梯度;将所述当前本地梯度累加到所述本地梯度累计值中。上述方法,优选的,还包括:基于所述当前本地梯度,对所述计算节点服务器中的本地训练模型进行梯度更新。上述方法,优选的,还包括:接收所述参数服务器传输的全局训练模型,所述参数服务器传输的全局训练模型为经过基于所述本地梯度累计值的梯度更新后的模型;将接收到的全局训练模型更新为所述计算节点服务器的本地训练模型。上述方法,优选的,所述全局梯度更新条件包括:全局轮换更新步数与计算节点服务器的本地轮换标识相对应。本申请还提供了一种分布式训练中的梯度更新装置,应用于计算节点服务器,所述装置包括:步数监测单元,用于监测对目标数据进行分布式训练的全局轮换更新步数;条件判断单元,用于判断所述全局轮换更新步数是否满足所述计算节点服务器的全局梯度更新条件,如果是,触发梯度更新单元,否则,触发梯度累计单元;梯度更新单元,用于将所述计算节点服务器当前的本地梯度累计值传输给参数服务器中,由所述参数服务器对全局训练模型基于所述本地梯度累计值进行梯度更新;梯度累计单元,用于计算当前本地梯度并将所述当前本地梯度累加到所述本地梯度累计值中。本申请还提供了一种计算节点服务器,包括:存储器,用于存储应用程序及所述应用程序运行所产生的数据;处理器,用于执行所述应用程序,以实现功能:监测对目标数据进行分布式训练的全局轮换更新步数;基于所述全局轮换更新步数满足所述计算节点服务器的全局梯度更新条件的判断,将所述计算节点服务器当前的本地梯度累计值传输给参数服务器中,由所述参数服务器对全局训练模型基于所述本地梯度累计值进行梯度更新;基于所述全局轮换更新步数不满足所述计算节点服务器的全局梯度更新条件的判断,计算当前本地梯度并将所述当前本地梯度累加到所述本地梯度累计值中。本申请还提供了一种训练集群,包括:参数服务器,用于存储全局训练模型;计算节点服务器,用于监测对目标数据进行分布式训练的全局轮换更新步数;基于所述全局轮换更新步数满足所述计算节点服务器的全局梯度更新条件的判断,将所述计算节点服务器当前的本地梯度累计值传输给参数服务器中,由所述参数服务器对全局训练模型基于所述本地梯度累计值进行梯度更新;基于所述全局轮换更新步数不满足所述计算节点服务器的全局梯度更新条件的判断,计算当前本地梯度并将所述当前本地梯度累加到所述本地梯度累计值中。上述训练集群,优选的,所述计算节点服务器具体用于:基于本地训练模型,计算当前本地梯度;将所述当前本地梯度累加到所述本地梯度累计值中。上述训练集群,优选的,所述计算节点服务器还用于:接收所述参数服务器传输的全局训练模型,所述参数服务器传输的全局训练模型为经过基于所述本地梯度累计值的梯度更新后的模型;将接收到的全局训练模型更新为所述计算节点服务器的本地训练模型。。从上述技术方案可以看出,本申请提供的一种分布式训练中的梯度更新方法及相关设备,通过对目标数据在进行分布式训练时的全局轮换更新步数进行监测,从而在满足当前计算节点服务器的全局梯度更新条件时,将计算节点服务器当前的本地梯度累计值传输给参数服务器中,进而参数服务器对全局训练模型基于该本地梯度累计值进行梯度更新,而在不满足当前计算节点服务器的全局梯度更新条件时,当前计算节点服务器计算当前本地梯度并进行累加到本地梯度累计值中。可见,在本申请中每次全局轮换更新步数更新时,判断有没有轮换到计算节点服务器进行梯度更新,而在没有轮换到时,计算节点服务器会持续进行本地梯度计算并进行累加,直到轮换到计算节点服务器进行梯度更新时,计算节点服务器就可以将累加的本地梯度更新到参数服务器,由参数服务器对全局训练模型基于累计值进行梯度更新,由此计算节点服务器不会在没有轮换到梯度更新时只进行等待,而是持续进行本地梯度计算,从而减少计算节点服务器的等待时间,提高计算节点服务器等设备的占用率。附图说明为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。图1为本申请实施例一提供的一种分布式训练中的梯度更新方法的流程图;图2为本申请实施例的应用示例图;图3~图5分别为本申请实施例一的其他流程图;图6为本申请实施例的另一应用示例图;图7为本申请实施例二提供的一种分布式训练中的梯度更新装置的结构示意图;图8为本申请实施例三提供的一种计算节点服务器的结构示意图;图9为本申请实施例四提供的一种训练集群的架构图;图10为本申请实施例中的另一应用示例图。具体实施方式下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。参考图1,为本申请实施例一提供的一种分布式训练中的梯度更新方法的实现流程图,本实施例中的方法适用于能够进行梯度计算的计算节点服务器中,计算节点服务器与参数服务器组成分布式训练集群,而在分布式训练集群中可以不止一个计算节点服务器,本实施例中的方案适用于分布式训练集群中的任意计算节点服务器中。在具体实现中,计算节点服务器可以为具有数据计算功能的台式机、服务器等终端设备实现。本实施例中的方法可以包括以下步骤:步骤101:监测对目标数据进行分布式训练的全局轮换更新步数。其中,目标数据即为进行分布式训练的数据,如某个领域的样本数据等,在分布式训练集群中,参数服务器存放全局训练模型,而本实施例中基于全局训练模型对目标数据进行分布式训练,目标数据进行分布式训练中会经过多次迭代训练,并且以全局轮换更新步数表征迭代训练中进行训练梯度更新的计算节点服务器。由此,本实施例中通过监测对目标数据进行分布式训练的全局轮换更新步数来判断是否轮到当前的计算节点服务器进行梯度更新,即基于本文档来自技高网...

【技术保护点】
1.一种分布式训练中的梯度更新方法,应用于计算节点服务器,所述方法包括:监测对目标数据进行分布式训练的全局轮换更新步数;基于所述全局轮换更新步数满足所述计算节点服务器的全局梯度更新条件的判断,将所述计算节点服务器当前的本地梯度累计值传输给参数服务器中,由所述参数服务器对全局训练模型基于所述本地梯度累计值进行梯度更新;基于所述全局轮换更新步数不满足所述计算节点服务器的全局梯度更新条件的判断,计算当前本地梯度并将所述当前本地梯度累加到所述本地梯度累计值中。

【技术特征摘要】
1.一种分布式训练中的梯度更新方法,应用于计算节点服务器,所述方法包括:监测对目标数据进行分布式训练的全局轮换更新步数;基于所述全局轮换更新步数满足所述计算节点服务器的全局梯度更新条件的判断,将所述计算节点服务器当前的本地梯度累计值传输给参数服务器中,由所述参数服务器对全局训练模型基于所述本地梯度累计值进行梯度更新;基于所述全局轮换更新步数不满足所述计算节点服务器的全局梯度更新条件的判断,计算当前本地梯度并将所述当前本地梯度累加到所述本地梯度累计值中。2.根据权利要求1所述的方法,其特征在于,计算当前本地梯度并将所述当前本地梯度累加到所述本地梯度累计值中,包括:基于本地训练模型,计算当前本地梯度;将所述当前本地梯度累加到所述本地梯度累计值中。3.根据权利要求1或2所述的方法,其特征在于,还包括:基于所述当前本地梯度,对所述计算节点服务器中的本地训练模型进行梯度更新。4.根据权利要求1或2所述的方法,其特征在于,还包括:接收所述参数服务器传输的全局训练模型,所述参数服务器传输的全局训练模型为经过基于所述本地梯度累计值的梯度更新后的模型;将接收到的全局训练模型更新为所述计算节点服务器的本地训练模型。5.根据权利要求1或2所述的方法,其特征在于,所述全局梯度更新条件包括:全局轮换更新步数与计算节点服务器的本地轮换标识相对应。6.一种分布式训练中的梯度更新装置,应用于计算节点服务器,所述装置包括:步数监测单元,用于监测对目标数据进行分布式训练的全局轮换更新步数;条件判断单元,用于判断所述全局轮换更新步数是否满足所述计算节点服务器的全局梯度更新条件,如果是,触发梯度更新单元,否则,触发梯度累计单元;梯度更新单元,用于将所述计算节点服务器当前的本地梯度累计值传输给参数服务器中,由所述参数服务器对全局...

【专利技术属性】
技术研发人员:胡文晖王鹏王奇刚
申请(专利权)人:联想北京有限公司
类型:发明
国别省市:北京,11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1