System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind() 分类模型的训练和数据分类方法、装置、设备及存储介质制造方法及图纸_技高网

分类模型的训练和数据分类方法、装置、设备及存储介质制造方法及图纸

技术编号:40979189 阅读:5 留言:0更新日期:2024-04-18 21:26
本公开提供一种分类模型的训练和数据分类方法、装置、设备及存储介质,涉及人工智能技术领域,具体涉及深度学习、数据分类、模型优化等技术领域,可应用于图像分类、文本分类等场景下。具体实现方案包括:获取源域样本数据、以及源域样本数据对应的源域样本标签;通过分类网络确定源域样本数据对应的源域特征信息和源域预测标签;根据源域样本标签和源域预测标签确定第一损失;根据源域特征信息、源域样本数据所属类的第一质心、以及源域样本数据所属类之外的其他类的第二质心,确定第二损失;根据第一损失和第二损失确定第三损失;根据第三损失更新分类网络,得到分类模型。本公开可以降低分类模型的域差异,提高分类模型的性能。

【技术实现步骤摘要】

本公开涉及人工智能,具体涉及深度学习、数据分类、模型优化等,可应用于图像分类、文本分类等场景下,尤其涉及一种分类模型的训练和数据分类方法、装置、设备及存储介质


技术介绍

1、对于分类模型而言,用于训练分类模型的样本数据可以称为源域数据,使用分类模型进行分类的待分类数据可以称为目标域数据。而源域数据和目标域数据之间的域差异,可能会导致将采用源域数据训练的分类模型应用于目标域数据的分类时,出现显著的性能下降问题。

2、目前,可以通过采用多个源域数据集训练分类模型的方式,来适当降低分类模型对不同域数据之间的敏感度,以减少分类模型的性能下降。

3、但这种采用多个源域数据集训练分类模型的方式中,多个源域数据集占用了大量内存,训练成本较高,且分类模型的性能下降依然很明显。


技术实现思路

1、本公开提供了一种分类模型的训练和数据分类方法、装置、设备及存储介质,能够明显降低分类模型的域差异,减少分类模型的性能下降,提高分类模型在对目标域数据进行分类时的性能。

2、根据本公开的第一方面,提供了一种分类模型的训练方法,所述方法包括:获取源域样本数据、以及源域样本数据对应的源域样本标签;通过预设的分类网络确定源域样本数据对应的源域特征信息和源域预测标签;根据源域样本标签和源域预测标签,确定第一损失;根据源域特征信息、源域样本数据所属类的第一质心、以及源域样本数据所属类之外的其他类的第二质心,确定第二损失;根据第一损失和第二损失,确定第三损失;根据第三损失更新分类网络,得到分类模型;分类模型用于对目标域数据进行分类。

3、根据本公开的第二方面,提供了一种分类模型的训练装置,所述装置包括:获取单元,用于获取源域样本数据、以及源域样本数据对应的源域样本标签。训练单元,用于通过预设的分类网络确定源域样本数据对应的源域特征信息和源域预测标签;根据源域样本标签和源域预测标签,确定第一损失;根据源域特征信息、源域样本数据所属类的第一质心、以及源域样本数据所属类之外的其他类的第二质心,确定第二损失;根据第一损失和第二损失,确定第三损失;根据第三损失更新分类网络,得到分类模型;分类模型用于对目标域数据进行分类。

4、根据本公开的第三方面,提供了一种数据分类方法,所述方法包括:获取待分类的目标域数据;通过分类模型,对目标域数据进行分类,得到目标域数据的类别。其中,分类模型是采用源域样本数据、以及源域样本数据对应的源域样本标签,对预设的分类网络进行训练得到的,在训练分类网络时采用的损失包括:源域样本标签和源域预测标签之间的第一损失、以及源域特征信息与源域样本数据所属类的第一质心和所属类之外的其他类的第二质心之间的第二损失,源域预测标签和源域特征信息是分类网络根据源域样本数据所确定的。

5、根据本公开的第四方面,提供了一种数据分类装置,装置包括:获取单元,用于获取待分类的目标域数据;分类单元,用于通过分类模型,对目标域数据进行分类,得到目标域数据的类别。其中,分类模型是采用源域样本数据、以及源域样本数据对应的源域样本标签,对预设的分类网络进行训练得到的,在训练分类网络时采用的损失包括:源域样本标签和源域预测标签之间的第一损失、以及源域特征信息与源域样本数据所属类的第一质心和所属类之外的其他类的第二质心之间的第二损失,源域预测标签和源域特征信息是分类网络根据源域样本数据所确定的。

6、根据本公开的第五方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行如第一方面或第三方面所述的方法。

7、根据本公开的第六方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使计算机执行根据第一方面或第三方面所述的方法。

8、根据本公开的第七方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据第一方面或第三方面所述的方法。

9、应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。

本文档来自技高网...

【技术保护点】

1.一种分类模型的训练方法,所述方法包括:

2.根据权利要求1所述的方法,所述根据所述源域特征信息、所述源域样本数据所属类的第一质心、以及所述源域样本数据所属类之外的其他类的第二质心,确定第二损失,包括:

3.根据权利要求1或2所述的方法,所述根据所述源域特征信息、所述源域样本数据所属类的第一质心、以及所述源域样本数据所属类之外的其他类的第二质心,确定第二损失之前,所述方法还包括:

4.根据权利要求3所述的方法,当所述N等于2时,所述平滑系数的值为预设的初始值;

5.根据权利要求1-4任一项所述的方法,所述根据所述第一损失和所述第二损失,确定第三损失,包括:

6.根据权利要求5所述的方法,所述第一权重和所述第二权重均为大于0、且小于1的数,且所述第一权重和所述第二权重的和等于1。

7.根据权利要求5所述的方法,所述第一权重等于1,所述第二权重为大于0、且小于1的数。

8.根据权利要求1-7任一项所述的方法,所述根据所述源域特征信息、所述源域样本数据所属类的第一质心、以及所述源域样本数据所属类之外的其他类的第二质心,确定第二损失之前,所述方法还包括:

9.根据权利要求1-8任一项所述的方法,所述根据所述第三损失更新所述分类网络,得到分类模型之后,所述方法还包括:

10.一种数据分类方法,所述方法包括:

11.根据权利要求10所述的方法,所述第二损失为第一对数值的负数,所述第一对数值为第三比值以自然常数e或10为底时的对数值,所述第三比值为第一指数值相对于第一指数值和第二指数值之和的比值,所述第一指数值为所述源域特征信息和所述第一质心的乘积相对于预设的超参数的第一比值以自然常数e为底时的指数值,所述第二指数值为所述源域特征信息和所述第二质心的乘积相对于所述超参数的第二比值以自然常数e为底时的指数值。

12.根据权利要求10或11所述的方法,当所述源域样本数据为第1轮参与所述分类模型的训练时,所述第一质心为从所述源域样本数据所属类包含的所有数据中的一组数据的特征信息,所述第二质心为所述源域样本数据所属类之外的其他类包含的所有数据中的一组数据的特征信息;

13.根据权利要求12所述的方法,当所述N等于2时,所述平滑系数的值为预设的初始值;

14.根据权利要求10-13任一项所述的方法,在训练所述分类网络时,所述第一损失所占的权重为第一权重,所述第二损失所占的权重为第三权重。

15.根据权利要求14所述的方法,所述第一权重和所述第二权重均为大于0、且小于1的数,且所述第一权重和所述第二权重的和等于1。

16.根据权利要求14所述的方法,所述第一权重等于1,所述第二权重为大于0、且小于1的数。

17.根据权利要求10-16任一项所述的方法,所述分类模型采用目标域样本数据、以及所述目标域样本数据对应的目标域样本标签进行了更新;

18.一种分类模型的训练装置,所述装置包括:

19.根据权利要求18所述的装置,所述训练单元,具体用于:

20.根据权利要求18或19所述的装置,所述训练单元,还用于:

21.根据权利要求20所述的装置,当所述N等于2时,所述平滑系数的值为预设的初始值;

22.根据权利要求18-21任一项所述的装置,所述训练单元,具体用于:

23.根据权利要求18-22任一项所述的装置,所述训练单元,还用于:

24.根据权利要求18-23任一项所述的装置,所述获取单元,还用于获取目标域样本数据、以及所述目标域样本数据对应的目标域样本标签;

25.一种数据分类装置,所述装置包括:

26.根据权利要求25所述的装置,所述第二损失为第一对数值的负数,所述第一对数值为第三比值以自然常数e或10为底时的对数值,所述第三比值为第一指数值相对于第一指数值和第二指数值之和的比值,所述第一指数值为所述源域特征信息和所述第一质心的乘积相对于预设的超参数的第一比值以自然常数e为底时的指数值,所述第二指数值为所述源域特征信息和所述第二质心的乘积相对于所述超参数的第二比值以自然常数e为底时的指数值。

27.根据权利要求25或26所述的装置,当所述源域样本数据为第1轮参与所述分类模型的训练时,所述第一质心为从所述源域样本数据所属类包含的所有数据中的一组数据的特征信息,所述第二质心为所述源域样本数据所属类之外的其他类包含的所有数据中的一组数据的特征信息;

28.根据权利要求27所述的装置,当所述N等于2时,所述平滑系数的值为预设的初始值;...

【技术特征摘要】

1.一种分类模型的训练方法,所述方法包括:

2.根据权利要求1所述的方法,所述根据所述源域特征信息、所述源域样本数据所属类的第一质心、以及所述源域样本数据所属类之外的其他类的第二质心,确定第二损失,包括:

3.根据权利要求1或2所述的方法,所述根据所述源域特征信息、所述源域样本数据所属类的第一质心、以及所述源域样本数据所属类之外的其他类的第二质心,确定第二损失之前,所述方法还包括:

4.根据权利要求3所述的方法,当所述n等于2时,所述平滑系数的值为预设的初始值;

5.根据权利要求1-4任一项所述的方法,所述根据所述第一损失和所述第二损失,确定第三损失,包括:

6.根据权利要求5所述的方法,所述第一权重和所述第二权重均为大于0、且小于1的数,且所述第一权重和所述第二权重的和等于1。

7.根据权利要求5所述的方法,所述第一权重等于1,所述第二权重为大于0、且小于1的数。

8.根据权利要求1-7任一项所述的方法,所述根据所述源域特征信息、所述源域样本数据所属类的第一质心、以及所述源域样本数据所属类之外的其他类的第二质心,确定第二损失之前,所述方法还包括:

9.根据权利要求1-8任一项所述的方法,所述根据所述第三损失更新所述分类网络,得到分类模型之后,所述方法还包括:

10.一种数据分类方法,所述方法包括:

11.根据权利要求10所述的方法,所述第二损失为第一对数值的负数,所述第一对数值为第三比值以自然常数e或10为底时的对数值,所述第三比值为第一指数值相对于第一指数值和第二指数值之和的比值,所述第一指数值为所述源域特征信息和所述第一质心的乘积相对于预设的超参数的第一比值以自然常数e为底时的指数值,所述第二指数值为所述源域特征信息和所述第二质心的乘积相对于所述超参数的第二比值以自然常数e为底时的指数值。

12.根据权利要求10或11所述的方法,当所述源域样本数据为第1轮参与所述分类模型的训练时,所述第一质心为从所述源域样本数据所属类包含的所有数据中的一组数据的特征信息,所述第二质心为所述源域样本数据所属类之外的其他类包含的所有数据中的一组数据的特征信息;

13.根据权利要求12所述的方法,当所述n等于2时,所述平滑系数的值为预设的初始值;

14.根据权利要求10-13任一项所述的方法,在训练所述分类网络时,所述第一损失所占的权重为第一权重,所述第二损失所占的权重为第三权重。

15.根据权利要求14所述的方法,所述第一权重...

【专利技术属性】
技术研发人员:周文硕杨大陆杨叶辉王晓荣王磊
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1