训练分类模型和数据分类的方法和装置制造方法及图纸

技术编号:29676215 阅读:13 留言:0更新日期:2021-08-13 21:58
本公开的实施例公开了训练分类模型和数据分类的方法和装置。该方法的具体实施方式包括:执行以下训练步骤:从样本集中选取至少一个样本;基于概念表征网络提取每个样本的概念表征和每个类别的概念表征;根据每个样本的概念表征与其所属类别的概念表征的距离计算每个样本所属类别的预测概率;根据每个样本所属类别的预测概率和类别标签计算总损失值;若总损失值小于预定阈值,则基于概念表征网络构造分类模型。该实施方式能够从有限的标注样本中学习新类别的鲁棒、可信的知识。

【技术实现步骤摘要】
训练分类模型和数据分类的方法和装置
本公开的实施例涉及计算机
,具体涉及训练分类模型和数据分类的方法和装置。
技术介绍
深度学习由于其优秀的数据学习能力、出色的任务执行性能,已经逐渐被应用到了人们生活、工作、学习的各个行业,比如人脸识别、商品检索等等。然而深度学习由于其模型的复杂性,往往需要海量的带有标签的针对某一任务采集的标注数据,来进行训练,才能获取性能稳定且置信度高的深度学习模型。然而,现实生活场景中,往往很难获取大量的带有标签的数据:1)部分场景中,比如商品检索场景,虽然有海量的商品数据,但是大部分商品数据并不具备直接的标注,而人工标注数据价格高、费时费力;2)部分场景中,比如医疗场景,部分疾病的数据很难采集大量的样本,比如罕见病可能只能收集一个病人的数据,导致数据多样性不足,无法利用这些数据训练获取泛化性能好的深度模型。
技术实现思路
本公开的实施例提出了训练分类模型和数据分类的方法和装置。第一方面,本公开的实施例提供了一种训练分类模型的方法,包括:执行以下训练步骤:从样本集中选取至少一个样本,其中,所述样本集中的样本具有类别标签;基于概念表征网络提取每个样本的概念表征和每个类别的概念表征;根据每个样本的概念表征与其所属类别的概念表征的距离计算每个样本所属类别的预测概率;根据每个样本所属类别的预测概率和类别标签计算总损失值;若总损失值小于预定阈值,则基于概念表征网络构造分类模型。在一些实施例中,该方法还包括:若总损失值不小于预定阈值,则调整概念表征网络的相关参数,继续执行训练步骤。在一些实施例中,基于概念表征网络提取每个样本的概念表征和每个类别的概念表征,包括:基于概念表征网络提取每个样本的概念表征;将类别标签相同的样本的概念表征聚类,得到每个类别的概念表征。在一些实施例中,概念表征网络包括特征提取网络、区域自注意力机制网络和概念聚合池化网络;以及基于概念表征网络提取每个样本的概念表征和每个类别的概念表征,包括:将至少一个样本分别输入特征提取网络,得到每个样本的区域特征;将每个样本的区域特征分别输入区域自注意力机制网络,得到每个样本的增强区域特征;将每个样本的增强区域特征分别输入概念聚合池化网络,得到每个样本的概念表征;将类别标签相同的样本的概念表征聚类,得到每个类别的概念表征。在一些实施例中,该方法还包括:根据样本集应用的领域的计算量选择网络层数与计算量正相关的特征提取网络。在一些实施例中,将每个样本的区域特征分别输入区域自注意力机制网络,得到每个样本的增强区域特征,包括:将每个样本的区域特征的位置信息分别进行编码,得到每个样本的位置编码;将每个样本的区域特征分别计算全局平均特征,得到每个样本的全局上下文信息;将每个样本的区域特征、位置编码和全局上下文信息构成每个样本的区域信息;将每个样本的区域信息分别输入区域自注意力机制网络,得到每个样本的增强区域特征。在一些实施例中,将每个样本的增强区域特征分别输入概念聚合池化网络,得到每个样本的概念表征,包括:将每个样本的增强区域特征分别输入注意力池化网络,得到每个样本的第一概念表征;将每个样本的增强区域特征分别进行平均池化,得到每个样本的第二概念表征;将每个样本的第一概念表征和第二概念表征的加权和确定为每个样本的概念表征。在一些实施例中,类别标签为平滑后的标签。第二方面,本公开的实施例提供了一种数据分类方法,包括:获取待分类的目标数据和至少一种类别的样本数据集;将目标数据和样本数据集输入采用如第一方面中任一项的方法生成的分类模型,输出目标数据所属类别的预测概率。第三方面,本公开的实施例提供了一种训练分类模型的装置,包括:选取单元,被配置成从样本集中选取至少一个样本,其中,所述样本集中的样本具有类别标签;提取单元,被配置成基于概念表征网络提取每个样本的概念表征和每个类别的概念表征;预测单元,被配置成根据每个样本的概念表征与其所属类别的距离计算每个样本所属类别的预测概率;计算单元,被配置成根据每个样本所属类别的预测概率和类别标签计算总损失值;循环单元,被配置成若所述总损失值小于预定阈值,则基于所述概念表征网络构造分类模型。第四方面,本公开的实施例提供了一种数据分类装置,包括:获取单元,被配置成获取待分类的目标数据和至少一种类别的样本数据集;分类单元,被配置成将目标数据和样本数据集输入采用如第一方面中任一项的方法生成的分类模型,输出目标数据所属类别的预测概率。第五方面,本公开的实施例提供了一种用于输出信息的电子设备,包括:一个或多个处理器;存储装置,其上存储有一个或多个计算机程序,当一个或多个计算机程序被一个或多个处理器执行,使得一个或多个处理器实现如第一方面和第二方面中任一项的方法。第六方面,本公开的实施例提供了一种计算机可读介质,其上存储有计算机程序,其中,计算机程序被处理器执行时实现如第一方面和第二方面中任一项的方法。本公开的实施例提供的训练分类模型和数据分类的方法和装置,通过对样本的概念表征以及类别的概念表征的学习,增加了与概念有关信息的权重,并且消除背景、噪声、与样本无关的信息对概念表征的影响,可解决小样本条件下,数据内容包含概念无关信息时,如何进一步鲁棒的获取单个数据的概念表征的问题。并通过汇总单个数据的概念表征得到整个类别的数据的概念表征。以图像为例,即将图像的局部区域视为图像的信息基本单位,利用自适应的方法,识别并汇总具有概念相关信息的区域信息。附图说明通过阅读参照以下附图所作的对非限制性实施例所作的详细描述,本公开的其它特征、目的和优点将会变得更明显:图1是本公开可以应用于其中的示例性系统架构图;图2是根据本公开训练分类模型的方法的一个实施例的流程图;图3是根据本公开训练分类模型的方法的一个应用场景的示意图;图4是根据本公开训练分类模型的装置的一个实施例的结构示意图;图5是根据本公开数据分类方法的一个实施例的流程图;图6是根据本公开数据分类装置的一个实施例的结构示意图;图7是适于用来实现本公开实施例的电子设备的计算机系统的结构示意图。具体实施方式下面结合附图和实施例对本公开作进一步的详细说明。可以理解的是,此处所描述的具体实施例仅仅用于解释相关专利技术,而非对该专利技术的限定。另外还需要说明的是,为了便于描述,附图中仅示出了与有关专利技术相关的部分。需要说明的是,在不冲突的情况下,本公开中的实施例及实施例中的特征可以相互组合。下面将参考附图并结合实施例来详细说明本公开。图1示出了可以应用本公开实施例的训练分类模型的方法、训练分类模型的装置、数据分类的方法或数据分类的装置的示例性系统架构100。如图1所示,系统架构100可以包括终端101、102,网络103、数据库服务器104和服务器105。网络103用以在终端101、102,数据库服务器104与服务器105之间提供通信链路的介质。网络103可以包括各种连接类型,例如有线、无线本文档来自技高网...

【技术保护点】
1.一种训练分类模型的方法,包括:执行以下训练步骤:/n从样本集中选取至少一个样本,其中,所述样本集中的样本具有类别标签;/n基于概念表征网络提取每个样本的概念表征和每个类别的概念表征;/n根据每个样本的概念表征与其所属类别的概念表征的距离计算每个样本所属类别的预测概率;/n根据每个样本所属类别的预测概率和类别标签计算总损失值;/n若所述总损失值小于预定阈值,则基于所述概念表征网络构造分类模型。/n

【技术特征摘要】
1.一种训练分类模型的方法,包括:执行以下训练步骤:
从样本集中选取至少一个样本,其中,所述样本集中的样本具有类别标签;
基于概念表征网络提取每个样本的概念表征和每个类别的概念表征;
根据每个样本的概念表征与其所属类别的概念表征的距离计算每个样本所属类别的预测概率;
根据每个样本所属类别的预测概率和类别标签计算总损失值;
若所述总损失值小于预定阈值,则基于所述概念表征网络构造分类模型。


2.根据权利要求1所述的方法,其中,所述方法还包括:
若所述总损失值不小于预定阈值,则调整所述概念表征网络的相关参数,继续执行所述训练步骤。


3.根据权利要求1所述的方法,其中,所述基于概念表征网络提取每个样本的概念表征和每个类别的概念表征,包括:
基于概念表征网络提取每个样本的概念表征;
将类别标签相同的样本的概念表征聚类,得到每个类别的概念表征。


4.根据权利要求3所述的方法,其中,所述概念表征网络包括特征提取网络、区域自注意力机制网络和概念聚合池化网络;以及
所述基于概念表征网络提取每个样本的概念表征,包括:
将所述至少一个样本分别输入特征提取网络,得到每个样本的区域特征;
将每个样本的区域特征分别输入区域自注意力机制网络,得到每个样本的增强区域特征;
将每个样本的增强区域特征分别输入概念聚合池化网络,得到每个样本的概念表征。


5.根据权利要求4所述的方法,其中,所述方法还包括:
根据所述样本集应用的领域的计算量选择网络层数与计算量正相关的特征提取网络。


6.根据权利要求4所述的方法,其中,所述将每个样本的区域特征分别输入区域自注意力机制网络,得到每个样本的增强区域特征,包括:
将每个样本的区域特征的位置信息分别进行编码,得到每个样本的位置编码;
将每个样本的区域特征分别计算全局平均特征,得到每个样本的全局上下文信息;
将每个样本的区域特征、位置编码和全局上下文信息构成每个样本的区域信息;
将每个样本的区域信息分别输入区域自注意力机制网络,得到每个样本的增强区域特征。

【专利技术属性】
技术研发人员:詹忆冰韩梦雅
申请(专利权)人:京东数科海益信息科技有限公司
类型:发明
国别省市:北京;11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1