训练分类模型、数据分类方法、装置、设备、介质及产品制造方法及图纸

技术编号:34101745 阅读:18 留言:0更新日期:2022-07-11 23:36
本公开提供了训练分类模型方法、数据分类方法、装置、设备、介质及产品,涉及人工智能技术领域,具体为深度学习、计算机视觉技技术领域,可应用于医学影像处理场景。具体实现方案为:利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定第一训练图像对应的分类损失;利用分类模型的第二网络分支提取第一训练图像的第一图像特征,以及第二训练图像的第二图像特征;基于第一图像特征与第二图像特征,确定对比损失;基于分类损失以及对比损失,更新分类模型的参数,得到训练完成的分类模型。本公开提升了分类模型的图像分类效果。了分类模型的图像分类效果。了分类模型的图像分类效果。

【技术实现步骤摘要】
训练分类模型、数据分类方法、装置、设备、介质及产品


[0001]本公开涉及人工智能
,尤其涉及深度学习、计算机视觉
,可应用于医学影像处理场景。

技术介绍

[0002]自监督对比学习是无监督学习的一种,能够从无标注的数据中学习知识,随着自监督对比学习的发展,从特征层面取得了很好的效果。例如,在人工智能的数据分类中,使用自监督对比学习方式能够对影像等图像数据进行分类,比如对医学图像数据进行分级。
[0003]对于有监督模型,需要大量高质量标注的样本提高学习效果。对于数据标注成本高,有标注样本相对较少、标注质量差的情况,有监督模型往往泛化能力不够强,且标注本身的噪声限制了有监督分类模型的上限。

技术实现思路

[0004]本公开提供了一种用于训练分类模型方法、数据分类方法、装置、设备、介质及产品。
[0005]根据本公开的一方面,提供了一种训练分类模型的方法,包括:利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定所述第一训练图像对应的分类损失;利用所述分类模型的第二网络分支提取所述第一训练图像的第一图像特征,并提取第二训练图像的第二图像特征;基于所述第一图像特征与所述第二图像特征,确定对比损失;基于所述分类损失以及所述对比损失,更新所述分类模型的参数,得到训练完成的分类模型。
[0006]根据本公开的另一方面,提供了一种数据分类方法,包括:
[0007]确定待分类数据;将所述待分类数据输入至分类模型,得到所述分类模型的输出结果;基于所述分类模型的输出结果,确定所述待分类数据的分类结果;其中,所述分类模型包括第一网络分支和第二网络分支,并基于分类损失以及对比损失进行参数更新后预先训练得到;所述第一网络分支用于对第一训练图像进行分类预测,得到所述第一训练图像的分类预测结果;所述第二网络分支用于提取所述第一训练图像的第一图像特征以及第二训练图像的第二图像特征;其中,所述分类损失通过所述第一训练图像的分类预测结果确定,所述对比损失基于所述第一图像特征与所述第二图像特征确定。
[0008]根据本公开的又一方面,提供了一种训练分类模型的装置,包括:确定模块,用于利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定所述第一训练图像对应的分类损失,以及,基于第一图像特征与第二图像特征,确定对比损失;提取模块,利用所述分类模型的第二网络分支提取所述第一训练图像的第一图像特征,并提取第二训练图像的第二图像特征;更新模块,用于基于所述分类损失以及所述对比损失,更新所述分类模型的参数,得到训练完成的分类模型。
[0009]根据本公开的又一方面,提供了一种数据分类装置,包括:
[0010]确定模块,用于确定待分类数据;分类模块,用于将所述待分类数据输入至分类模型,得到所述分类模型的输出结果,并基于所述分类模型的输出结果,确定所述待分类数据的分类结果;其中,所述分类模型包括第一网络分支和第二网络分支,并基于分类损失以及对比损失进行参数更新后预先训练得到;所述第一网络分支用于对第一训练图像进行分类预测,得到所述第一训练图像的分类预测结果;所述第二网络分支用于提取所述第一训练图像的第一图像特征以及第二训练图像的第二图像特征;其中,所述分类损失通过所述第一训练图像的分类预测结果确定,所述对比损失基于所述第一图像特征与所述第二图像特征确定。
[0011]根据本公开的又一方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本公开中的训练分类模型的方法。
[0012]根据本公开的又一方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本公开中的数据分类方法。
[0013]根据本公开的又一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行本公开中训练分类模型的方法。
[0014]根据本公开的又一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行本公开中数据分类方法。
[0015]根据本公开的又一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现本公开中的训练分类模型的方法。
[0016]根据本公开的又一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现本公开中的数据分类方法。
[0017]应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
[0018]附图用于更好地理解本方案,不构成对本公开的限定。其中:
[0019]图1是根据本公开一示例性实施例示出的训练分类模型的方法流程示意图;
[0020]图2是根据本公开一示例性实施例示出的分类模型结构示意图;
[0021]图3示出了一种有监督分类模型结构的示意图;
[0022]图4示出了一种无监督对比学习的分类模型结构的示意图;
[0023]图5是根据本公开提供的一种基于分类损失以及对比损失,更新分类模型参数的方法流程示意图;
[0024]图6是根据本公开的利用分类模型的第一网络分支对第一训练图像进行分类预测的分类预测结果,确定第一训练图像对应的分类损失的方法流程示意图;
[0025]图7是根据本公开的利用所述分类模型的第二网络分支分别提取第一训练图像的第一图像特征,以及第二训练图像的第二图像特征的方法流程示意图;
[0026]图8是在糖尿病视网膜病变分级数据集上,有监督分类器的骨干网络的输出特征可视化的结果示意图;
[0027]图9示出了糖尿病视网膜病变分级数据集中包括的数据集详细信息;
[0028]图10为有监督分类器分类模型结构和本公开中双流结构的分类模型结构训练集损失值(loss)随训练过程的移动平均线曲线对比示意图;
[0029]图11为有监督分类器分类模型结构和本公开中双流结构的分类模型结构验证集中Kappa值的移动平均曲线。
[0030]图12展现了有监督分类器分类模型结构和本公开中双流结构的分类模型结构在测试集上的Kappa值对比示意图;
[0031]图13是根据本公开一示例性实施例示出的一种数据分类方法流程图;
[0032]图14是根据本公开一示例性实施例示出的一种训练分类模型的装置框图;
[0033]图15是根据本公开一示例性实施例示出的一种数据分类装置框图;
[0034]图16是用来实现本公开实施例的训练分类本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种训练分类模型的方法,包括:利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定所述第一训练图像对应的分类损失;利用所述分类模型的第二网络分支提取所述第一训练图像的第一图像特征,并提取第二训练图像的第二图像特征;基于所述第一图像特征与所述第二图像特征,确定对比损失;基于所述分类损失以及所述对比损失,更新所述分类模型的参数,得到训练完成的分类模型。2.根据权利要求1所述的方法,其中,所述对比损失通过多分类版本的噪声对比估计损失函数确定。3.根据权利要求1或2所述的方法,其中,所述基于所述分类损失以及所述对比损失,更新所述分类模型的参数,包括:基于所述分类损失,利用反向传播更新所述第一网络分支的参数,并基于所述对比损失,利用反向传播更新所述第一网络分支的参数;基于更新后的所述第一网络分支的参数,通过动量更新所述第二网络分支的参数。4.根据权利要求1所述的方法,其中,所述利用分类模型的第一网络分支对所述第一训练图像进行分类预测,并基于得到的分类预测结果,确定所述第一训练图像对应的分类损失,包括:将所述第一训练图像输入至分类模型的第一网络分支的分类器,得到对所述第一训练图像进行分类预测的分类预测结果;基于所述分类预测结果与所述训练图像的标注信息,确定所述第一训练图像对应的交叉熵分类损失。5.根据权利要求1所述的方法,其中,所述利用所述分类模型的第二网络分支提取所述第一训练图像的第一图像特征,以及所述第二训练图像的第二图像特征,包括:利用所述第二网络分支中第一分支的投影器与预测器,提取所述第一训练图像的第一图像特征;利用所述第二网络分支中第二分支的投影器,提取所述第二训练图像的第二图像特征。6.根据权利要求1

5中任意一项所述的方法,其中,所述第一训练图像和所述第二训练图像通过对训练图像进行不同次的数据增强得到。7.一种数据分类方法,包括:确定待分类数据;将所述待分类数据输入至分类模型,得到所述分类模型的输出结果;基于所述分类模型的输出结果,确定所述待分类数据的分类结果;其中,所述分类模型包括第一网络分支和第二网络分支,并基于分类损失以及对比损失进行参数更新后预先训练得到;所述第一网络分支用于对第一训练图像进行分类预测,得到所述第一训练图像的分类预测结果;所述第二网络分支用于提取所述第一训练图像的第一图像特征以及第二训练图像的
第二图像特征;其中,所述分类损失通过所述第一训练图像的分类预测结果确定,所述对比损失基于所述第一图像特征与所述第二图像特征确定。8.根据权利要求7所述的方法,其中,所述第一训练图像的分类预测结果基于所述第一网络分支的分类器得到;所述分类损失为所述第一训练图像的交叉熵分类损失,所述第一训练图像的交叉熵分类损失基于所述第一训练图像的分类预测结果以及训练图像的标注信息确定;所述第一训练图像通过对所述训练图像进行数据增强得到。9.根据权利要求7所述的方法,其中,所述第一训练图像的第一图像特征利用所述第二网络分支中第一分支的投影器与预测器提取;所述第二训练图像的第二图像特征利用所述第二网络分支中第二分支的投影器提取。10.根据权利要求7所述的方法,其中,所述待分类数据为医学影像数据。11.一种训练分类模型的装置,包括:确定模块,用于利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定所述第一训练图像对应的分类损失,以及,基于第一图像特征与第二图像特征,确定对比损失;提取模块,利用所述分类模型的第二网络分支提取所述第一训练图像的第一图像特征,并提取第二训练图像的第二图像特征;更新模块,用于基于所述分类损失以及所述对比损失,更新所述分类模型的参...

【专利技术属性】
技术研发人员:周文硕杨大陆杨叶辉武秉泓王晓荣王磊
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1