分类任务模型的训练方法、装置、设备及存储介质制造方法及图纸

技术编号:21800111 阅读:28 留言:0更新日期:2019-08-07 10:45
本申请公开了一种分类任务模型的训练方法、装置、设备及存储介质,涉及机器学习技术领域,所述方法包括:采用第一数据集训练初始的特征提取器,该第一数据集是类别不均衡数据集;构建生成对抗网络,该生成对抗网络包括特征提取器、特征生成器和域分类器;采用第二类别样本对生成对抗网络进行训练,得到完成训练的特征生成器;构建分类任务模型,该分类任务模型包括完成训练的特征生成器、特征提取器和分类器;采用第一数据集对分类任务模型进行训练;其中,完成训练的特征生成器用于对第二类别样本在特征空间进行扩增。本申请通过特征生成器对少数类别样本在特征空间进行扩增,提高最终训练得到的分类任务模型的精度。

Training Method, Device, Equipment and Storage Medium of Classified Task Model

【技术实现步骤摘要】
分类任务模型的训练方法、装置、设备及存储介质
本申请实施例涉及机器学习
,特别涉及一种分类任务模型的训练方法、装置、设备及存储介质。
技术介绍
机器学习对于处理分类任务具有较好的性能表现,例如基于深度神经网络构建分类任务模型,并通过适当的训练样本对该模型进行训练,完成训练的分类任务模型即可用于处理分类任务,如图像识别、语音识别等分类任务。在训练分类任务模型时,训练数据集中包含的训练样本的类别可能并不均衡,例如正样本的数量远少于负样本的数量,这样的训练数据集可以称为类别不均衡数据集。如果采用类别不均衡数据集对分类任务模型进行训练,会导致最终得到的分类任务模型的性能表现不佳。在相关技术中,提出了通过样本上采样来使得类别不均衡数据集中不同类别的训练样本数量保持均衡。所谓样本上采样,就是以数量多的一方的样本数量为基准,把数量少的一方的样本数量进行扩增,生成和数量多的一方相同数量的样本。例如,当正样本的数量小于负样本的数量时,可以复制一些正样本,使得正样本的数量和负样本的数量相同。经样本上采样得到的训练样本训练出的分类任务模型,存在过拟合的情况,即该分类任务模型的训练误差远小于其在测试数本文档来自技高网...

【技术保护点】
1.一种分类任务模型的训练方法,其特征在于,所述方法包括:采用第一数据集训练初始的特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量;构建生成对抗网络,所述生成对抗网络包括所述特征提取器、特征生成器和域分类器;其中,所述特征生成器用于生成与所述特征提取器相同维度的特征向量,所述域分类器用于对所述特征提取器输出的特征向量和所述特征生成器输出的特征向量进行区分;采用所述第二类别样本对所述生成对抗网络进行训练,得到完成训练的特征生成器;构建分类任务模型,所述分类任务模型包括所述完成训练的特征生成器、所述特征提取器和分类...

【技术特征摘要】
1.一种分类任务模型的训练方法,其特征在于,所述方法包括:采用第一数据集训练初始的特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量;构建生成对抗网络,所述生成对抗网络包括所述特征提取器、特征生成器和域分类器;其中,所述特征生成器用于生成与所述特征提取器相同维度的特征向量,所述域分类器用于对所述特征提取器输出的特征向量和所述特征生成器输出的特征向量进行区分;采用所述第二类别样本对所述生成对抗网络进行训练,得到完成训练的特征生成器;构建分类任务模型,所述分类任务模型包括所述完成训练的特征生成器、所述特征提取器和分类器;采用所述第一数据集对所述分类任务模型进行训练;其中,所述完成训练的特征生成器用于对所述第二类别样本在特征空间进行扩增。2.根据权利要求1所述的方法,其特征在于,所述采用所述第二类别样本对所述生成对抗网络进行训练,得到完成训练的特征生成器,包括:在所述生成对抗网络的每一轮训练过程中,为所述特征提取器的输入赋予第一标签,为所述特征生成器的输入赋予第二标签;计算所述域分类器的第一损失函数值;根据所述第一损失函数值对所述域分类器的参数进行更新;屏蔽所述特征提取器的输入,为所述特征生成器的输入赋予所述第一标签;计算所述域分类器的第二损失函数值;根据所述第二损失函数值对所述特征生成器的参数进行更新。3.根据权利要求1所述的方法,其特征在于,所述特征生成器的输入包括先验数据与噪声数据的叠加;其中,所述先验数据从所述第一数据集的所述第二类别样本中提取,或者,所述先验数据从第二数据集中与所述第二类别样本同类别的样本中提取。4.根据权利要求1至3任一项所述的方法,其特征在于,所述分类任务模型还包括:数据清洗单元;所述数据清洗单元,用于对所述特征生成器和所述特征提取器输出的异常特征向量进行过滤。5.根据权利要求4所述的方法,其特征在于,所述数据清洗单元,用于:从所述特征生成器和所述特征提取器输出的特征向量中,筛选出符合预设条件的特征向量对,所述符合预设条件的特征向量对是指标签不同且相似度最大的两个特征向量;将所述符合预设条件的特征向量对作为所述异常特征向量进行过滤。6.根据权利要求1至3任一项所述的方法,其特征在于,所述采用第一数据集训练初始的特征提取器,包括:构建初始的分类任务模型,所述初始的分类任务模型包括所述初始的特征提取器和初始的分类器;采用所述第一数据集对所述初始的分类任务模型进行训练,得到初始训练后的特征提取器,所述初始训练后的特征提取器被用于所述生成对抗网络中。7.一种分类任务模型的训练装置,其特征在于,所述装置包括:第一训练模块,用于采用第一数据集训练初始的特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量;第一构...

【专利技术属性】
技术研发人员:沈荣波周可田宽颜克洲江铖
申请(专利权)人:深圳市腾讯计算机系统有限公司华中科技大学
类型:发明
国别省市:广东,44

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1