一种基于知识蒸馏的高效图像分类方法及系统技术方案

技术编号:32735782 阅读:18 留言:0更新日期:2022-03-20 08:42
本发明专利技术实施例公开了一种基于知识蒸馏的高效图像分类方法及系统,利用大模型学习到的知识去指导小模型进行训练,使得小模型在保证参数量和计算量可以大幅降低的情况下,获得与大模型相当的性能,从而实现模型的压缩与加速,大大提高图像分类的效率,改进后的图像分类器可以在保证原始图像分类器性能的同时,大大减少网络模型在进行图像分类时产生的不必要的计算量和存储消耗,使得小型的图像分类器可以更加高效的对图像进行分类,同时产生较小的内存消耗,从而有效节省计算和存储资源,使CNN可以更好地应用到图像分类算法中。CNN可以更好地应用到图像分类算法中。CNN可以更好地应用到图像分类算法中。

【技术实现步骤摘要】
一种基于知识蒸馏的高效图像分类方法及系统


[0001]本专利技术实施例涉及图像处理、计算机视觉
,具体涉及一种基于知识蒸馏的高效图像分类方法及系统。

技术介绍

[0002]近年来,深度学习通过分层式结构的多层信息处理来进行非监督的特征学习和图像分类,模拟人脑学习和分析的能力,形成一个神经网络结构,从而像人脑一样对外界输入事物进行分析和理解,相对于通过浅层学习获得图像底层特征的传统图像分类方法,深度学习利用设定好的网络结构,完全从训练数据中学习图像的层级结构性特征,能够提取更加接近图像高级语义的抽象特征,因此在图像分类上的表现远远超过传统方法。
[0003]深度卷积神经网络(Convolutional Neural Network,CNN)在特征表示上具有极大的优越性,模型提取的特征随着网络深度的增加越来越抽象,越来越越能表现图像主题语义。所以,CNN通过对图像交替进行卷积核池化操作,逐渐提取图像的高层特征,再使用神经网络对特征分类,以此来实现对图像进行分类的功能,在图像分类领域表现出了极大的优势。然而,随着对图像分类性能需求的增加,基于深度卷积神经网络的图像分类算法成为高性能图像分类体系结构设计的基础。但CNN通常会产生大量的计算和存储消耗,特别是利用大量数据学习得到的大规模网络在进行图像分类时效率较低,同时会产生较大的内存消耗,占用过多的计算和存储资源,大大阻碍了CNN在图像分类算法中的应用。

技术实现思路

[0004]为此,本专利技术实施例提供一种基于知识蒸馏的高效图像分类方法及系统,以解决现有的利用大量数据学习得到的大规模网络在进行图像分类时效率较低,同时会产生较大的内存消耗,占用过多的计算和存储资源的问题。
[0005]为了实现上述目的,本专利技术实施例提供如下技术方案:
[0006]根据本专利技术实施例的第一方面,提出了一种基于知识蒸馏的高效图像分类方法,所述方法包括:
[0007]构建训练集和测试集,并对所述训练集和测试集中的图像进行类别标注;
[0008]对所述训练集合测试集中的图像进行预处理;
[0009]使用预处理后的训练集对多个学生网络模型同时进行训练,训练过程中利用预先训练好的教师网络模型对学生网络模型的训练进行指导,并在目标损失函数中增加相似性损失和多样性损失,输出准确率较高的学生网络模型;
[0010]使用预处理后的测试集对输出的学生网络模型进行图像分类测试。
[0011]进一步地,训练过程中利用预先训练好的教师网络模型对学生网络模型的训练进行指导,具体包括:
[0012]同时定义多个小型学生网络模型,让所有学生网络模型都共享相同的网络体系结构,并通过随机初始化的方式对网络权重进行初始化;
[0013]将训练集中的每个小批量数据通过D={x
i
,y
i
}
i=1~m
表示,m表示批量数,每个样本都属于C个类别之一,y
i
是一个C维的向量,表示真实的数据标注值;
[0014]然后利用交叉熵损失函数来表示第k个学生网络模型的输出与真实标签之间的误差,以此来提升各个学生网络模型的性能,其中表示第k个学生网络模型softmax层之前的logit输出;
[0015]增加KL散度损失来表示第k个学生网络模型的软化分布输出与教师网络模型软化输出之间的差异,以此来达到教师网络模型指导学生网络模型的目的,KL散度损失函数表达式为:
[0016][0017]其中,第k个学生网络模型的软化输出与教师网络模型的软化输出分别表示为与其中超参数τ表示温度系数,用于在所有输出类别上产生较软的概率分布;
[0018]对所述交叉熵损失函数和KL散度损失函数进行加权求和,获得知识蒸馏中训练学生网络模型的损失函数其中α是用来控制两项损失之间平衡的超参数。
[0019]进一步地,训练过程中利用预先训练好的教师网络模型对学生网络模型的训练进行指导,并在目标损失函数中增加相似性损失和多样性损失,具体还包括:
[0020]在知识蒸馏目标损失函数中增加两个新的损失项,即相似性损失L
sm
,它使不同的学生彼此互动;与多样性损失L
ds
,它使学生组学习的知识多样化;
[0021]将学生网络模型i到学生网络模型k的知识定义为他们之间的KL散度并引入差异掩码来表示学生网络模型i到学生网络模型k之间的差异,其中M(i,k)定义为两个学生网络模型的logit输出之间的欧式距离;
[0022]对于特定的学生网络模型k,将相似性损失L
sm
定义为:其中,K表示学生网络模型的数量,I(x)是一个指示符函数,即当x状态为true时函数值为1,否则为0;
[0023]多样性损失L
ds
定义为:其中D(z
k,j
,z
avg,j
)表示多样性度量,被定义为:也就是计算每个学生网络模型的logit输出与他们的平均值之间的L2范数,C表示logit的维度,z
avg
表示所有学生网络模型的logit输出的平均,表达式为:
[0024]在整个由多学生网络模型组成的动态知识蒸馏系统中,将学生网络模型k的总体损失函数定义为四个损失的加权组合:其中γ表示控制损失影响的比例因子。
[0025]进一步地,训练过程中利用预先训练好的教师网络模型对学生网络模型的训练进行指导,并在目标损失函数中增加相似性损失和多样性损失,具体还包括:
[0026]引入自动化的控制机制来调整比例因子γ,该机制表达为:
[0027][0028]其中θ是用来控制γ变化的超参数,Δacc是每个epoch结束时训练集上所有学生网络模型准确率变化的平均值,表达式为:表示学生网络模型k在当前epoch训练结束时的准确率,对应的是上一个epoch训练结束时的准确率。
[0029]进一步地,训练过程中利用预先训练好的教师网络模型对学生网络模型的训练进行指导,并在目标损失函数中增加相似性损失和多样性损失,具体还包括:
[0030]将所有的学生网络模型通过一个整体损失函数L
total
同时进行训练,其表达式为:
[0031]进一步地,训练过程中利用预先训练好的教师网络模型对学生网络模型的训练进行指导,并在目标损失函数中增加相似性损失和多样性损失,输出准确率较高的学生网络模型,具体还包括:
[0032]当所有学生网络模型的准确率不再上升时,停止训练,输出准确率最高的一个学生网络模型。
[0033]根据本专利技术实施例的第二方面,提出了一种基于知识蒸馏的高效图像分类系统,所述系统包括:
[0034]训练集和测试集构建模块,用于构建训练集和测试集,并对所述训练集和测试集中的图像进行类别标注;
[0035]对所述训练集合测试集中的图像进行预处理;
[0036]模型训练模块,用于使用预处理后的训练集对多个学生网本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于知识蒸馏的高效图像分类方法,其特征在于,所述方法包括:构建训练集和测试集,并对所述训练集和测试集中的图像进行类别标注;对所述训练集合测试集中的图像进行预处理;使用预处理后的训练集对多个学生网络模型同时进行训练,训练过程中利用预先训练好的教师网络模型对学生网络模型的训练进行指导,并在目标损失函数中增加相似性损失和多样性损失,输出准确率较高的学生网络模型;使用预处理后的测试集对输出的学生网络模型进行图像分类测试。2.根据权利要求1所述的一种基于知识蒸馏的高效图像分类方法,其特征在于,训练过程中利用预先训练好的教师网络模型对学生网络模型的训练进行指导,具体包括:同时定义多个小型学生网络模型,让所有学生网络模型都共享相同的网络体系结构,并通过随机初始化的方式对网络权重进行初始化;将训练集中的每个小批量数据通过D={x
i
,y
i
}
i=1~m
表示,m表示批量数,每个样本都属于C个类别之一,y
i
是一个C维的向量,表示真实的数据标注值;然后利用交叉熵损失函数来表示第k个学生网络模型的输出与真实标签之间的误差,以此来提升各个学生网络模型的性能,其中表示第k个学生网络模型softmax层之前的logit输出;增加KL散度损失来表示第k个学生网络模型的软化分布输出与教师网络模型软化输出之间的差异,以此来达到教师网络模型指导学生网络模型的目的,KL散度损失函数表达式为:其中,第k个学生网络模型的软化输出与教师网络模型的软化输出分别表示为与其中超参数τ表示温度系数,用于在所有输出类别上产生较软的概率分布;对所述交叉熵损失函数和KL散度损失函数进行加权求和,获得知识蒸馏中训练学生网络模型的损失函数其中α是用来控制两项损失之间平衡的超参数。3.根据权利要求2所述的一种基于知识蒸馏的高效图像分类方法,其特征在于,训练过程中利用预先训练好的教师网络模型对学生网络模型的训练进行指导,并在目标损失函数中增加相似性损失和多样性损失,具体还包括:在知识蒸馏目标损失函数中增加两个新的损失项,即相似性损失L
sm
,它使不同的学生彼此互动;与多样性损失L
ds
,它使学生组学习的知识多样化;将学生网络模型i到学生网络模型k的知识定义为他们之间的KL散度并引入差异掩码来表示学生网络模型i到学生网络模型k之间的差异,其中M(i,k)定义为两个学生网络模型的logit输出之间
的欧式距离;对于特定的学生网络模型k,将相似性损失L
sm
定义为:其中,K表示学生网络模型的数量,I(x)是一个指示符函数,即当x状态为true时函数值为1,否则...

【专利技术属性】
技术研发人员:李翔
申请(专利权)人:北京影谱科技股份有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1