一种伪标签模型的训练方法、装置、设备及存储介质制造方法及图纸

技术编号:35009320 阅读:6 留言:0更新日期:2022-09-21 15:00
本公开实施例提供一种伪标签模型的训练方法、装置、设备及存储介质,其中方法包括:获取样本图像的标签信息和边缘信息;获取所述样本图像的第一分类预测结果和第一边缘预测结果,所述第一分类预测结果和第一边缘预测结果通过待训练的伪标签模型预测得到;根据所述第一分类预测结果和所述标签信息的差别,确定第一网络损失;根据所述第一边缘预测结果和所述边缘信息的差别,确定第二网络损失;根据所述第一网络损失和所述第二损失调整所述待训练的伪标签模型的网络参数,直至达到模型训练结束条件时,获得所述伪标签模型。该方法训练得到的伪标签模型,对于无标签图像可以利用预测的边缘信息引导修正预测的标签信息,进而得到更为可靠的伪标签。更为可靠的伪标签。更为可靠的伪标签。

【技术实现步骤摘要】
一种伪标签模型的训练方法、装置、设备及存储介质


[0001]本公开涉及深度学习
,具体涉及一种伪标签模型的训练方法、装置、设备及存储介质。

技术介绍

[0002]随着深度学习的发展,语义分割任务已经得到了广泛研究,基于监督学习的语义分割算法不断刷新着Cityscapes(城市景观数据集),Pascal VOC(目标检测数据集)等基准数据集上的报告精度。然而,对于基于监督学习的语义分割任务而言,在训练时使用的训练数据集中的大量样本图像需要经过高质量的像素级标注,这意味着昂贵的标注成本和漫长的标注时间。
[0003]半监督学习(Semi

Supervised Learning,SSL)是监督学习与无监督学习相结合的一种学习方法。半监督学习在训练过程中使用少量有标签图像与大量无标签图像。其核心在于,有效利用大量的无标签图像,作为有标签图像的补充,提升训练得到的模型的精度。半监督语义分割(Semi

Supervised Semantic Segmentation)将半监督学习应用于语义分割任务,以缓解基于监督学习的语义分割任务对高质量标注的训练数据集的依赖。半监督语义分割的关键在于为训练数据集中的无标签图像的各个像素打上伪标签(Pseudo Label),伪标签的准确性及可靠性越高,则训练得到的模型的精度越高。
[0004]相关技术中,一种常见的做法是自我学习(Self

Training),即先利用已标注的有标签图像来训练一个模型,然后使用这个模型对未标注的无标签图像进行标注预测,选择高度可信的预测作为伪标签,但这样得到的伪标签中存在一定数量的不可靠的伪标签,使用这些伪标签标注的数据集进行语义分割任务的模型训练将导致模型的退化,甚至使得模型朝着错误的方向发展,制约了半监督语义分割算法的精度。

技术实现思路

[0005]有鉴于此,本公开实施例提供至少一种伪标签模型的训练、装置、设备及存储介质。
[0006]具体地,本公开实施例是通过如下技术方案实现的:
[0007]第一方面,提供一种伪标签模型的训练方法,所述方法包括:
[0008]获取样本图像的标签信息和边缘信息,所述标签信息包括所述样本图像的分割掩膜的信息,所述边缘信息包括所述分割掩膜的边缘的信息,所述分割掩膜用于标注所述样本图像中各个像素的所属类别;
[0009]获取所述样本图像的第一分类预测结果和第一边缘预测结果,所述第一分类预测结果和第一边缘预测结果通过待训练的伪标签模型预测得到;
[0010]根据所述第一分类预测结果和所述标签信息的差别,确定第一网络损失;
[0011]根据所述第一边缘预测结果和所述边缘信息的差别,确定第二网络损失;
[0012]根据所述第一网络损失和所述第二损失调整所述待训练的伪标签模型的网络参
数,直至达到模型训练结束条件时,获得所述伪标签模型。
[0013]第二方面,提供一种半监督语义分割网络的训练方法,所述语义分割网络的训练样本集中包括:多个无标签图像;所述方法包括:
[0014]获取通过上述伪标签模型的训练方法训练得到的伪标签模型;
[0015]将无标签图像输入所述伪标签模型,得到对应的第二预测结果和第二边缘预测结果;
[0016]根据所述第二边缘预测结果,对所述第二分类预测结果进行修正,得到无标签图像对应的伪标签信息;
[0017]根据所述无标签图像及其对应的伪标签信息,对所述语义分割网络进行训练。
[0018]第三方面,提供一种伪标签模型的训练装置,所述装置包括:
[0019]训练数据获取模块,用于:获取样本图像的标签信息和边缘信息,所述标签信息包括所述样本图像的分割掩膜的信息,所述边缘信息包括所述分割掩膜的边缘的信息,所述分割掩膜用于标注所述样本图像中各个像素的所属类别;获取所述样本图像的第一分类预测结果和第一边缘预测结果,所述第一分类预测结果和第一边缘预测结果通过待训练的伪标签模型预测得到;
[0020]网络损失确定模块,用于:根据所述第一分类预测结果和所述标签信息的差别,确定第一网络损失;根据所述第一边缘预测结果和所述边缘信息的差别,确定第二网络损失;
[0021]网络参数调整模块,用于:根据所述第一网络损失和所述第二损失调整所述待训练的伪标签模型的网络参数,直至达到模型训练结束条件时,获得所述伪标签模型。
[0022]第四方面,提供一种半监督语义分割网络的训练装置,所述语义分割网络的训练样本集中包括:多个无标签图像;所述装置包括:
[0023]模型获取模块,用于:获取通过上述伪标签模型的训练方法训练得到的伪标签模型;
[0024]模型预测模块,用于:将无标签图像输入所述伪标签模型,得到对应的第二预测结果和第二边缘预测结果;
[0025]标签修正模块,用于:根据所述第二边缘预测结果,对所述第二分类预测结果进行修正,得到无标签图像对应的伪标签信息;
[0026]网络训练模块,用于:根据所述无标签图像及其对应的伪标签信息,对所述语义分割网络进行训练。
[0027]第五方面,一种电子设备,所述设备包括存储器、处理器,所述存储器用于存储可在处理器上运行的计算机指令,所述处理器用于在执行所述计算机指令时实现本公开任一实施例所述的伪标签模型的训练方法或者半监督语义分割网络的训练方法。
[0028]第六方面,提供一种计算机可读存储介质,其上存储有计算机程序,所述程序被处理器执行时实现本公开任一实施例所述的伪标签模型的训练方法或者半监督语义分割网络的训练方法。
[0029]本公开实施例的技术方案提供的伪标签模型的训练方法,使用样本图像的标签信息和边缘信息作为监督信息训练得到伪标签模型,该伪标签模型用于预测得到分类预测结果和边缘预测结果,使得对于无标签图像可以利用预测的分类预测结果引导修正预测的分类预测结果,进而得到更为可靠的伪标签,能够充分利用有限的样本图像的标签提供的信
息,提升算法精度,同时大量节省标注成本。
附图说明
[0030]为了更清楚地说明本公开一个或多个实施例或相关技术中的技术方案,下面将对实施例或相关技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本公开一个或多个实施例中记载的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
[0031]图1是本公开至少一个实施例示出的一种语义分割网络训练方法的示意图;
[0032]图2是本公开至少一个实施例示出的一种伪标签模型的训练方法的流程图;
[0033]图3是本公开至少一个实施例示出的一种样本图像和其对应的分割掩膜;
[0034]图4是本公开至少一个实施例示出的一种分割掩膜和其对应的边缘;
[0035]图5是本公开至少一个实施例示出的一种基于编码器

解码器网络结构的本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种伪标签模型的训练方法,其特征在于,所述方法包括:获取样本图像的标签信息和边缘信息,所述标签信息包括所述样本图像的分割掩膜的信息,所述边缘信息包括所述分割掩膜的边缘的信息,所述分割掩膜用于标注所述样本图像中各个像素的所属类别;获取所述样本图像的第一分类预测结果和第一边缘预测结果,所述第一分类预测结果和第一边缘预测结果通过待训练的伪标签模型预测得到;根据所述第一分类预测结果和所述标签信息的差别,确定第一网络损失;根据所述第一边缘预测结果和所述边缘信息的差别,确定第二网络损失;根据所述第一网络损失和所述第二损失调整所述待训练的伪标签模型的网络参数,直至达到模型训练结束条件时,获得所述伪标签模型。2.根据权利要求1所述的方法,其特征在于,所述方法还包括:对所述分割掩膜按图形填充,得到所述分割掩膜的边缘的信息。3.根据权利要求1所述的方法,其特征在于,所述根据所述第一边缘预测结果和所述边缘信息的差别,确定第二网络损失,包括:对于所述样本图像的各个像素中位于所述分割掩膜的边缘的像素,根据所述像素的第一边缘预测结果和边缘信息的差别,确定第二网络损失。4.根据权利要求1所述的方法,其特征在于,所述第一分类预测结果包含预测得到的所述样本图像中各个像素的属于类别集合中每一个类别的概率;所述根据所述第一分类预测结果和所述标签信息的差别,确定第一网络损失,包括:对于所述样本图像中的每一个像素,根据所述标签信息中标注的所述像素的所属类别和所述第一分类预测结果包含的所述像素属于所属类别的概率,计算得到第一网络损失。5.根据权利要求1所述的方法,其特征在于,所述第一边缘预测结果包含预测得到的所述样本图像中各个像素的属于边缘类别或者背景类别的概率;所述根据所述第一边缘预测结果和所述边缘信息的差别,确定第二网络损失,包括:对于所述样本图像中的每一个像素,根据所述边缘信息标注的所述像素属于的边缘类别或者背景类别,以及所述第一边缘预测结果包含的所述像素属于边缘类别或者背景类别的概率,计算得到第二网络损失。6.根据权利要求1所述的方法,其特征在于,根据所述第一网络损失和所述第二损失调整所述伪标签模型的网络参数,包括:在每一轮训练中,对所述第一网络损失和第二网络损失进行加权求和,得到总网络损失;根据所述总网络损失,调整所述伪标签模型的网络参数;或者,在第一阶段训练的每一轮训练中,根据所述第一网络损失,对所述伪标签模型的网络参数进行调整,直至完成第一阶段训练,得到调整后的伪标签模型;在第二阶段训练的每一轮训练中,根据所述第二网络损失,调整第一阶段训练后的所述伪标签模型的网络参数,直至完成第二阶段的训练。7.一种半监督语义分割网络的...

【专利技术属性】
技术研发人员:王钰超费敬敬李韡吴立威
申请(专利权)人:上海商汤智能科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1