一种基于全局-局部知识蒸馏的跨域小样本图像分类方法技术

技术编号:37292213 阅读:29 留言:0更新日期:2023-04-21 03:23
本发明专利技术提供了一种基于全局

【技术实现步骤摘要】
一种基于全局

局部知识蒸馏的跨域小样本图像分类方法


[0001]本专利技术属图像处理
,具体涉及一种基于全局

局部知识蒸馏的跨域小样本图像分类方法。

技术介绍

[0002]图像处理是机器视觉走向工业应用的关键技术,而图像分类是图像处理技术的基础。在医学、遥感等多种场景下,图像数据往往难以获取,呈现典型的小样本特性。为了缓解小样本问题,一种有效的方式是利用源域数据学习可迁移的知识,并将学习到的知识泛化到目标域的小样本任务中。然而,由于源域与目标域之间存在域差异,导致源域上训练的模型难以有效地泛化到目标域中。为此,研究适用于跨域场景下的小样本图像分类技术具有重要的应用价值。文献“Snell J,Swersky K,Zemel R.Prototypical networks for fewlearning[C]//Advances in Neural Information Processing Systems.2017:4077

.”提出一种基于原型的小样本图像分类方法。它首先使用深度神经网络提取图像的特征,然后在特征空间中利用每个小样本任务中的少量标记样本构建类别的原型表示,最后根据测试样本到这些类别原型之间的距离进行类别隶属关系的分配。然而,由于深度神经网络的简单性偏好,导致该方法构建的原型往往只能捕获最具判别性的模式,例如,颜色、形状等,忽略了具有跨域泛化能力的语义信息。因此,该方法在跨域小样本图像分类任务中性能表现不佳。

技术实现思路
/>[0003]为了克服现有技术的不足,本专利技术提供一种基于全局

局部知识蒸馏的跨域小样本图像分类方法。构建了由全局分支和局部分支构成的分类模型,其中,全局分支以原始图像为输入,用于提取图像的全局特征,局部分支以原始图像的局部块为输入,用于提取该图像的局部特征;在两分支之间,通过构建全局

局部知识蒸馏损失促进全局特征关注到图像的局部区域,使得全局特征捕获丰富的语义信息,进而提升全局特征在跨域小样本任务上的泛化性能。
[0004]一种基于全局

局部知识蒸馏的跨域小样本图像分类方法,其特征在于步骤如下:
[0005]步骤1:基于现有的图像数据集构建小样本任务训练数据集,包括支持集和查询集其中,支持集包括N个类别,每个类别带有K个监督样本,查询集也包括这N个类别,每个类别带有M个未标记的样本;
[0006]步骤2:构建模型的全局分支,其处理过程如下:
[0007]首先,按照下式获得支持集的原型表示:
[0008][0009]其中,表示支持集中第n个类别的第k个样本,表示全局分支中的特征提
取网络,本专利技术中采用ResNet

10网络,C
n
表示第n个类别的原型表示,n=1,2,

,N;
[0010]然后,基于原型表示对查询集中的每个样本进行类别隶属关系预测:
[0011][0012]其中,表示查询集中的第i个查询样本,i=1,2,

,N*M,表示该样本的预测得分,matching(
·
)为两个向量之间的相似度度量函数,本专利技术中使用欧氏距离进行相似度度量;
[0013]接着,根据预测得分中的最大相似度对应的类别作为该查询样本的预测标签并根据查询样本的预测标签和真实标签计算交叉熵损失如下:
[0014][0015]其中,H(
·
)表示交叉熵损失函数,表示查询样本对应的真实标签,表示查询样本的预测标签和真实标签之间的交叉熵损失;
[0016]步骤3:构建模型的局部分支,其处理过程如下:
[0017]对于查询样本首先使用随机裁剪获取其相应的局部图像块其中r∈[1,R],表示每个查询图像对应的局部图像块的个数,表示查询样本的第r个局部图像块;
[0018]然后,使用局部分支中的特征提取网络提取得到各个局部图像块对应的局部特征其中,局部分支中的特征提取网络采用ResNet

10网络;
[0019]接着,使用步骤2计算的原型对局部特征进行类别隶属关系预测,得到各个局部图像块对应的预测得分
[0020][0021]其中,表示查询样本的第r个局部图像块的相似度得分,表示查询样本的第r个局部图像块的局部特征;
[0022]步骤4:按照下式计算模型的总损失
[0023][0024]其中,I表示小样本任务中查询样本的总个数,表示查询样本的全局

局部知识蒸馏损失,表示跨图像的局部

全局蒸馏损失,λ1表示全局

局部知识蒸馏损失项的系数,设置λ1为1,λ2表示跨图像的局部

全局蒸馏损失项的系数,设置λ2为0.15;
[0025]所述的查询样本的全局

局部知识蒸馏损失按下式计算得到:
[0026][0027]所述的跨图像的局部

全局蒸馏损失按下式计算得到:
[0028][0029]其中,表示查询集中的第j个查询样本的第r个局部图像块的预测得分,j≠i表示j为与第i个查询样本属同一类别的不同样本,j=1,2,

,N*M;
[0030]步骤5:根据步骤4计算的模型总损失,使用随机梯度下降法,端到端的训练全局分支的网络参数,并按下式进行局部分支的网络参数的更新:
[0031]θ
T


T
+(1

m)θ
S
(8)
[0032]其中,θ
T
表示局部分支中的网络参数,m表示指数移动平均更新中的动量系数,设置m为0.998,θ
S
表示全局分支中的网络参数,

表示更新操作;
[0033]步骤6:将待处理图像数据集输入到步骤5训练后得到的全局分支,预测得到其中每幅图像的隶属类别,完成图像分类。
[0034]本专利技术的有益效果是:通过训练阶段构建的全局

局部知识蒸馏框架促进全局特征关注到图像的局部信息,从而使模型能够学习到泛化性强的语义表征,提升在跨域小样本任务上的泛化性能;采用端到端的框架设计方式,一旦模型在源域(训练数据集)上训练完成之后,即可在任意目标域(待处理图像数据集)的小样本任务上进行测试,无需微调特征提取模型;本专利技术能够在跨域小样本图像分类中获得较好的分类效果。
具体实施方式
[0035]下面结合实施例对本专利技术进一步说明,本专利技术包括但不仅限于下述实施例。
[0036]本专利技术提供了一种基于全局

局部知识蒸馏的跨域小样本图像分类方法,其具体实现过程如下:
[0037]1、构建小样本任务训练数据集
[0038]跨域小样本图像分类任务要求模型在源域中进行本文档来自技高网...

【技术保护点】

【技术特征摘要】
1.一种基于全局

局部知识蒸馏的跨域小样本图像分类方法,其特征在于步骤如下:步骤1:基于现有的图像数据集构建小样本任务训练数据集,包括支持集和查询集其中,支持集包括N个类别,每个类别带有K个监督样本,查询集也包括这N个类别,每个类别带有M个未标记的样本;步骤2:构建模型的全局分支,其处理过程如下:首先,按照下式获得支持集的原型表示:其中,表示支持集中第n个类别的第k个样本,表示全局分支中的特征提取网络,本发明中采用ResNet

10网络,C
n
表示第n个类别的原型表示,n=1,2,

,N;然后,基于原型表示对查询集中的每个样本进行类别隶属关系预测:其中,表示查询集中的第i个查询样本,i=1,2,

,N*M,表示该样本的预测得分,matching(
·
)为两个向量之间的相似度度量函数,本发明中使用欧氏距离进行相似度度量;接着,根据预测得分中的最大相似度对应的类别作为该查询样本的预测标签并根据查询样本的预测标签和真实标签计算交叉熵损失如下:其中,H(
·
)表示交叉熵损失函数,表示查询样本对应的真实标签,表示查询样本的预测标签和真实标签之间的交叉熵损失;步骤3:构建模型的局部分支,其处理过程如下:对于查询样本首先使用随机裁剪获取其相应的局部图像块其中r∈[1,R],R表示每个查询图像对应的局部图像块的个数,表示查询样本的第r个局部图像块;然后,使用局部分支中的特征提取网络提取得到各个局部图像块对应的局部特征其中,局部分支中的特征提取网络采用ResNet

10网络;接着,使用步骤2计算的原型对局部特征进行...

【专利技术属性】
技术研发人员:张磊魏巍周飞张艳宁
申请(专利权)人:西北工业大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1