图像分割模型的训练与蒸馏方法、电子设备和存储介质技术

技术编号:36204530 阅读:11 留言:0更新日期:2023-01-04 11:59
本发明专利技术实施例提供一种图像分割模型的训练与知识蒸馏方法、电子设备和存储介质。该训练方法包括:获取训练样本图像;将训练样本图像输入至图像分割模型中,并基于第一预设损失函数对图像分割模型进行迭代训练;其中,第一预设损失函数包括以下交叉熵之和:第一类别所对应的交叉熵:基于第一类别在预设logits调整参数下得到的预测值而计算得到交叉熵值;带有衰减倍数的第二类别所对应的交叉熵:衰减倍数在第一预设衰减参数下将(1

【技术实现步骤摘要】
图像分割模型的训练与蒸馏方法、电子设备和存储介质


[0001]本专利技术涉及图像处理
,更具体地涉及一种图像分割模型的训练方法、应用于图像分割模型的知识蒸馏方法、电子设备和存储介质。

技术介绍

[0002]近年来,图像分割一直是一个活跃的研究领域,例如该技术可以帮助修复医疗领域的漏洞,并帮助大众。在图像分割领域,通常使用图像分割模型对图像进行分割。图像分割模型可以使用训练样本图像,基于损失函数进行训练。
[0003]现有图像分割模型的损失函数可以表示为:其中,N为每次迭代的样本总数,M表示样本类别总数,c表示当前类别,c=i表示第i类别为当前类别c。此外,T
c
是指c类的真实值或标签值,例如,若像素是c类,那么T
c
就等于1,否则T
c
就等于0,P
c
是指c类的预测值。该损失函数应用在样本不均衡以及小目标或者细小目标的场景下,可能存在漏检率高,交并比(IOU)低等问题。

技术实现思路

[0004]考虑到上述问题而提出了本专利技术。本专利技术提供了一种图像分割模型的训练方法、应用于图像分割模型的知识蒸馏方法、电子设备和存储介质。
[0005]根据本专利技术的第一方面,提供了一种图像分割模型的训练方法,包括:获取训练样本图像;将训练样本图像输入至图像分割模型中,并基于第一预设损失函数对图像分割模型进行迭代训练;其中,第一预设损失函数包括以下交叉熵之和:第一类别所对应的交叉熵:基于第一类别在预设logits调整参数下得到的预测值而计算得到交叉熵值;带有衰减倍数的第二类别所对应的交叉熵:衰减倍数在第一预设衰减参数下将(1

第二类别的预测值)非线性衰减。
[0006]示例性地,第一类别为稀有类,第二类别为非稀有类,稀有类是指训练样本图像所包含的所有类别中标注数量或标注比例小于预设阈值的类别。
[0007]示例性地,第一预设衰减参数为γ;衰减倍数为(1

第二类别的预测值)的γ次方,γ为非负数。
[0008]示例性地,γ的取值范围为[1,5]。
[0009]示例性地,基于每个第一类别的logits值与预设logits调整参数m之差获得该第一类别的预测值,m为非负数。
[0010]示例性地,m的取值范围为[0.5,1.5]。
[0011]示例性地,第一预设损失函数的公式如下:
[0012][0013]其中,N表示每次迭代训练时训练样本图像的总数量;
[0014]T
c=i
表示当前类别c为第一类别i时的真实值;
[0015]表示当前类别c为第一类别i时基于预设logits调整参数m下获得的预测值;
[0016]T
c≠i
表示当前类别c为任一第二类别时的真实值;
[0017]P
c≠i
表示当前类别c为任一第二类别时的预测值;
[0018]γ表示第一预设衰减参数。
[0019]根据本专利技术的第二方面,还提供了一种图像分割模型的训练方法,包括:获取训练样本图像;将训练样本图像输入至图像分割模型中,并基于第二预设损失函数对图像分割模型进行迭代训练;其中,第二预设损失函数包括多个类别所对应的交叉熵之和;每个类别的交叉熵带有预设衰减倍数,衰减倍数在第二预设衰减参数下将(1

该类别的预测值)非线性衰减;每个类别在预设logits调整参数下得到该类别的预测值,根据预测值计算该类别的交叉熵。
[0020]示例性地,第二预设衰减参数为γ;衰减倍数为(1

对应类别所对应的预测值)的γ次方,γ为非负数。
[0021]示例性地,γ的取值范围为[1,5]。
[0022]示例性地,基于每个类别的logits值与预设logits调整参数m之差获得该类别的预测值,m为非负数。
[0023]示例性地,m的取值范围为[0.5,1.5]。
[0024]示例性地,第二预设损失函数的公式如下:
[0025][0026]其中,N表示每次迭代训练时训练样本图像的总数量;
[0027]T
c
表示当前类别c的真实值;
[0028]表示当前类别c的基于预设logits调整参数m下获得的预测值;
[0029]γ表示第二预设衰减参数;
[0030]M表示多个类别的总数目;
[0031]c=i表示第i类别为当前类别c。
[0032]根据本专利技术的第三方面,还提供了一种应用于图像分割模型的知识蒸馏方法,包括:获取训练样本图像;将训练样本图像分别输入学生分割模型和教师分割模型中进行模型训练,在训练过程中基于蒸馏损失函数来计算损失;其中,蒸馏损失函数包括基于软目标计算的软损失函数和基于硬目标计算的硬损失函数,其中,软损失函数采用上述的第一预设损失函数或上述的第二预设损失函数实现。
[0033]示例性地,硬损失函数包括:在目标类别下的真实样本像素集合和实际预测样本像素集合的多种交集一一对应的多种概率值,每种概率值与对应交集中的像素属于目标类别的真实值和预测值相关,并且每种概率值的预测值在硬损失函数中被配置为在第三预设衰减参数的作用下进行非线性衰减。
[0034]示例性地,真实样本像素集合包括真实正样本像素集合和真实负样本像素集合,实际预测样本像素集合包括实际预测正样本像素集合和实际预测负样本像素集合,多种交集包括第一交集、第二交集和第三交集,其中,第一交集为目标类别下真实正样本像素集合
和实际预测正样本像素集合的交集;第二交集为目标类别下真实正样本像素集合和实际预测负样本像素集合的交集;第三交集为目标类别下真实负样本像素集合和实际预测正样本像素集合的交集。
[0035]示例性地,第二交集和第三交集对应的概率值具有对应的权重系数,第二交集所对应的概率值的权重系数大于第三交集所对应的概率值的权重系数。
[0036]根据本专利技术的第四方面,还提供了一种电子设备,包括处理器和存储器,存储器中存储有计算机程序,处理器执行计算机程序以实现上述的图像分割模型的训练方法、或者上述的知识蒸馏方法。
[0037]根据本专利技术的第五方面,还提供了一种存储介质,存储有计算机程序/指令,所述计算机程序/指令被处理器执行时实现上述的图像分割模型的训练方法、或者上述的知识蒸馏方法。
[0038]根据本专利技术实施例的图像分割模型的训练方法、应用于图像分割模型的知识蒸馏方法、电子设备和存储介质,通过损失函数中的预设logtis调整参数和预设衰减参数,第一,有助于解决样本不均衡以及少量数据带来的图像分割模型过拟合等问题,适用于小样本学习;第二,非常适用于细小目标的分割,尤其是对于需要考虑全局视野的小目标,降低误检和漏检。
附图说明
[0039]通过结合附图对本专利技术实施例进行更详细的描述,本专利技术的上述以及其它目的、特征和优势将变得更加明显。附图用来提供对本专利技术实施例的进一本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种图像分割模型的训练方法,其特征在于,包括:获取训练样本图像;将所述训练样本图像输入至图像分割模型中,并基于第一预设损失函数对所述图像分割模型进行迭代训练;其中,所述第一预设损失函数包括以下交叉熵之和:第一类别所对应的交叉熵:基于所述第一类别在预设logits调整参数下得到的预测值而计算得到交叉熵值;带有衰减倍数的第二类别所对应的交叉熵:所述衰减倍数在第一预设衰减参数下将(1

所述第二类别的预测值)非线性衰减。2.如权利要求1所述的方法,其中,所述第一类别为稀有类,所述第二类别为非稀有类,稀有类是指所述训练样本图像所包含的所有类别中标注数量或标注比例小于预设阈值的类别。3.如权利要求1所述的方法,其中,所述第一预设衰减参数为γ;所述衰减倍数为(1

所述第二类别的预测值)的γ次方,γ为非负数。4.如权利要求3所述的方法,其中,γ的取值范围为[1,5]。5.如权利要求1所述的方法,其中,基于每个所述第一类别的logits值与预设logits调整参数m之差获得该第一类别的预测值,m为非负数。6.如权利要求5所述的方法,其中,m的取值范围为[0.5,1.5]。7.如权利要求1

6任一项所述的方法,其中,所述第一预设损失函数的公式如下:其中,N表示每次迭代训练时所述训练样本图像的总数量;T
c=i
表示当前类别c为第一类别i时的真实值;表示当前类别c为第一类别i时基于所述预设logits调整参数m下获得的预测值;T
c≠i
表示当前类别c为任一第二类别时的真实值;P
c≠i
表示当前类别c为任一第二类别时的预测值;γ表示所述第一预设衰减参数。8.一种图像分割模型的训练方法,其特征在于,包括:获取训练样本图像;将所述训练样本图像输入至图像分割模型中,并基于第二预设损失函数对所述图像分割模型进行迭代训练;其中,所述第二预设损失函数包括多个类别所对应的交叉熵之和;每个所述类别的所述交叉熵带有预设衰减倍数,所述衰减倍数在第二预设衰减参数下将(1

该类别的预测值)非线性衰减;每个所述类别在预设logits调整参数下得到该类别的预测值,根据所述预测值计算该类别的交叉熵。9.如权利要求8所述的方法,其中,所述第二预设衰减参数为γ;所述衰减倍数为(1

对应类别所对应的预测值)的γ次方,γ为非负数。
10.如权利要求9所述的方法,其中,γ的取值范围为[1,5]。11.如权利要求8所述的方法,其中,基于每个所述类别的logits值与预设logits调整参数m之差获得该类别的预测值,m为非负数。12.如权利要求11所述的方法,其中,m的取值范围为[0.5,1.5]。13.如权利要求8

【专利技术属性】
技术研发人员:陈燕娟孙新张共济
申请(专利权)人:苏州镁伽科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1