【技术实现步骤摘要】
模型训练的方法、装置、设备及计算机可读存储介质
本专利技术实施例涉及计算机
,尤其涉及一种模型训练的方法、装置、设备及计算机可读存储介质。
技术介绍
在多个应用领域中,如自然语言处理、计算机视觉、语音识别、预测分析、推荐引擎等等,均需要使用机器学习模型进行数据处理。在机器学习中,在进行数据处理通常会根据输入来预测输出,预测值和真实值之间会有一定的误差,在模型训练的过程中会使用优化器(optimizer)来最小化这个误差,梯度下降法(GradientDescent)就是一种常用的优化器。批量梯度下降(BatchGradientDescent,简称BGD)是进行模型训练时的一种常用的梯度下降法,每次都使用训练集中的所有样本来更新梯度参数,可以得到全局最优解。在实现本专利技术过程中,专利技术人发现现有技术中至少存在如下问题:当样本量很大时,批量梯度下降的速度就会非常慢,这样会导致模型训练时间长,模型训练效率低,从而导致获取机器学习模型并进行数据处理的效率低。
技术实现思路
本专利技术实施例提供一种模型训练的方法、装置、设备及计算机可读存储介质,用以解决当样本量很大时,批量梯度下降的速度就会非常慢,导致模型训练时间长,效率低的问题。第一方面,本专利技术实施例提供一种模型训练的方法,应用于分布式批处理平台的第一节点,所述分布式批处理平台包括所述第一节点,以及多个第二节点,包括:在进行每轮迭代过程中,将批量训练数据集拆分为多个子数据集;将所述子数据集分发给各第二节 ...
【技术保护点】
1.一种模型训练的方法,其特征在于,应用于分布式批处理平台的第一节点,所述分布式批处理平台包括所述第一节点,以及多个第二节点,包括:/n在进行每轮迭代过程中,将批量训练数据集拆分为多个子数据集;/n将所述子数据集分发给各第二节点,并将本轮的初始梯度参数发送给各所述第二节点;/n根据各所述第二节点返回的更新后的梯度参数,确定本轮的目标梯度参数;/n其中,最后一轮的目标梯度参数作为机器学习模型的最终模型参数,所述机器学习模型用于进行数据处理。/n
【技术特征摘要】
1.一种模型训练的方法,其特征在于,应用于分布式批处理平台的第一节点,所述分布式批处理平台包括所述第一节点,以及多个第二节点,包括:
在进行每轮迭代过程中,将批量训练数据集拆分为多个子数据集;
将所述子数据集分发给各第二节点,并将本轮的初始梯度参数发送给各所述第二节点;
根据各所述第二节点返回的更新后的梯度参数,确定本轮的目标梯度参数;
其中,最后一轮的目标梯度参数作为机器学习模型的最终模型参数,所述机器学习模型用于进行数据处理。
2.根据权利要求1所述的方法,其特征在于,
若本轮是第一轮迭代,则本轮的初始梯度参数为梯度参数的初始化值;
若本轮不是第一轮迭代,则本轮的初始梯度参数为上一轮的目标梯度参数。
3.根据权利要求2所述的方法,其特征在于,在进行第一轮迭代之前,还包括:
初始化梯度参数和迭代次数。
4.根据权利要求1-3中任一项所述的方法,其特征在于,所述根据各所述第二节点返回的更新后的梯度参数,确定本轮的目标梯度参数,包括:
计算各所述第二节点返回的更新后的梯度参数的平均值,得到所述目标梯度参数。
5.根据权利要求4所述的方法,其特征在于,所述计算各所述第二节点返回的更新后的梯度参数的平均值,得到所述目标梯度参数,包括:
接收各所述第二节点返回的样本梯度和所述梯度参数更新次数,所述样本梯度是指样本对应的更新后的梯度参数;
计算所有所述样本梯度之和,以及梯度参数的更新总次数;
根据所有所述样本梯度之和以及梯度参数的更新总次数,计算所有所述样本梯度的平均值,得到所述目标梯度参数。
6.一种模型训练的方法,其特征在于,应用于分布式批处理平台的第二节点,所述分布式批处理平台包第一节点,以及多个第二节点,包括:
获取子数据集和本轮迭代的初始梯度参数,所述子数据集为批量训练数据集的子集;
根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数;
将所述更新后的梯度参数发送给第一节点,以使第一节点根据各所述第二节点返回的更新后的梯度参数确定本轮的目标梯度参数;
其中,最后一轮的目标梯度参数作为机器学习模型的最终模型参数,所述机器学习模型用于进行数据处理。
7.根据权利要求6所述的方法,其特征在于,所述根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数,包括:
通过并行计算的方式,根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数。
8.根据权利要求7所述的方法,其特征在于,所述并行计算的方式,根据所述子数据集和初始梯度参数,计算得到更新后的梯度参数,包括:
调用CUDA的核函数,根据所述子数据集中的每个样本和初始梯度参数,计算每个样本的样本梯度,所述样本梯度是指样本对应的更新后的梯度参数。
9.根据权利要求8所述的方法,其特征在于,所述调用CUDA的核函数,根据所述子数据集中的每个样本和初始梯度参数,计算每个样本的样本梯度之后,还包括:
记录梯度参数的更新次数。
<...
【专利技术属性】
技术研发人员:黄乐乐,彭南博,
申请(专利权)人:京东数字科技控股股份有限公司,
类型:发明
国别省市:北京;11
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。