一种基于梯度提升决策树的模型训练方法及装置制造方法及图纸

技术编号:20589694 阅读:16 留言:0更新日期:2019-03-16 07:24
公开了一种基于梯度提升决策树的模型训练方法及装置。将一个GBDT算法流程划分为两个阶段,在前一阶段,从与目标业务场景相近的业务场景的数据域获取已标注样本依次训练若干决策树,并确定经过前一阶段训练后产生的训练残差;在后一阶段,从目标业务场景的数据域获取已标注样本,并基于所述训练残差,继续训练若干决策树。最终,应用于目标业务场景的模型实际上是由前一阶段训练出的决策树与后一阶段训练出的决策树集成得到的。

【技术实现步骤摘要】
一种基于梯度提升决策树的模型训练方法及装置
本说明书实施例涉及信息
,尤其涉及一种基于梯度提升决策树的模型训练方法及装置。
技术介绍
众所周知,当需要训练应用于某个业务场景的预测模型时,通常需要从该业务场景的数据域获取大量数据进行标注,作为已标注样本,进行模型训练。如果已标注样本的数量较少,则通常无法得到效果合格的模型。需要说明的是,某个业务场景的数据域,实际上是基于该业务场景所产生的业务数据的集合。然而,实践中,某些特殊业务场景下积累的数据较少。这导致当需要训练应用于某个特殊业务场景的模型时,无法从该特殊业务场景的数据域获取足够的已标注样本,从而无法得到效果合格的模型。
技术实现思路
为了解决某些特殊业务场景下积累的数据较少导致无法训练出效果合格的模型的问题,本说明书实施例提供一种基于梯度提升决策树的模型训练方法及装置,技术方案如下:根据本说明书实施例的第1方面,提供一种基于梯度提升决策树的模型训练方法,用于训练应用于目标业务场景的目标模型,所述方法包括:获取第一样本集合;所述第一样本集合是从源业务场景的数据域获取的已标注样本的集合;所述源业务场景是与所述目标业务场景相近的业务场景;使用所述第一样本集合,执行梯度提升决策树GBDT算法流程,依次训练至少一个决策树,直至满足预设的训练暂停条件;根据使用所述第一样本集合训练出的决策树,确定训练残差;获取第二样本集合;所述第二样本集合是从所述目标业务场景的数据域获取的已标注样本的集合;使用所述第二样本集合,基于所述训练残差继续执行GBDT算法流程,依次训练至少一个决策树,直至满足预设的训练停止条件;其中,所述目标模型是由已训练出的决策树集成得到的。根据本说明书实施例的第2方面,提供一种预测方法,包括:从目标业务场景的数据域获取待预测数据;根据所述待预测数据,确定所述待预测数据对应的模型输入特征;将所述模型输入特征输入到应用于所述目标业务场景的预测模型,以输出预测结果;所述预测模型是根据上述第1方面的方法得到的。根据本说明书实施例的第3方面,提供一种基于梯度提升决策树的模型训练装置,用于训练应用于目标业务场景的目标模型,所述装置包括:第一获取模块,获取第一样本集合;所述第一样本集合是从源业务场景的数据域获取的已标注样本的集合;所述源业务场景是与所述目标业务场景相近的业务场景;第一训练模块,使用所述第一样本集合,执行梯度提升决策树GBDT算法流程,依次训练至少一个决策树,直至满足预设的训练暂停条件;计算模块,根据使用所述第一样本集合训练出的决策树,确定训练残差;第二获取模块,获取第二样本集合;所述第二样本集合是从所述目标业务场景的数据域获取的已标注样本的集合;第二训练模块,使用所述第二样本集合,基于所述训练残差继续执行GBDT算法流程,依次训练至少一个决策树,直至满足预设的训练停止条件;其中,所述目标模型是由已训练出的决策树集成得到的。根据本说明书实施例的第4方面,提供一种预测装置,包括:获取模块,从目标业务场景的数据域获取待预测数据;确定模块,根据所述待预测数据,确定所述待预测数据对应的模型输入特征;输入模块,将所述模型输入特征输入到应用于所述目标业务场景的预测模型,以输出预测结果;所述预测模型是根据上述第1方面的方法得到的。本说明书实施例所提供的技术方案,将一个GBDT算法流程划分为两个阶段,在前一阶段,从与目标业务场景相近的业务场景的数据域获取已标注样本依次训练若干决策树,并确定经过前一阶段训练后产生的训练残差;在后一阶段,从目标业务场景的数据域获取已标注样本,并基于所述训练残差,继续训练若干决策树。最终,应用于目标业务场景的模型实际上是由前一阶段训练出的决策树与后一阶段训练出的决策树集成得到的。通过本说明书实施例,虽然目标业务场景下积累的数据不足,但是,可以借助与目标业务场景相近的业务场景的数据,训练应用于目标业务场景的模型。经过测试,可以得到效果合格的模型。应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本说明书实施例。此外,本说明书实施例中的任一实施例并不需要达到上述的全部效果。附图说明为了更清楚地说明本说明书实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本说明书实施例中记载的一些实施例,对于本领域普通技术人员来讲,还可以根据这些附图获得其他的附图。图1是本说明书实施例提供的一种基于梯度提升决策树的模型训练方法的流程示意图;图2是本说明书实施例提供的方案架构示意图;图3是本说明书实施例提供的一种预测方法的流程示意图;图4是本说明书实施例提供的一种基于梯度提升决策树的模型训练装置的结构示意图;图5是本说明书实施例提供的一种预测装置的结构示意图;图6是用于配置本说明书实施例方法的一种设备的结构示意图。具体实施方式本专利技术借鉴了机器学习
的迁移学习思想。在面对训练应用于目标业务场景的模型的需求时,如果目标业务场景下积累的数据不足,那么可以利用与目标业务场景相近的业务场景下积累的数据进行模型训练。具体地,本专利技术将迁移学习思想与梯度提升决策树(GradientBoostingDecisionTree,GBDT)算法相结合,对GBDT算法流程进行了改进。在本说明书实施例中,针对一个GBDT算法流程,先使用与目标业务场景相近的业务场景下产生的数据进行训练,满足一定的训练暂停条件之后,暂停训练,并计算当前的训练残差;随后,使用目标业务场景下产生的数据,基于所述训练残差继续训练,直到满足一定的训练停止条件。如此,将训练得到的GBDT模型应用于目标业务场景,可以取得较好的预测效果。需要说明的是,在本文中,与目标业务场景相近的业务场景,实际上是与目标业务场景相类似或相关联的业务场景。本文将与目标业务场景相近的业务场景称为源业务场景。举例来说,假设目标业务场景是男性商品推荐场景,为了更好的根据男性用户的年龄进行商品推荐,需要训练用于预测男性用户年龄的模型。然而,由于男性商品推荐功能上线不久,积累的男性用户购买记录较少(购买记录中记载了购买者的各种特征信息以及购买者的年龄),因为无法获得足够的已标注样本进行训练。于是,可以以女性商品推荐场景为目标业务场景对应的源业务场景。由于女性商品推荐功能早已上线,已经积累了大量女性用户购买记录,因此,在本说明书实施例中,可以借助于积累的大量女性用户购买记录,使用少量的男性用户购买记录训练出效果合格的,用于预测男性用户年龄的模型。为了使本领域技术人员更好地理解本说明书实施例中的技术方案,下面将结合本说明书实施例中的附图,对本说明书实施例中的技术方案进行详细地描述,显然,所描述的实施例仅仅是本说明书的一部分实施例,而不是全部的实施例。基于本说明书中的实施例,本领域普通技术人员所获得的所有其他实施例,都应当属于保护的范围。以下结合附图,详细说明本说明书各实施例提供的技术方案。图1是本说明书实施例提供的一种基于梯度提升决策树的模型训练方法的流程示意图,包括以下步骤:S100:获取第一样本集合。本方法的目的是训练应用于目标业务场景的目标模型。在本说明书实施例中,目标业务场景对应的源业务场景的数据域中积累的数据较多,可以从源本文档来自技高网
...

【技术保护点】
1.一种基于梯度提升决策树的模型训练方法,用于训练应用于目标业务场景的目标模型,所述方法包括:获取第一样本集合;所述第一样本集合是从源业务场景的数据域获取的已标注样本的集合;所述源业务场景是与所述目标业务场景相近的业务场景;使用所述第一样本集合,执行梯度提升决策树GBDT算法流程,依次训练至少一个决策树,直至满足预设的训练暂停条件;根据使用所述第一样本集合训练出的决策树,确定训练残差;获取第二样本集合;所述第二样本集合是从所述目标业务场景的数据域获取的已标注样本的集合;使用所述第二样本集合,基于所述训练残差继续执行GBDT算法流程,依次训练至少一个决策树,直至满足预设的训练停止条件;其中,所述目标模型是由已训练出的决策树集成得到的。

【技术特征摘要】
1.一种基于梯度提升决策树的模型训练方法,用于训练应用于目标业务场景的目标模型,所述方法包括:获取第一样本集合;所述第一样本集合是从源业务场景的数据域获取的已标注样本的集合;所述源业务场景是与所述目标业务场景相近的业务场景;使用所述第一样本集合,执行梯度提升决策树GBDT算法流程,依次训练至少一个决策树,直至满足预设的训练暂停条件;根据使用所述第一样本集合训练出的决策树,确定训练残差;获取第二样本集合;所述第二样本集合是从所述目标业务场景的数据域获取的已标注样本的集合;使用所述第二样本集合,基于所述训练残差继续执行GBDT算法流程,依次训练至少一个决策树,直至满足预设的训练停止条件;其中,所述目标模型是由已训练出的决策树集成得到的。2.如权利要求1所述的方法,所述训练暂停条件,具体包括:使用所述第一样本集合训练出的决策树的数量达到第一指定数量。3.如权利要求1所述的方法,所述训练停止条件,具体包括:使用所述第二样本集合训练出的决策树的数量达到第二指定数量。4.如权利要求1所述的方法,在使用所述第二样本集合,基于所述训练残差继续执行GBDT算法流程之前,所述方法还包括:获取第三样本集合;所述第三样本集合是从其他源业务场景的数据域获取的已标注样本的集合;使用所述第三样本集合,基于所述训练残差继续执行GBDT算法流程,依次训练至少一个决策树,直至满足预设的训练再暂停条件;根据使用所述第一样本集合训练出的决策树和使用所述第三样本集合训练出的决策树,重新确定所述训练残差。5.如权利要求4所述的方法,所述训练再暂停条件,具体包括:使用所述第三样本集合训练出的决策树的数量达到第三指定数量。6.一种预测方法,包括:从目标业务场景的数据域获取待预测数据;根据所述待预测数据,确定所述待预测数据对应的模型输入特征;将所述模型输入特征输入到应用于所述目标业务场景的预测模型,以输出预测结果;所述预测模型是根据权利要求1~5任一项所述的方法得到的。7.一种基于梯度提升决策树的模型训练装置,用于训练应用于目标业务场景的目标模型,所述装置包括:第一获取模块,获取第一样本集合;所述第一样本集合是从源业务场景的数据域获取的已标注样本的集合;所述源业务场景是与所述...

【专利技术属性】
技术研发人员:陈超超周俊
申请(专利权)人:阿里巴巴集团控股有限公司
类型:发明
国别省市:开曼群岛,KY

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1