模型训练的方法、装置、设备及计算机可读存储介质制造方法及图纸

技术编号:26651299 阅读:20 留言:0更新日期:2020-12-09 00:52
本发明专利技术实施例提供一种模型训练的方法、装置、设备及计算机可读存储介质,本发明专利技术实施例的方法,在采用批量梯度下降进行模型训练的每轮迭代过程中,通过分布式批处理平台中的第一节点将批量训练数据集拆分为多个子数据集,将所述子数据集分发给各第二节点,并将本轮的初始梯度参数发送给各所述第二节点;第二节点根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数,并将所述更新后的梯度参数发送给第一节点;第一节点根据各所述第二节点返回的更新后的梯度参数,确定本轮的目标梯度参数,能够基于分布式批处理平台,利用分布式实时批处理技术,提高批量梯度下降法的批处理能力,从而能够加快训练速度,缩短训练时间,提高训练效率。

【技术实现步骤摘要】
模型训练的方法、装置、设备及计算机可读存储介质
本专利技术实施例涉及计算机
,尤其涉及一种模型训练的方法、装置、设备及计算机可读存储介质。
技术介绍
在多个应用领域中,如自然语言处理、计算机视觉、语音识别、预测分析、推荐引擎等等,均需要使用机器学习模型进行数据处理。在机器学习中,在进行数据处理通常会根据输入来预测输出,预测值和真实值之间会有一定的误差,在模型训练的过程中会使用优化器(optimizer)来最小化这个误差,梯度下降法(GradientDescent)就是一种常用的优化器。批量梯度下降(BatchGradientDescent,简称BGD)是进行模型训练时的一种常用的梯度下降法,每次都使用训练集中的所有样本来更新梯度参数,可以得到全局最优解。在实现本专利技术过程中,专利技术人发现现有技术中至少存在如下问题:当样本量很大时,批量梯度下降的速度就会非常慢,这样会导致模型训练时间长,模型训练效率低,从而导致获取机器学习模型并进行数据处理的效率低。
技术实现思路
本专利技术实施例提供一种模型训练的方法、装置、设备及计算机可读存储介质,用以解决当样本量很大时,批量梯度下降的速度就会非常慢,导致模型训练时间长,效率低的问题。第一方面,本专利技术实施例提供一种模型训练的方法,应用于分布式批处理平台的第一节点,所述分布式批处理平台包括所述第一节点,以及多个第二节点,包括:在进行每轮迭代过程中,将批量训练数据集拆分为多个子数据集;将所述子数据集分发给各第二节点,并将本轮的初始梯度参数发送给各所述第二节点;根据各所述第二节点返回的更新后的梯度参数,确定本轮的目标梯度参数;其中,最后一轮的目标梯度参数作为机器学习模型的最终模型参数,所述机器学习模型用于进行数据处理。第二方面,本专利技术实施例提供一种模型训练的方法,应用于分布式批处理平台的第二节点,所述分布式批处理平台包第一节点,以及多个第二节点,包括:获取子数据集和本轮迭代的初始梯度参数,所述子数据集为批量训练数据集的子集;根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数;将所述更新后的梯度参数发送给第一节点,以使第一节点根据各所述第二节点返回的更新后的梯度参数确定本轮的目标梯度参数;其中,最后一轮的目标梯度参数作为机器学习模型的最终模型参数,所述机器学习模型用于进行数据处理。第三方面,本专利技术实施例提供一种数据处理的方法,包括:获取待处理数据;通过训练好的机器学习模型对所述待处理数据进行数据处理;其中,所述机器学习模型的模型训练的每轮迭代过程中,通过分布式批处理平台的第一节点将批量训练数据集拆分为多个子数据集;将所述子数据集分发给各第二节点,并将本轮的初始梯度参数发送给各所述第二节点,以使各所述第二节点根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数;根据各所述第二节点返回的更新后的梯度参数,确定本轮的目标梯度参数;最后一轮的目标梯度参数作为机器学习模型的最终模型参数,得到所述训练好的机器学习模型。第四方面,本专利技术实施例提供一种模型训练的装置,应用于分布式批处理平台的第一节点,所述分布式批处理平台包括所述第一节点,以及多个第二节点,包括:任务分发模块,用于在进行每轮迭代过程中,将批量训练数据集拆分为多个子数据集;所述任务分发模块还用于将所述子数据集分发给各第二节点,并将本轮的初始梯度参数发送给各所述第二节点;梯度参数确定模块,用于根据各所述第二节点返回的更新后的梯度参数,确定本轮的目标梯度参数;其中,最后一轮的目标梯度参数作为机器学习模型的最终模型参数,所述机器学习模型用于进行数据处理。第五方面,本专利技术实施例提供一种模型训练的装置,应用于分布式批处理平台的第二节点,所述分布式批处理平台包括第一节点,以及多个第二节点,包括:通信模块,用于获取子数据集和本轮迭代的初始梯度参数,所述子数据集为批量训练数据集的子集;参数计算模块,用于根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数;所述通信模块还用于将所述更新后的梯度参数发送给第一节点,以使第一节点根据各所述第二节点返回的更新后的梯度参数确定本轮的目标梯度参数;其中,最后一轮的目标梯度参数作为机器学习模型的最终模型参数,所述机器学习模型用于进行数据处理。第六方面,本专利技术实施例提供一种电子设备,应用于分布式批处理平台的第一节点,所述分布式批处理平台包括所述第一节点,以及多个第二节点,包括:处理器,存储器,以及存储在所述存储器上并可在所述处理器上运行的计算机程序;其中,所述处理器运行所述计算机程序时实现上述第一方面所述的方法。第七方面,本专利技术实施例提供一种电子设备,应用于分布式批处理平台的第二节点,所述分布式批处理平台包括第一节点,以及多个第二节点,包括:处理器,存储器,以及存储在所述存储器上并可在所述处理器上运行的计算机程序;其中,所述处理器运行所述计算机程序时实现上述第二方面所述的方法。第八方面,本专利技术实施例提供一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,所述计算机程序被处理器执行时实现上述第一方面或者第二方面所述的方法。本专利技术实施例提供的模型训练的方法、装置、设备及计算机可读存储介质,在采用批量梯度下降进行模型训练的每轮迭代过程中,通过分布式批处理平台中的第一节点将批量训练数据集拆分为多个子数据集,将所述子数据集分发给各第二节点,并将本轮的初始梯度参数发送给各所述第二节点;第二节点根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数,并将所述更新后的梯度参数发送给第一节点;第一节点根据各所述第二节点返回的更新后的梯度参数,确定本轮的目标梯度参数,能够基于分布式批处理平台,利用分布式实时批处理技术,提高批量梯度下降法的批处理能力,从而能够加快训练速度,缩短训练时间,提高训练效率,从而可以缩短训练得到机器学习模型并进行数据处理的时间,提高数据处理的效率。附图说明图1为本专利技术实施例一提供的模型训练的方法流程图;图2为本专利技术实施例二提供的模型训练的方法流程图;图3为本专利技术实施例二提供的另一模型训练的方法流程图;图4为本专利技术实施例三提供的模型训练的装置的结构示意图;图5为本专利技术实施例四提供的模型训练的装置的结构示意图;图6为本专利技术实施例五提供的模型训练的装置的结构示意图;图7为本专利技术实施例七提供的数据处理的方法流程图;图8为本专利技术实施例七提供的电子设备的结构示意图;图9为本专利技术实施例八提供的电子设备的结构示意图。通过上述附图,已示出本专利技术明确的实施例,后文中将有更详细的描述。这些附图和文字描述并不是为了通过任何方式限制本专利技术构思的范围,而是通过参考特定实施例为本领域技术人员说明本专利技术的概念。具体实施方式这里将详细地对示例性实施例进行说明本文档来自技高网...

【技术保护点】
1.一种模型训练的方法,其特征在于,应用于分布式批处理平台的第一节点,所述分布式批处理平台包括所述第一节点,以及多个第二节点,包括:/n在进行每轮迭代过程中,将批量训练数据集拆分为多个子数据集;/n将所述子数据集分发给各第二节点,并将本轮的初始梯度参数发送给各所述第二节点;/n根据各所述第二节点返回的更新后的梯度参数,确定本轮的目标梯度参数;/n其中,最后一轮的目标梯度参数作为机器学习模型的最终模型参数,所述机器学习模型用于进行数据处理。/n

【技术特征摘要】
1.一种模型训练的方法,其特征在于,应用于分布式批处理平台的第一节点,所述分布式批处理平台包括所述第一节点,以及多个第二节点,包括:
在进行每轮迭代过程中,将批量训练数据集拆分为多个子数据集;
将所述子数据集分发给各第二节点,并将本轮的初始梯度参数发送给各所述第二节点;
根据各所述第二节点返回的更新后的梯度参数,确定本轮的目标梯度参数;
其中,最后一轮的目标梯度参数作为机器学习模型的最终模型参数,所述机器学习模型用于进行数据处理。


2.根据权利要求1所述的方法,其特征在于,
若本轮是第一轮迭代,则本轮的初始梯度参数为梯度参数的初始化值;
若本轮不是第一轮迭代,则本轮的初始梯度参数为上一轮的目标梯度参数。


3.根据权利要求2所述的方法,其特征在于,在进行第一轮迭代之前,还包括:
初始化梯度参数和迭代次数。


4.根据权利要求1-3中任一项所述的方法,其特征在于,所述根据各所述第二节点返回的更新后的梯度参数,确定本轮的目标梯度参数,包括:
计算各所述第二节点返回的更新后的梯度参数的平均值,得到所述目标梯度参数。


5.根据权利要求4所述的方法,其特征在于,所述计算各所述第二节点返回的更新后的梯度参数的平均值,得到所述目标梯度参数,包括:
接收各所述第二节点返回的样本梯度和所述梯度参数更新次数,所述样本梯度是指样本对应的更新后的梯度参数;
计算所有所述样本梯度之和,以及梯度参数的更新总次数;
根据所有所述样本梯度之和以及梯度参数的更新总次数,计算所有所述样本梯度的平均值,得到所述目标梯度参数。


6.一种模型训练的方法,其特征在于,应用于分布式批处理平台的第二节点,所述分布式批处理平台包第一节点,以及多个第二节点,包括:
获取子数据集和本轮迭代的初始梯度参数,所述子数据集为批量训练数据集的子集;
根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数;
将所述更新后的梯度参数发送给第一节点,以使第一节点根据各所述第二节点返回的更新后的梯度参数确定本轮的目标梯度参数;
其中,最后一轮的目标梯度参数作为机器学习模型的最终模型参数,所述机器学习模型用于进行数据处理。


7.根据权利要求6所述的方法,其特征在于,所述根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数,包括:
通过并行计算的方式,根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数。


8.根据权利要求7所述的方法,其特征在于,所述并行计算的方式,根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数,包括:
调用CUDA的核函数,根据所述子数据集中的每个样本和初始梯度参数,计算每个样本的样本梯度,所述样本梯度是指样本对应的更新后的梯度参数。


9.根据权利要求8所述的方法,其特征在于,所述调用CUDA的核函数,根据所述子数据集中的每个样本和初始梯度参数,计算每个样本的样本梯度之后,还包括:
记录梯度参数的更新次数。
<...

【专利技术属性】
技术研发人员:黄乐乐彭南博
申请(专利权)人:京东数字科技控股股份有限公司
类型:发明
国别省市:北京;11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1