一种机器学习的均衡方法及相关装置制造方法及图纸

技术编号:33378124 阅读:16 留言:0更新日期:2022-05-11 22:46
本申请公开了一种机器学习的均衡方法及相关装置,所述方法包括从所属正确类别为主要类别的训练样本中,按照预设概率筛选部分训练样本作为目标训练样本;对原始类别序列中的类别标签进行错排以形成错排类别序列;基于目标训练样本和错排类别序列生成错排输入数据;以错排输入数据对机器学习模型进行训练。本申请通过错排方式将部分正确类别为主要类别的训练样本的原始类别序列修改为错排类别序列,这样即可以保留所有训练样本,使得机器学习可以学习到所有训练样本所携带的特征信息,同时又阻止机器学习模型对主要类别的过度学习,使得机器学习在主要类别上的权重更新变小且次要类别不受影响,使得机器学习的训练过程更加平衡,提高了模型性能。提高了模型性能。提高了模型性能。

【技术实现步骤摘要】
一种机器学习的均衡方法及相关装置


[0001]本申请涉及人工智能
,特别涉及一种机器学习的均衡方法及相关装置。

技术介绍

[0002]随着科学技术和人工智能的迅速发展,机器学习可以用来实现各自功能,例如,事件检测、文本分类、图像分类以及目标检测等。然而,在实际应用中,经常会出现呈长尾分布(long

tail)的数据集,该数据集中部分样本的类别标签所占据的数据量大,部分样本的类别标签所占据的数据量小,导致机器学习的学习不平衡,机器学习为了解决这种不平衡会大量训练数据量小的类别标签,使得机器学习模型可以针对对数据量小的类别标签对应的样本进行更精准识别,从而导致基于数据集训练的机器学习模型在数据量大的类别标签上出现过度拟合的现象,进而影响模型的模型性能。
[0003]因而现有技术还有待改进和提高。

技术实现思路

[0004]本申请要解决的技术问题在于,针对现有技术的不足,提供一种机器学习的均衡方法及相关装置。
[0005]为了解决上述技术问题,本申请实施例第一方面提供了一种机器学习的均衡方法,所述的方法包括:
[0006]一种机器学习的均衡方法,所述的方法包括:
[0007]从所属正确类别为主要类别的训练样本中,按照预设概率筛选部分训练样本作为目标训练样本;
[0008]对原始类别序列中的类别标签进行错排,以形成错排类别序列,其中,所述原始类别序列为基于所述机器学习的所有训练样本所属的类别标签生成的;
[0009]基于所述目标训练样本和所述错排类别序列生成错排输入数据;
[0010]以所述错排输入数据对机器学习模型进行训练。
[0011]在一个实现方式中,所述原始类别序列的生成过程具体包括:
[0012]获取所述机器学习的所有训练样本所属的类别标签,并随机排列获取到的所有类别标签以生成原始类别序列。
[0013]在一个实现方式中,所述主要类别为数据量大于预设数量的类别标签,其中,所述数据量为属于所述类别标签的训练样本的数量。
[0014]在一个实现方式中,对所述原始类别序列中的类别标签进行错排,以形成错排类别序列的步骤,具体包括:
[0015]对所述原始类别序列中的各类别标签按照各自对应的数据量进行降序排列后,以得到初始类别序列;
[0016]在所述初始类别序列中,从前向后选取预设个数的类别标签作为目标错排类别,其中,预设个数的类别标签不包括所述目标训练样本对应的正确类别标签;
[0017]在所述原始类别序列中对选取到的目标错排类别进行全错排,以形成错排类别序列。
[0018]在一个实现方式中,所述基于所述目标训练样本和所述错排类别序列生成错排输入数据的步骤,具体包括:
[0019]将所述目标训练样本与所述错排类别序列相连接,以生成所述错排输入数据。
[0020]在一个实现方式中,所述以所述错排输入数据对机器学习模型进行训练的步骤,具体包括:
[0021]将所述错排输入数据输入所述样本训练模块,通过所述样本训练模块确定所述错排输入数据对应的语境化表达向量;所述语境化表达向量包括所述目标训练样本中的语义词组对应的语境化表达向量和所述错排类别序列中类别标签对应的语境化表达向量;
[0022]提取所述错排类别序列中类别标签对应的语境化表达向量,并输入所述预测模块,通过所述预测模块确定所述错排输入数据对应的各个类别标签的概率;
[0023]基于各类别标签的概率确定所述错排输入数据对应的预测类别,并基于所述预测类别对所述机器学习模型的训练进行调整。
[0024]在一个实现方式中,所述的方法还包括:
[0025]对于所属正确类别为次要类别的训练样本,以及所属正确类别为主要类别且未被筛选为目标训练样本的训练样本,基于所述训练样本和所述原始类别序列生成输入数据,并以所述输入数据对所述机器学习模型进行训练。
[0026]在一个实现方式中,所述基于所述训练样本和所述原始类别序列生成输入数据,并以所述输入数据对所述机器学习模型进行训练,具体包括:
[0027]将所述训练样本与所述原始类别序列相连接,以生成所述输入数据。
[0028]在一个实现方式中,所述以所述输入数据对所述机器学习模型进行训练的步骤,具体包括:
[0029]将所述输入数据输入所述样本训练模块,通过所述样本训练模块确定所述输入数据对应的语境化表达向量;所述语境化表达向量包括所述训练样本中的语义词组对应的语境化表达向量和所述原始类别序列中类别标签对应的语境化表达向量;
[0030]提取所述原始类别序列中类别标签对应的语境化表达向量,并输入所述预测模块,通过所述预测模块确定所述输入数据对应的各个类别标签的概率;
[0031]基于各类别标签的概率确定所述输入数据对应的预测类别,并基于所述预测类别对所述机器学习模型的训练进行调整。
[0032]本申请实施例第二方面提供了一种机器学习的均衡装置,所述的装置包括:
[0033]筛选模块,从所属正确类别为主要类别的训练样本中,按照预设概率筛选部分训练样本作为目标训练样本;
[0034]形成模块,用于对原始类别序列中的类别标签进行错排,以形成错排类别序列,其中,所述原始类别序列为基于所述机器学习的所有训练样本所属的类别标签生成的;
[0035]生成模块,用于基于所述目标训练样本和所述错排类别序列生成错排输入数据;
[0036]训练模块,用于以所述错排输入数据对机器学习模型进行训练。
[0037]本申请实施例第三方面提供了一种计算机可读存储介质,所述计算机可读存储介质存储有一个或者多个程序,所述一个或者多个程序可被一个或者多个处理器执行,以实
现如上所述的机器学习的均衡方法中的步骤。
[0038]本申请实施例第四方面提供了一种终端设备,其包括:处理器、存储器及通信总线;所述存储器上存储有可被所述处理器执行的计算机可读程序;
[0039]所述通信总线实现处理器和存储器之间的连接通信;
[0040]所述处理器执行所述计算机可读程序时实现如上所述的机器学习的均衡方法中的步骤。
[0041]有益效果:与现有技术相比,本申请从所属正确类别为主要类别的训练样本中,按照预设概率筛选部分训练样本作为目标训练样本;对原始类别序列中的类别标签进行错排,以形成错排类别序列;基于所述目标训练样本和所述错排类别序列生成错排输入数据;以所述错排输入数据对机器学习模型进行训练。本申请通过错排方式将部分正确类别为主要类别的训练样本的原始类别序列修改为错排类别序列,这样即可以保留所有训练样本,使得机器学习可以学习到所有训练样本所携带的特征信息,同时又阻止机器学习过程中模型对类别序列中主要类别的过度学习,使得机器学习在主要类别上的权重更新变小且次要类别不受影响,使得机器学习的训练过程更加平衡,提高了模型性能。
附图说明
[0042]为了更清楚地说明本申本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种机器学习的均衡方法,其特征在于,所述的方法包括:从所属正确类别为主要类别的训练样本中,按照预设概率筛选部分训练样本作为目标训练样本;对原始类别序列中的类别标签进行错排,以形成错排类别序列,其中,所述原始类别序列为基于所述机器学习的所有训练样本所属的类别标签生成的;基于所述目标训练样本和所述错排类别序列生成错排输入数据;以所述错排输入数据对机器学习模型进行训练。2.根据权利要求1所述的机器学习的均衡方法,其特征在于,所述原始类别序列的生成过程具体包括:获取所述机器学习的所有训练样本所属的类别标签,并随机排列获取到的所有类别标签以生成原始类别序列。3.根据权利要求2所述的机器学习的均衡方法,其特征在于,所述主要类别为数据量大于预设数量的类别标签,其中,所述数据量为属于所述类别标签的训练样本的数量。4.根据权利要求1

3任意一项所述的机器学习的均衡方法,其特征在于,对所述原始类别序列中的类别标签进行错排,以形成错排类别序列的步骤,具体包括:对所述原始类别序列中的各类别标签按照各自对应的数据量进行降序排列后,以得到初始类别序列;在所述初始类别序列中,从前向后选取预设个数的类别标签作为目标错排类别,其中,预设个数的类别标签不包括所述目标训练样本对应的正确类别标签;在所述原始类别序列中对选取到的目标错排类别进行全错排,以形成错排类别序列。5.根据权利要求4所述的机器学习的均衡方法,其特征在于,所述基于所述目标训练样本和所述错排类别序列生成错排输入数据的步骤,具体包括:将所述目标训练样本与所述错排类别序列相连接,以生成所述错排输入数据。6.根据权利要求4所述的机器学习的均衡方法,其特征在于,所述机器学习的模型包括样本训练模块以及预测模块,所述以所述错排输入数据对机器学习模型进行训练的步骤,具体包括:将所述错排输入数据输入所述样本训练模块,通过所述样本训练模块确定所述错排输入数据对应的语境化表达向量;所述语境化表达向量包括所述目标训练样本中的语义词组对应的语境化表达向量和所述错排类别序列中类别标签对应的语境化表达向量;提取所述错排类别序列中类别标签对应的语境化表达向量,并输入所述预测模块,通过所述预测模块确定所述错排输入数据对应的各个类别标签的概率;基于各类别标签的概率确定所述错排输入数据对应的预测类别,并基于所述预测类别对所述机器学习模型的训练进行调整。7.根据权利要求1

...

【专利技术属性】
技术研发人员:杨海钦赵嘉晨
申请(专利权)人:粤港澳大湾区数字经济研究院福田
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1