一种分类模型训练方法及装置、存储介质制造方法及图纸

技术编号:24499292 阅读:32 留言:0更新日期:2020-06-13 04:21
本公开是关于一种分类模型训练方法及装置、存储介质。分类模型训练方法,包括:利用已知离散标签的第一数据集进行分类模型的初始训练,得到第一分类模型;将第一数据集内的样本数据输入到第一分类模型,得到连续标签;利用样本数据和连续标签构成的第二数据集,继续训练第一分类模型得到第二分类模型;利用第二分类模型对第一数据集包含的样本数据进行分类,得到第二分类模型输出的连续标签;将第二分类模型输出的连续标签满足存疑条件的样本数据输出,以获得对样本数据重新标注的离散标签,以更新第二数据集;利用更新后的第二数据集继续训练第二分类模型,直到满足训练停止条件。

A training method, device and storage medium of classification model

【技术实现步骤摘要】
一种分类模型训练方法及装置、存储介质
本公开涉及信息
,尤其涉及一种分类模型训练方法及装置、存储介质。
技术介绍
神经网络等可以构成分类模型。分类模型可以用于对图像和文本进行分类。但是这种分类模型在使用之前,需要使用标注好的样本数据进行标注。且由于分类模型训练涉及大量样本数据的标注,具有标注工作量大,且若标注工作量大伴随有标注错误时,可能会导致训练得到的分类模型的精确度差的现象。
技术实现思路
本公开提供一种分类模型训练方法及装置、存储介质。本公开实施例第一方面提供一种分类模型训练方法,包括:利用已知离散标签的第一数据集进行分类模型的初始训练,得到第一分类模型;其中,所述离散标签,用于指示所述第一数据集内样本数据的类别;将所述第一数据集内的样本数据输入到所述第一分类模型,得到连续标签,其中,所述连续标签,用于指示所述样本数据为对应类别的概率;利用所述样本数据和所述连续标签构成的第二数据集,继续训练所述第一分类模型得到第二分类模型;利用所述第二分类模型对所述第一数据集包含的样本数据进行分类,得到所述第二分类模型输出的连续标签;将所述第二分类模型输出的连续标签满足存疑条件的所述样本数据输出,以获得对所述样本数据重新标注的离散标签,以更新所述第二数据集;利用更新后的所述第二数据集继续训练所述第二分类模型,直到满足训练停止条件。基于上述方案,所述方法还包括:获取未知标签的第三数据集;所述利用所述第二分类模型对所述第一数据集包含的样本数据进行分类,得到所述第二分类模型输出的连续标签,包括:利用所述第二分类模型对所述第一数据集与所述第三数据集包含的样本数据进行分类,得到所述第二分类模型输出的连续标签。基于上述方案,所述将所述第二分类模型输出的连续标签满足存疑条件的所述样本数据输出,包括:将所述第二分类模型输出的连续标签位于存疑标签值区域的所述样本数据输出。基于上述方案,所述将所述第一数据集内的样本数据输入到经过初始训练的所述第一分类模型,得到连续标签,包括:将所述第一数据集内的样本数据输入到经过初始训练的所述第一分类模型,得到所述第一分类模型输出的预测标签;根据所述离散标签和所述预测标签,得到所述连续标签。基于上述方案,所述根据所述离散标签和所述预测标签,得到所述连续标签,包括:当所述离散标签为:y1=1且y0=0时,按照如下得到所述连续标签;当所述离散标签为y0=1且y1=0时,按照如下公式得到所述连续标签;其中,所述y0为所述离散标签中标注对应样本数据为第一类别的标签值;所述y1为所述离散标签中标注对应样本数据为第二类别的标签值;所述Y0为所述连续标签中标注对应样本数据为第一类别的标签值;所述Y1为所述连续标签中标注对应样本数据为第二类别的标签值;所述p0为所述预测标签中标注对应样本数据为所述第一类别的预测值,所述p1为所述预测标签中标注对应样本数据为所述第二类别的预测值;所述k0为将离散标签中指示对应样本数据为第一类别的连续化为所述连续标签的学习率;所述k1为将离散标签中指示对应样本数据为第二类别的连续化为所述连续标签的学习率;所述λ为预设值。基于上述方案,所述训练停止条件,包括以下至少之一:所述第二分类模型输出的满足存疑条件的连续标签数目小于第一阈值;所述第二分类模型输出的满足存疑条件的连续标签转换为所述离散标签之后的误标率,小于第二阈值。本公开实施例第二方面提供一种分类模型训练装置,包括:第一训练模块,用于利用已知离散标签的第一数据集进行分类模型的初始训练,得到第一分类模型;其中,所述离散标签,用于指示所述第一数据集内样本数据的类别;第一得到模块,用于将所述第一数据集内的样本数据输入到所述第一分类模型,得到连续标签,其中,所述连续标签,用于指示所述样本数据为对应类别的概率;第二训练模块,用于利用所述样本数据和所述连续标签构成的第二数据集,继续训练所述第一分类模型得到第二分类模型;第二得到模块,用于利用所述第二分类模型对所述第一数据集包含的样本数据进行分类,得到所述第二分类模型输出的连续标签;更新模块,用于将所述第二分类模型输出的连续标签满足存疑条件的所述样本数据输出,以获得对所述样本数据重新标注的离散标签,以更新所述第二数据集;第三训练模块,用于利用更新后的所述第二数据集继续训练所述第二分类模型,直到满足训练停止条件。基于上述方案,所述装置还包括:获取模块,用于获取未知标签的第三数据集;所述第二得到模块,用于利用所述第二分类模型对所述第一数据集与所述第三数据集包含的样本数据进行分类,得到所述第二分类模型输出的连续标签。基于上述方案,所述更新模块,具体用于将所述第二分类模型输出的连续标签位于存疑标签值区域的所述样本数据输出。基于上述方案,所述第一得到模块,用于将所述第一数据集内的样本数据输入到经过初始训练的所述第一分类模型,得到所述第一分类模型输出的预测标签;根据所述离散标签和所述预测标签,得到所述连续标签。基于上述方案,所述第一得到模块,具体用于当所述离散标签为:y1=1且y0=0时,按照如下得到所述连续标签;当所述离散标签为y0=1且y1=0时,按照如下公式得到所述连续标签;其中,所述y0为所述离散标签中标注对应样本数据为第一类别的标签值;所述y1为所述离散标签中标注对应样本数据为第二类别的标签值;所述Y0为所述连续标签中标注对应样本数据为第一类别的标签值;所述Y1为所述连续标签中标注对应样本数据为第二类别的标签值;所述p0为所述预测标签中标注对应样本数据为所述第一类别的预测值,所述p1为所述预测标签中标注对应样本数据为所述第二类别的预测值;所述k0为将离散标签中指示对应样本数据为第一类别的连续化为所述连续标签的学习率;所述k1为将离散标签中指示对应样本数据为第二类别的连续化为所述连续标签的学习率;所述λ为预设值。基于上述方案,所述训练停止条件,包括以下至少之一:所述第二分类模型输出的满足存疑条件的连续标签数目小于第一阈值;所述第二分类模型输出的满足存疑条件的连续标签转换为所述离散标签之后的误标率,小于第二阈值。本公开实施例第三方面提供一种分类模型训练装置,包括处理器、存储器及存储在存储器上并能够有所述处理器运行的可执行程序,其特征在于,所述处理器运行所述可执行程序时执行如前述任意技术方案提供的分类模型训练方法的步骤。本公开实施例第四方面提供一种存储介质,其上存储由可执行程序,其特征在于,所述可执行程序被处理器执行时实现如前述任意技术方案提供的分类模型训练方法的步骤。本公开的实施例提供的技术方案可以包括以下有益效果:分类模型训练的过程中,首先要少量的标注有离散标签的第一训练集对能够输出连续标签的分类本文档来自技高网...

【技术保护点】
1.一种分类模型训练方法,其特征在于,包括:/n利用已知离散标签的第一数据集进行分类模型的初始训练,得到第一分类模型;其中,所述离散标签,用于指示所述第一数据集内样本数据的类别;/n将所述第一数据集内的样本数据输入到所述第一分类模型,得到连续标签,其中,所述连续标签,用于指示所述样本数据为对应类别的概率;/n利用所述样本数据和所述连续标签构成的第二数据集,继续训练所述第一分类模型得到第二分类模型;/n利用所述第二分类模型对所述第一数据集包含的样本数据进行分类,得到所述第二分类模型输出的连续标签;/n将所述第二分类模型输出的连续标签满足存疑条件的所述样本数据输出,以获得对所述样本数据重新标注的离散标签,以更新所述第二数据集;/n利用更新后的所述第二数据集继续训练所述第二分类模型,直到满足训练停止条件。/n

【技术特征摘要】
1.一种分类模型训练方法,其特征在于,包括:
利用已知离散标签的第一数据集进行分类模型的初始训练,得到第一分类模型;其中,所述离散标签,用于指示所述第一数据集内样本数据的类别;
将所述第一数据集内的样本数据输入到所述第一分类模型,得到连续标签,其中,所述连续标签,用于指示所述样本数据为对应类别的概率;
利用所述样本数据和所述连续标签构成的第二数据集,继续训练所述第一分类模型得到第二分类模型;
利用所述第二分类模型对所述第一数据集包含的样本数据进行分类,得到所述第二分类模型输出的连续标签;
将所述第二分类模型输出的连续标签满足存疑条件的所述样本数据输出,以获得对所述样本数据重新标注的离散标签,以更新所述第二数据集;
利用更新后的所述第二数据集继续训练所述第二分类模型,直到满足训练停止条件。


2.根据权利要求1所述的方法,其特征在于,所述方法还包括:
获取未知标签的第三数据集;
所述利用所述第二分类模型对所述第一数据集包含的样本数据进行分类,得到所述第二分类模型输出的连续标签,包括:
利用所述第二分类模型对所述第一数据集与所述第三数据集包含的样本数据进行分类,得到所述第二分类模型输出的连续标签。


3.根据权利要求1或2所述的方法,其特征在于,所述将所述第二分类模型输出的连续标签满足存疑条件的所述样本数据输出,包括:
将所述第二分类模型输出的连续标签位于存疑标签值区域的所述样本数据输出。


4.根据权利要求1或2所述的方法,其特征在于,所述将所述第一数据集内的样本数据输入到经过初始训练的所述第一分类模型,得到连续标签,包括:
将所述第一数据集内的样本数据输入到经过初始训练的所述第一分类模型,得到所述第一分类模型输出的预测标签;
根据所述离散标签和所述预测标签,得到所述连续标签。


5.根据权利要求4所述的方法,其特征在于,所述根据所述离散标签和所述预测标签,得到所述连续标签,包括:
当所述离散标签为:y1=1且y0=0时,按照如下得到所述连续标签;



Y0=1-Y1
当所述离散标签为y0=1且y1=0时,按照如下公式得到所述连续标签;



Y1=1-Y0
其中,所述y0为所述离散标签中标注对应样本数据为第一类别的标签值;所述y1为所述离散标签中标注对应样本数据为第二类别的标签值;
所述Y0为所述连续标签中标注对应样本数据为第一类别的标签值;所述Y1为所述连续标签中标注对应样本数据为第二类别的标签值;
所述p0为所述预测标签中标注对应样本数据为所述第一类别的预测值,所述p1为所述预测标签中标注对应样本数据为所述第二类别的预测值;所述k0为将离散标签中指示对应样本数据为第一类别的连续化为所述连续标签的学习率;所述k1为将离散标签中指示对应样本数据为第二类别的连续化为所述连续标签的学习率;
所述λ为预设值。


6.根据权利要求1所述的方法,其特征在于,所述训练停止条件,包括以下至少之一:
所述第二分类模型输出的满足存疑条件的连续标签数目小于第一阈值;
所述第二分类模型输出的满足存疑条件的连续标签转换为所述离散标签之后的误标率,小于第二阈值。


7.一种分类模型训练装置,其特征在于,包括:
第一训练模块,用于利用已知离散标签的第一数据集进...

【专利技术属性】
技术研发人员:徐泽宇邓雄文
申请(专利权)人:北京松果电子有限公司
类型:发明
国别省市:北京;11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1