分类模型训练方法和设备技术

技术编号:26378484 阅读:13 留言:0更新日期:2020-11-19 23:47
本申请实施例公开了分类模型训练方法和设备。分类模型训练方法的一具体实施方式包括:将样本信息输入至初始模型,得到样本信息的预测类别标签;将样本信息的预测类别标签和真实类别标签输入至损失函数,得到样本信息的损失;对样本信息的损失进行动态加权,得到样本信息的加权损失;基于加权损失调整初始模型的参数,得到分类模型。该实施方式在模型训练阶段对样本信息的损失进行动态加权,以调整模型对不同样本信息的学习程度,使得模型优化的方向更符合实际需求,进而提高模型的分类准确度。

【技术实现步骤摘要】
分类模型训练方法和设备
本申请实施例涉及计算机
,具体涉及分类模型训练方法和设备。
技术介绍
在社交媒体中进行高效准确的色情图像识别对于为用户营造安全的网络空间具有十分重要的意义。目前,色情图像识别技术除了基于肤色检测等传统方法外,更主流的方案是采用基于深度学习的图像分类方法。然而,传统的深度学习模型只关注图像是否正确分类以及分类的置信度。
技术实现思路
本申请实施例提出了分类模型训练方法和设备。第一方面,本申请实施例提供了一种分类模型训练方法,包括:将样本信息输入至初始模型,得到样本信息的预测类别标签;将样本信息的预测类别标签和真实类别标签输入至损失函数,得到样本信息的损失;对样本信息的损失进行动态加权,得到样本信息的加权损失;基于加权损失调整初始模型的参数,得到分类模型。在一些实施例中,对样本信息的损失进行动态加权,得到样本信息的加权损失,包括:基于样本信息的损失,确定样本信息的损失权重;基于样本信息的损失权重对样本信息的损失加权,得到加权损失。在一些实施例中,样本信息的损失权重与样本信息的损失正相关。在一些实施例中,样本信息是样本文本、样本图像、样本语音和样本视频之一。在一些实施例中,损失函数是交叉熵损失函数。在一些实施例中,交叉熵损失函数定义为:其中,1≤i≤n,1≤j≤n,且i,j,n均为正整数,y是样本信息的真实类别标签的独热编码,样本信息属于n类,yi是样本信息属于第i个类别的真实概率的独热编码,z为初始模型的输出,p为z的归一化,代表样本信息的预测类别标签,pi为样本信息属于第i个类别的预测概率。在一些实施例中,样本信息只属于一个类别,交叉墒损失函数定义为:CEL=-lnpt;其中,样本信息只属于第t个类别,1≤t≤n,且t为正整数,yt=1。在一些实施例中,加权损失定义为:REL=-αlnpt,其中,α为损失权重,是样本信息的预测类别标签,若样本信息正确分类为第t个类别,则且正确分类标签为yt时的损失权重为λ1,若样本信息错误分类为第j个类别,则则错误分类标签为yj时的损失权重为λ2,若样本信息错误分类为第k个类别,则则错误分类标签为yk时的损失权重为λ3。在一些实施例中,该方法还包括:获取待识别信息;将待识别信息输入至分类模型,得到待识别信息的类别。在一些实施例中,样本信息是样本图像;以及该方法还包括:获取待识别图像;利用分类模型检测待识别图像中是否存在违规信息。在一些实施例中,利用分类模型检测待识别图像中是否存在违规信息,包括:对待识别图像进行预处理,得到预处理图像;将预处理图像输入至分类模型,得到预处理图像的类别;基于预处理图像的类别,确定待识别图像中是否存在违规信息。第二方面,本申请实施例提供了一种分类模型训练装置,包括:初始分类单元,被配置成将样本信息输入至初始模型,得到样本信息的预测类别标签;损失计算单元,被配置成将样本信息的预测类别标签和真实类别标签输入至损失函数,得到样本信息的损失;动态加权单元,被配置成对样本信息的损失进行动态加权,得到样本信息的加权损失;参数调整单元,被配置成基于加权损失调整初始模型的参数,得到分类模型。在一些实施例中,动态加权单元进一步被配置成:基于样本信息的损失,确定样本信息的损失权重;基于样本信息的损失权重对样本信息的损失加权,得到加权损失。在一些实施例中,样本信息的损失权重与样本信息的损失正相关。在一些实施例中,样本信息是样本文本、样本图像、样本语音和样本视频之一。在一些实施例中,损失函数是交叉熵损失函数。在一些实施例中,交叉熵损失函数定义为:其中,1≤i≤n,1≤j≤n,且i,j,n均为正整数,y是样本信息的真实类别标签的独热编码,样本信息属于n类,yi是样本信息属于第i个类别的真实概率的独热编码,z为初始模型的输出,p为z的归一化,代表样本信息的预测类别标签,pi为样本信息属于第i个类别的预测概率。在一些实施例中,样本信息只属于一个类别,交叉墒损失函数定义为:CEL=-lnpt;其中,样本信息只属于第t个类别,1≤t≤n,且t为正整数,yt=1。在一些实施例中,加权损失定义为:=-,其中,α为损失权重,是样本信息的预测类别标签,若样本信息正确分类为第t个类别,则且正确分类标签为yt时的损失权重为λ1,若样本信息错误分类为第j个类别,则则错误分类标签为yj时的损失权重为λ2,若样本信息错误分类为第k个类别,则则错误分类标签为yk时的损失权重为λ3。在一些实施例中,该装置还包括:信息获取单元,被配置成获取待识别信息;信息分类单元,被配置成将待识别信息输入至分类模型,得到待识别信息的类别。在一些实施例中,样本信息是样本图像;以及该装置还包括:图像获取单元,被配置成获取待识别图像;违规检测单元,被配置成利用分类模型检测待识别图像中是否存在违规信息。在一些实施例中,违规检测单元进一步被配置成:对待识别图像进行预处理,得到预处理图像;将预处理图像输入至分类模型,得到预处理图像的类别;基于预处理图像的类别,确定待识别图像中是否存在违规信息。第三方面,本申请实施例提供了一种计算机设备,该计算机设备包括:一个或多个处理器;存储装置,其上存储有一个或多个程序;当一个或多个程序被一个或多个处理器执行,使得一个或多个处理器实现如第一方面中任一实现方式描述的方法。第四方面,本申请实施例提供了一种计算机可读介质,其上存储有计算机程序,该计算机程序被处理器执行时实现如第一方面中任一实现方式描述的方法。本申请实施例提供的分类模型训练方法和设备,首先将样本信息输入至初始模型,得到样本信息的预测类别标签;之后将样本信息的预测类别标签和真实类别标签输入至损失函数,得到样本信息的损失;然后对样本信息的损失进行动态加权,得到样本信息的加权损失;最后基于加权损失调整初始模型的参数,得到分类模型。在模型训练阶段对样本信息的损失进行动态加权,以调整模型对不同样本信息的学习程度,不仅关注样本信息是否正确分类以及分类的置信度,还关注错误分类时不同样本信息之间的关联性,使得模型优化的方向更符合实际需求,进而提高模型的分类准确度。附图说明通过阅读参照以下附图所作的对非限制性实施例所作的详细描述,本申请的其它特征、目的和优点将会变得更明显:图1是本申请可以应用于其中的示例性系统架构;图2是根据本申请的分类模型训练方法的一个实施例的流程图;图3是根据本申请的分类模型训练方法的又一个实施例的流程图;图4是根据本申请的分类模型训练方法的另一个实施例的流程图;图5是适于用来实现本申请实施例的计算机设备的计算机系统的结构示意图。具体实施方式下面结合附图和实施例对本申请作进一步的详细说明。可以理解的是,此处所描述的具体实施例仅仅用于解释相关专利技术,而非对该专利技术的限定。另外还需要说明的是,为了便于描述,附图中仅示出了与有关专利技术相关的部分。需本文档来自技高网...

【技术保护点】
1.一种分类模型训练方法,包括:/n将样本信息输入至初始模型,得到所述样本信息的预测类别标签;/n将所述样本信息的预测类别标签和真实类别标签输入至损失函数,得到所述样本信息的损失;/n对所述样本信息的损失进行动态加权,得到所述样本信息的加权损失;/n基于所述加权损失调整所述初始模型的参数,得到分类模型。/n

【技术特征摘要】
1.一种分类模型训练方法,包括:
将样本信息输入至初始模型,得到所述样本信息的预测类别标签;
将所述样本信息的预测类别标签和真实类别标签输入至损失函数,得到所述样本信息的损失;
对所述样本信息的损失进行动态加权,得到所述样本信息的加权损失;
基于所述加权损失调整所述初始模型的参数,得到分类模型。


2.根据权利要求1所述的方法,其中,所述对所述样本信息的损失进行动态加权,得到所述样本信息的加权损失,包括:
基于所述样本信息的损失,确定所述样本信息的损失权重;
基于所述样本信息的损失权重对所述样本信息的损失加权,得到所述加权损失。


3.根据权利要求2所述的方法,其中,所述样本信息的损失权重与所述样本信息的损失正相关。


4.根据权利要求1-3之一所述的方法,其中,所述样本信息是样本文本、样本图像、样本语音和样本视频之一。


5.根据权利要求1-3之一所述的方法,其中,所述损失函数是交叉熵损失函数。


6.根据权利要求5所述的方法,其中,所述交叉熵损失函数定义为:



其中,1≤i≤n,1≤j≤n,且i,j,n均为正整数,y是所述样本信息的真实类别标签的独热编码,所述样本信息属于n类,yi是所述样本信息属于第i个类别的真实概率的独热编码,z为所述初始模型的输出,p为z的归一化,代表所述样本信息的预测类别标签,pi为所述样本信息属于第i个类别的预测概率。


7.根据权利要求6所述的方法,其中,所述样本信息只属于一个类别,交叉墒损失函数定义为:
CEL=-lnpt;
其中,所述样本信息只属于第t个类别,1≤...

【专利技术属性】
技术研发人员:侯永杰
申请(专利权)人:连尚新昌网络科技有限公司
类型:发明
国别省市:浙江;33

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1