System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind() 一种图像分类模型的训练方法、装置及电子设备制造方法及图纸_技高网

一种图像分类模型的训练方法、装置及电子设备制造方法及图纸

技术编号:41275404 阅读:4 留言:0更新日期:2024-05-11 09:28
本申请实施例提供了一种图像分类模型的训练方法、装置及电子设备,涉及计算机视觉技术领域,本申请实施例包括:针对每两张样本图像,对该两张样本图像按照指定比例混合,得到混合图像,并对该两张样本图像的训练标签按照指定比例混合,得到混合图像的混合标签。再将样本图像和混合图像分别输入图像分类网络,之后基于图像分类网络输出的样本图像所属的类别和训练标签,确定样本损失值,并基于图像分类网络输出的混合图像所属的类别和混合标签,确定混合损失值。再基于样本损失值和混合损失值,调整图像分类网络的网络参数,直至图像分类网络收敛时,将当前的图像分类网络作为图像分类模型。能够提高图像分类的准确度。

【技术实现步骤摘要】

本申请涉及计算机视觉,特别是涉及一种图像分类模型的训练方法、装置及电子设备


技术介绍

1、图像分类技术的应用范围十分广泛,例如,可应用于人脸识别、自动驾驶、智能家居和医学影像分析等领域。使用神经网络模型能够预测图像属于每种预设类别的概率,实现了对图像的分类,该方式能提高图像分类的速度和准确度,因此神经网络的快速发展进一步推动了图像分类技术在各个领域中落地。

2、在实际落地场景中,需要使用高可信的神经网络模型。即实际应用中期望神经网络模型预测的更高的概率对应的分类结果更可能是正确的;且更低的概率对应的分类结果更可能是不准确的,意味着神经网络对此次预测结果不太确认。

3、为了保证神经网络模型对图像进行分类的准确率,并降低误报率,目前在神经网络的训练过程中,通常将样本图像输入神经网络模型,得到神经网络模型输出的样本图像属于每种预设类别的概率,之后基于神经网络模型的输出结果与样本图像的训练标签计算损失值,然后利用损失值调整神经网络模型的网络参数。但该方式对于提高神经网络模型的预测准确度的效果有限,即训练后的得到神经网络模型对图像分类的准确度不够高,存在较多的误报。


技术实现思路

1、本申请实施例的目的在于提供一种图像分类模型的训练方法、装置及电子设备,以提高图像分类的准确度。具体技术方案如下:

2、本申请实施例的第一方面,提供了一种图像分类模型的训练方法,所述方法包括:

3、获取多张样本图像以及每张样本图像的训练标签,每张样本图像的训练标签表示该样本图像实际所属的类别;

4、针对每两张样本图像,对该两张样本图像按照指定比例混合,得到混合图像,并对该两张样本图像的训练标签按照所述指定比例混合,得到所述混合图像的混合标签;

5、将所述样本图像和所述混合图像分别输入图像分类网络,得到所述图像分类网络输出的所述样本图像所属的类别和所述混合图像所属的类别;

6、基于所述图像分类网络输出的所述样本图像所属的类别和所述样本图像的训练标签,确定样本损失值;

7、基于所述图像分类网络输出的所述混合图像所属的类别和所述混合图像的混合标签,确定混合损失值;

8、基于所述样本损失值和所述混合损失值,调整所述图像分类网络的网络参数,并返回所述获取多张样本图像以及每张样本图像的训练标签的步骤,直至所述图像分类网络收敛时,将当前的图像分类网络作为图像分类模型。

9、可选的,每张原始样本的尺寸均相同;所述对该两张样本图像按照指定比例混合,得到混合图像,包括:

10、在0到1范围内采样,得到混合权重;

11、计算该两张样本图像中的第一样本图像的像素值与所述混合权重的第一乘积;

12、计算该两张样本图像中的第二样本图像的像素值与指定权重的第二乘积;其中,所述指定权重为1与所述混合权重的差值;

13、计算所述第一乘积与所述第二乘积的和值,作为所述混合图像。

14、可选的,所述对该两张样本图像的训练标签按照所述指定比例混合,得到所述混合图像的混合标签,包括:

15、计算所述第一样本图像的训练标签与所述混合权重的第三乘积;

16、计算所述第二样本图像的训练标签与所述指定权重的第四乘积;

17、计算所述第三乘积与所述第四乘积的和值,作为所述混合图像的混合标签。

18、可选的,多次采样获得的混合权重满足贝塔分布。

19、可选的,所述获取多张样本图像以及每张样本图像的训练标签,包括:

20、确定上一次从样本图像集包括的多组样本图像中选择的一组样本图像;

21、若上一次选择的一组样本图像不为最后一组样本图像,则获取上一次选择的样本图像组的下一组样本图像以及所述下一组样本图像的训练标签;

22、若上一次选择的一组样本图像为最后一组样本图像,则获取第一组样本图像以及所述第一组样本图像的训练标签。

23、可选的,所述图像分类网络输出的所述样本图像所属的类别包括所述样本图像属于每种预设类别的概率;所述基于所述样本损失值和所述混合损失值,调整所述图像分类网络的网络参数,包括:

24、获取每张样本图像的正确分类次数以及目标预测概率;其中,所述正确分类次数为:所述图像分类网络输出的该样本图像属于每种预设类别的概率中,最大概率对应的类别与该样本图像的训练标签表示的目标类别相同的次数;所述目标预测概率为:所述图像分类网络输出的该样本图像属于目标类别的概率;

25、针对每两张样本图像,根据该两张样本图像的正确分类次数之间的差值,以及该两张样本图像的目标预测概率之间的差值,确定该两张样本图像之间的一致性偏差;

26、计算每两张样本图像之间的一致性偏差的和值,作为一致性损失值;

27、根据所述样本损失值、所述混合损失值和所述一致性损失值,确定总损失值;

28、利用所述总损失值,调整所述图像分类网络的网络参数。

29、可选的,每两张样本图像之间的一致性偏差为:

30、;

31、其中,表示第i张样本图像和第j张样本图像之间的一致性偏差,表示第i张样本图像,表示第j张样本图像,表示第i张样本图像的正确分类次数,表示第j张样本图像的正确分类次数,max和sign分别表示函数运算,表示第i张样本图像的目标预测概率,表示第j张样本图像的目标预测概率。

32、可选的,所述总损失值为:

33、;

34、其中,为所述总损失值,为所述样本损失值,和为预设的超参数,为所述混合损失值,为所述一致性损失值。

35、可选的,所述利用所述总损失值,调整所述图像分类网络的网络参数,包括:

36、基于锐度感知最小化优化算法,将所述总损失值的最小值对应的所述图像分类网络的网络参数,作为候选网络参数;

37、若当前迭代次数未达到指定次数,则将所述图像分类网络的网络参数修改为本次计算的候选网络参数;

38、若当前迭代次数达到指定次数,则计算最近的预设次数针对最后一组样本图像确定的候选网络参数的平均值,将所述图像分类网络的网络参数修改为所述平均值。

39、可选的,所述图像分类网络包括特征提取层和余弦分类器,所述特征提取层用于对输入的图像进行特征提取得到图像特征,所述余弦分类器用于基于所述图像特征与每种预设类别的权重之间的余弦相似度,确定输入的图像属于每种预设类别的概率。

40、本申请实施例的第二方面,提供了一种图像分类模型的训练装置,所述装置包括:

41、获取模块,用于获取多张样本图像以及每张样本图像的训练标签,每张样本图像的训练标签表示该样本图像实际所属的类别;

42、增强模块,用于针对每两张样本图像,对该两张样本图像按照指定比例混合,得到混合图像,并对该两张样本图像的训练标签按照所述指定比例混合,得到所述混合图像的混合本文档来自技高网...

【技术保护点】

1.一种图像分类模型的训练方法,其特征在于,所述方法包括:

2.根据权利要求1所述的方法,其特征在于,每张原始样本的尺寸均相同;所述对该两张样本图像按照指定比例混合,得到混合图像,包括:

3.根据权利要求2所述的方法,其特征在于,所述对该两张样本图像的训练标签按照所述指定比例混合,得到所述混合图像的混合标签,包括:

4.根据权利要求2所述的方法,其特征在于,多次采样获得的混合权重满足贝塔分布。

5.根据权利要求1所述的方法,其特征在于,所述获取多张样本图像以及每张样本图像的训练标签,包括:

6.根据权利要求5所述的方法,其特征在于,所述图像分类网络输出的所述样本图像所属的类别包括所述样本图像属于每种预设类别的概率;所述基于所述样本损失值和所述混合损失值,调整所述图像分类网络的网络参数,包括:

7.根据权利要求6所述的方法,其特征在于,每两张样本图像之间的一致性偏差为:

8.根据权利要求6所述的方法,其特征在于,所述总损失值为:

9.根据权利要求6所述的方法,其特征在于,所述利用所述总损失值,调整所述图像分类网络的网络参数,包括:

10.根据权利要求1-9任一项所述的方法,其特征在于,所述图像分类网络包括特征提取层和余弦分类器,所述特征提取层用于对输入的图像进行特征提取得到图像特征,所述余弦分类器用于基于所述图像特征与每种预设类别的权重之间的余弦相似度,确定输入的图像属于每种预设类别的概率。

11.一种图像分类模型的训练装置,其特征在于,所述装置包括:

12.根据权利要求11所述的装置,其特征在于,每张原始样本的尺寸均相同;所述增强模块,具体用于:

13.根据权利要求12所述的装置,其特征在于,所述增强模块,具体用于:

14.根据权利要求12所述的装置,其特征在于,多次采样获得的混合权重满足贝塔分布。

15.根据权利要求11所述的装置,其特征在于,所述获取模块,具体用于:

16.根据权利要求15所述的装置,其特征在于,所述图像分类网络输出的所述样本图像所属的类别包括所述样本图像属于每种预设类别的概率;所述调整模块,具体用于:

17.根据权利要求16所述的装置,其特征在于,每两张样本图像之间的一致性偏差为:

18.根据权利要求16所述的装置,其特征在于,所述总损失值为:

19.根据权利要求16所述的装置,其特征在于,所述调整模块,具体用于:

20.根据权利要求11-19任一项所述的装置,其特征在于,所述图像分类网络包括特征提取层和余弦分类器,所述特征提取层用于对输入的图像进行特征提取得到图像特征,所述余弦分类器用于基于所述图像特征与每种预设类别的权重之间的余弦相似度,确定输入的图像属于每种预设类别的概率。

21.一种电子设备,其特征在于,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;

22.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现权利要求1-10任一项所述的方法。

...

【技术特征摘要】

1.一种图像分类模型的训练方法,其特征在于,所述方法包括:

2.根据权利要求1所述的方法,其特征在于,每张原始样本的尺寸均相同;所述对该两张样本图像按照指定比例混合,得到混合图像,包括:

3.根据权利要求2所述的方法,其特征在于,所述对该两张样本图像的训练标签按照所述指定比例混合,得到所述混合图像的混合标签,包括:

4.根据权利要求2所述的方法,其特征在于,多次采样获得的混合权重满足贝塔分布。

5.根据权利要求1所述的方法,其特征在于,所述获取多张样本图像以及每张样本图像的训练标签,包括:

6.根据权利要求5所述的方法,其特征在于,所述图像分类网络输出的所述样本图像所属的类别包括所述样本图像属于每种预设类别的概率;所述基于所述样本损失值和所述混合损失值,调整所述图像分类网络的网络参数,包括:

7.根据权利要求6所述的方法,其特征在于,每两张样本图像之间的一致性偏差为:

8.根据权利要求6所述的方法,其特征在于,所述总损失值为:

9.根据权利要求6所述的方法,其特征在于,所述利用所述总损失值,调整所述图像分类网络的网络参数,包括:

10.根据权利要求1-9任一项所述的方法,其特征在于,所述图像分类网络包括特征提取层和余弦分类器,所述特征提取层用于对输入的图像进行特征提取得到图像特征,所述余弦分类器用于基于所述图像特征与每种预设类别的权重之间的余弦相似度,确定输入的图像属于每种预设类别的概率。

11.一种图像分类模型的训练装置,其特征在于,所述装置包括:...

【专利技术属性】
技术研发人员:沈西杨再初李昱廷林容泰黄世华
申请(专利权)人:英特灵达信息技术深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1