图像分类模型的知识蒸馏方法、装置和计算机设备制造方法及图纸

技术编号:27061351 阅读:26 留言:0更新日期:2021-01-15 14:41
本申请涉及一种图像分类模型的知识蒸馏方法、装置和计算机设备,所述方法通过获取用于知识蒸馏的训练数据集,将训练数据集中的每一个样本图像输入第一分类模型进行图像特征提取,得到与训练数据集对应的图像特征集合,对于训练数据集中任意一个样本图像,通过第二分类模型进行图像特征提取,得到对应的第二图像特征,并根据训练数据集中任意一个样本图像对应的第二图像特征以及训练数据集对应的图像特征集合确定蒸馏损失,根据蒸馏损失对第二分类模型进行反向传播,更新第二分类模型的模型参数,从而使得第二分类模型能够真正学习到样本图像之间的关系特征,并达到与第一分类模型类似的预测效果,以得到性能更佳的图像分类模型。

【技术实现步骤摘要】
图像分类模型的知识蒸馏方法、装置和计算机设备
本申请涉及机器学习
,特别是涉及一种图像分类模型的知识蒸馏方法、装置和计算机设备。
技术介绍
随着机器学习技术的发展,采用模型进行图像处理或识别已经越来越普遍。通常来说,更大更深更复杂的模型有着更好的拟合效果与更好的预测能力,但同时其计算效率低、耗时大、参数量大,从而不利于移动端、芯片端等应用层的部署;而简单模型虽然拟合能力弱,但其计算效率更高、参数量更少,从而更利于部署。而知识蒸馏(knowledgedistillation)作为一种重要的模型压缩手段,可以将复杂模型(teacher,也称教师模块或第一模型)中的知识(darkknowledge)迁移到简单模型(student,也称学生模型或第二模型)中,来使得student模型的拟合能力能够逼近甚至超过teacher模型,从而用更少的时间和空间复杂度来得到类似的预测效果。然而,不同的知识蒸馏方法对于需要迁移的darkknowledge定义不同。其中,在分类任务上,传统的基于输出结果的知识蒸馏,通常是通过直接拉近teacher模型与student模型的输出值(模型softmax(逻辑回归)之后输出的logits(一个事件发生与该事件不发生的比值的对数)值或者网络中间层输出的feature(特征))之间的距离,使得student模型得到近似teacher模型的预测结果。但目前基于图像分类任务的知识蒸馏方法通常只考虑到了单个batch(批)样本层面的信息提取,不具有全局性,从而无法得到性能更佳的图像分类student模型。
技术实现思路
基于此,有必要针对上述传统基于输出结果的知识蒸馏无法得到性能更佳的student模型的问题,提供一种能够得到更佳性能的图像分类模型的知识蒸馏方法、装置和计算机设备。一种图像分类模型的知识蒸馏方法,所述方法包括:获取用于知识蒸馏的训练数据集,所述训练数据集中包括多个样本图像;将所述训练数据集中的每一个样本图像输入第一分类模型进行图像特征提取,得到与所述训练数据集对应的图像特征集合,所述图像特征集合中包括与每一个样本图像一一对应的第一图像特征;对于所述训练数据集中任意一个样本图像,通过第二分类模型进行图像特征提取,得到对应的第二图像特征;根据所述训练数据集中任意一个样本图像对应的第二图像特征以及所述训练数据集对应的图像特征集合确定蒸馏损失;根据所述蒸馏损失对所述第二分类模型进行反向传播,更新所述第二分类模型的模型参数。在其中一个实施例中,所述根据所述训练数据集中任意一个样本图像对应的第二图像特征以及所述训练数据集对应的图像特征集合确定蒸馏损失,包括:基于所述训练数据集中的任意一个样本图像,获取所述图像特征集合中与所述样本图像对应的正例特征和负例特征;根据所述样本图像对应的正例特征、负例特征以及第二图像特征,计算所述蒸馏损失。在其中一个实施例中,所述基于所述训练数据集中的任意一个样本图像,获取所述图像特征集合中与所述样本图像对应的正例特征和负例特征,包括:基于所述训练数据集中的任意一个样本图像,获取所述图像特征集合中与所述样本图像对应的第一图像特征,确定为与所述样本图像对应的正例特征;获取所述图像特征集合中除与所述样本图像对应的正例特征之外的其他第一图像特征,确定为与所述样本图像对应的负例特征。在其中一个实施例中,所述根据所述样本图像对应的正例特征、负例特征以及第二图像特征,计算所述蒸馏损失,包括:根据所述样本图像对应的正例特征、负例特征以及第二图像特征,采用信息噪声收敛估计损失函数计算所述蒸馏损失,所述信息噪声收敛估计损失函数其中,q为所述样本图像对应的第二图像特征,t+为所述样本图像对应的正例特征,ti-为所述样本图像对应的负例特征,K为所述负例特征的个数,τ为超参数。在其中一个实施例中,所述将所述训练数据集中的每一个样本图像输入第一分类模型进行图像特征提取,得到与所述训练数据集对应的图像特征集合,包括:将所述训练数据集中的每一个样本图像输入第一分类模型,得到与每一个样本图像一一对应的第一原始图像特征;对于每一个第一原始图像特征分别进行范数归一化处理,得到对应处理后的第一图像特征;基于每一个第一原始图像特征分别对应的第一图像特征,得到对应的图像特征集合。在其中一个实施例中,所述对于所述训练数据集中任意一个样本图像,通过第二分类模型进行图像特征提取,得到对应的第二图像特征,包括:对于所述训练数据集中任意一个样本图像,将样本图像输入第二分类模型,得到与所述样本图像对应的第二原始图像特征;对所述第二原始图像特征进行仿射变换以及范数归一化处理,得到对应处理后的第二图像特征。一种图像分类处理方法,所述方法包括:获取待处理图像;通过图像分类模型对所述待处理图像进行分类处理,得到所述待处理图像的分类结果,所述图像分类模型为通过如上所述的图像分类模型的知识蒸馏方法得到的第二分类模型。一种图像分类模型的知识蒸馏装置,所述装置包括:训练数据集获取模块,用于获取进行知识蒸馏的训练数据集,所述训练数据集中包括多个样本图像;第一图像特征提取模块,用于将所述训练数据集中的每一个样本图像输入第一分类模型进行图像特征提取,得到与所述训练数据集对应的图像特征集合,所述图像特征集合中包括与每一个样本图像一一对应的第一图像特征;第二图像特征提取模块,用于对于所述训练数据集中任意一个样本图像,通过第二分类模型进行图像特征提取,得到对应的第二图像特征;蒸馏损失获取模块,用于根据所述训练数据集中任意一个样本图像对应的第二图像特征以及所述训练数据集对应的图像特征集合确定蒸馏损失;反向传播模块,用于根据所述蒸馏损失对所述第二分类模型进行反向传播,更新所述第二分类模型的模型参数。一种图像分类处理装置,所述装置包括:待处理图像获取模块,用于获取待处理图像;分类处理模块,用于通过图像分类模型对所述待处理图像进行分类处理,得到所述待处理图像的分类结果,所述图像分类模型为通过如上所述的图像分类模型的知识蒸馏方法得到的第二分类模型。一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现如上所述方法的步骤。一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现如上所述方法的步骤。上述图像分类模型的知识蒸馏方法、装置和计算机设备,通过获取用于知识蒸馏的训练数据集,将训练数据集中的每一个样本图像输入第一分类模型进行图像特征提取,得到与训练数据集对应的图像特征集合,对于训练数据集中任意一个样本图像,通过第二分类模型进行图像特征提取,得到对应的第二图像特征,并根据训练数据集中任意一个样本图像对应的第二图像特征以及训练数据集对应的图像特征集合确定蒸馏损失,根据蒸馏损失对第二分类模型进行反向传播,更新第二分类模型的模型参数,从而使得第二分类模型能够真正学习到样本图像之间的关系特征,并达到与第一本文档来自技高网
...

【技术保护点】
1.一种图像分类模型的知识蒸馏方法,其特征在于,所述方法包括:/n获取用于知识蒸馏的训练数据集,所述训练数据集中包括多个样本图像;/n将所述训练数据集中的每一个样本图像输入第一分类模型进行图像特征提取,得到与所述训练数据集对应的图像特征集合,所述图像特征集合中包括与每一个样本图像一一对应的第一图像特征;/n对于所述训练数据集中任意一个样本图像,通过第二分类模型进行图像特征提取,得到对应的第二图像特征;/n根据所述训练数据集中任意一个样本图像对应的第二图像特征以及所述训练数据集对应的图像特征集合确定蒸馏损失;/n根据所述蒸馏损失对所述第二分类模型进行反向传播,更新所述第二分类模型的模型参数。/n

【技术特征摘要】
1.一种图像分类模型的知识蒸馏方法,其特征在于,所述方法包括:
获取用于知识蒸馏的训练数据集,所述训练数据集中包括多个样本图像;
将所述训练数据集中的每一个样本图像输入第一分类模型进行图像特征提取,得到与所述训练数据集对应的图像特征集合,所述图像特征集合中包括与每一个样本图像一一对应的第一图像特征;
对于所述训练数据集中任意一个样本图像,通过第二分类模型进行图像特征提取,得到对应的第二图像特征;
根据所述训练数据集中任意一个样本图像对应的第二图像特征以及所述训练数据集对应的图像特征集合确定蒸馏损失;
根据所述蒸馏损失对所述第二分类模型进行反向传播,更新所述第二分类模型的模型参数。


2.根据权利要求1所述的方法,其特征在于,所述根据所述训练数据集中任意一个样本图像对应的第二图像特征以及所述训练数据集对应的图像特征集合确定蒸馏损失,包括:
基于所述训练数据集中的任意一个样本图像,获取所述图像特征集合中与所述样本图像对应的正例特征和负例特征;
根据所述样本图像对应的正例特征、负例特征以及第二图像特征,计算所述蒸馏损失。


3.根据权利要求2所述的方法,其特征在于,所述基于所述训练数据集中的任意一个样本图像,获取所述图像特征集合中与所述样本图像对应的正例特征和负例特征,包括:
基于所述训练数据集中的任意一个样本图像,获取所述图像特征集合中与所述样本图像对应的第一图像特征,确定为与所述样本图像对应的正例特征;
获取所述图像特征集合中除与所述样本图像对应的正例特征之外的其他第一图像特征,确定为与所述样本图像对应的负例特征。


4.根据权利要求3所述的方法,其特征在于,所述根据所述样本图像对应的正例特征、负例特征以及第二图像特征,计算所述蒸馏损失,包括:
根据所述样本图像对应的正例特征、负例特征以及第二图像特征,采用信息噪声收敛估计损失函数计算所述蒸馏损失,所述信息噪声收敛估计损失函数其中,q为所述样本图像对应的第二图像特征,t+为所述样本图像对应的正例特征,ti-为所述样本图像对应的第i个负例特征,K为所述负例特征的个数,τ为超参数。


5.根据权利要求1至4任一项所述的方法,其特征在于,所述将所述训练数据集中的每一个样本图像输入第一分类模型进行图像特征提取,得到与所述训练数据集对应的图像特征集合,包括:
将所述训练数据集中的每一个样本图像输入第一分类模型,得到与每一个样本图像一一对应的第一原始图像特征;
对于每一个第一原始图像特征分别进...

【专利技术属性】
技术研发人员:刘琦
申请(专利权)人:上海眼控科技股份有限公司
类型:发明
国别省市:上海;31

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1