模型训练方法、装置、计算机设备和存储介质制造方法及图纸

技术编号:33120013 阅读:20 留言:0更新日期:2022-04-17 00:17
本申请涉及一种模型训练方法、装置、计算机设备、存储介质和计算机程序产品。所述方法包括:获取已训练的教师模型对样本图像中各像素的类别预测结果;根据已训练的教师模型对样本图像中各像素的类别预测结果,得到样本图像中各像素的信息量;在各预设图像类别下,基于样本图像中各像素的信息量、已训练的教师模型对样本图像中各像素的类别预测结果和待训练的学生模型对样本图像中各像素的类别预测结果,得到待训练的学生模型的目标损失函数;根据目标损失函数,对待训练的学生模型进行迭代训练,得到训练完成的学生模型;训练完成的学生模型用于对输入的图像进行语义分割。采用本方法能够提升学生模型整体的预测准确性。方法能够提升学生模型整体的预测准确性。方法能够提升学生模型整体的预测准确性。

【技术实现步骤摘要】
模型训练方法、装置、计算机设备和存储介质


[0001]本申请涉及计算机
,特别是涉及一种模型训练方法、装置、计算机设备、存储介质和计算机程序产品。

技术介绍

[0002]知识蒸馏技术是在模型训练过程中,使用一个规模较大的模型作为老师模型进行训练,提取出图像样本中的特征信息,然后将特征信息传递给规模较小的学生模型,使得规模较小的学生模型不仅速度较快,还能借助特征信息提升模型性能。
[0003]然而,传统的知识蒸馏技术是直接将蒸馏损失函数应用在所有图像样本上,并没有考虑图像样本之间的差异性,差异性包括图像样本的类别数量和图形样本包含的信息量,使得在模型训练过程中模型会更倾向于信息量较少的多数类样本,而忽视信息量较大的少数类样本,造成学生模型在信息量较大的少数类样本上的预测准确性较低。

技术实现思路

[0004]基于此,有必要针对上述技术问题,提供一种能够提升学生模型预测准确率的模型训练方法、装置、计算机设备、计算机可读存储介质和计算机程序产品。
[0005]第一方面,本申请提供了一种模型训练方法。所述方法包括:获取已训练的教师模型对样本图像中各像素的类别预测结果;根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量;在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数;根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型;所述训练完成的学生模型用于对输入的图像进行语义分割。
[0006]在其中一个实施例中,各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数包括:在各预设图像类别下,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,确定所述样本图像中各像素的信息量的权重;根据各预设图像类别下所述样本图像中各像素的信息量和所述样本图像中各像素的信息量的权重,得到所述待训练的学生模型的目标损失函数。
[0007]在其中一个实施例中,在各预设图像类别下,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别
预测结果,确定所述样本图像中各像素的信息量的权重,包括:根据所述样本图像中各像素的类别预测结果,从所述各预设图像类别中确定出所述样本图像中各像素所属的图像类别;根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息散度;所述信息散度表示所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果之间的距离;在所述各预设图像类别下,根据所述样本图像中各像素的信息散度,依次确定所述样本图像中所属的图像类别与所述预设图像类别相同的像素的信息量的权重,得到所述样本图像中各像素的信息量的权重。
[0008]在其中一个实施例中,根据所述样本图像中各像素的信息量和所述样本图像中各像素的信息量的权重,得到所述待训练的学生模型的目标损失函数,包括:在所述各预设图像类别下,分别根据样本图像中所属的图像类别与所述预设图像类别相同的像素的信息量和所述与所述预设图像类别相同的像素的信息量的权重,确定所述样本图像在所述各预设图像类别下的总信息量;根据所述样本图像在所述各预设图像类别下的总信息量之和的平均值,得到所述待训练的学生模型的目标损失函数。
[0009]在其中一个实施例中,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量,包括:根据所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素在所述各预设图像类别下的类别信息量;根据所述各像素在所述各预设图像类别下的类别信息量,得到所述各像素的信息量。
[0010]在其中一个实施例中,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量,还包括:分别从所述各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为所述各像素的目标预测结果;分别根据所述各像素的目标预测结果,得到所述样本图像中各像素的信息量。
[0011]在其中一个实施例中,分别从所述各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为所述各像素的目标预测结果,包括:分别从所述各像素的类别预测结果中,筛选出类别预测概率最大的类别预测结果,作为所述各像素的目标预测结果。
[0012]在其中一个实施例中,分别从所述各像素的类别预测结果中,筛选出满足第一预设条件的类别预测结果,作为所述各像素的目标预测结果,包括:分别从所述各像素的类别预测结果中,筛选出类别预测概率最大和类别预测概率第二大的类别预测结果,作为所述各像素的目标预测结果。
[0013]在其中一个实施例中,根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型,包括:获取所述待训练的学生模型的初始损失函数;
根据所述初始损失函数和所述目标损失函数,得到总损失函数;根据所述总损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型。
[0014]在其中一个实施例中,根据所述初始损失函数和所述目标损失函数,得到总损失函数,包括:将所述初始损失函数与所述目标损失函数进行相加,得到所述总损失函数。
[0015]第二方面,本申请还提供了一种模型训练装置。所述装置包括:像素预测模块,用于获取已训练的教师模型对样本图像中各像素的类别预测结果;信息提取模块,用于根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量;函数获取模块,用于在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数;模型获取模块,用于根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型;所述训练完成的学生模型用于对输入的图像进行语义分割。
[0016]第三方面,本申请还提供了一种计算机设备。所述计算机设备包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现以下步骤:获取已训练的教师模型对样本图像中各像素的类别预测结果;根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量;在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,所述方法包括:获取已训练的教师模型对样本图像中各像素的类别预测结果;根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息量;在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数;根据所述目标损失函数,对所述待训练的学生模型进行迭代训练,得到训练完成的学生模型;所述训练完成的学生模型用于对输入的图像进行语义分割。2.根据权利要求1所述的方法,其特征在于,所述在各预设图像类别下,基于所述样本图像中各像素的信息量、所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述待训练的学生模型的目标损失函数,包括:在各预设图像类别下,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,确定所述样本图像中各像素的信息量的权重;根据各预设图像类别下所述样本图像中各像素的信息量和所述样本图像中各像素的信息量的权重,得到所述待训练的学生模型的目标损失函数。3.根据权利要求2所述的方法,其特征在于,所述在各预设图像类别下,根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,确定所述样本图像中各像素的信息量的权重,包括:根据所述样本图像中各像素的类别预测结果,从所述各预设图像类别中确定出所述样本图像中各像素所属的图像类别;根据所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果,得到所述样本图像中各像素的信息散度;所述信息散度表示所述已训练的教师模型对所述样本图像中各像素的类别预测结果和待训练的学生模型对所述样本图像中各像素的类别预测结果之间的距离;在所述各预设图像类别下,根据所述样本图像中各像素的信息散度,依次确定所述样本图像中所属的图像类别与所述预设图像类别相同的像素的信息量的权重,得到所述样本图像中各像素的信息量的权重。4.根据权利要求3所述的方法,其特征在于,所述根据所述样本图像中各像素的信息量和所述样本图像中各像素的信息量的权重,得到所述待训练的学生模型的目标损失函数,包括:在所述各预设图像类别下,分别根据样本图像中所属的图像类别与所述预设图像类别相同的像素的信息量和所述与所述预设图像类别相同的像素的信息量的权重,确定所述样本图像在所述各预设图像类别下的总信息量;根据所述样本图像在所述各预设图像类别下的总信息量之和的平均值,得到所述待训练的学生模型的目标损失函数。5.根据权利要求1所述的方法,其特征在于,所述根据所述已训练的教师模型对所述样
本图像中各像素的类别预测结果...

【专利技术属性】
技术研发人员:田倬韬易振彧刘枢吕江波沈小勇
申请(专利权)人:苏州思谋智能科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1