【技术实现步骤摘要】
一种基于多教师知识蒸馏的跨域小样本识别方法
[0001]本专利技术属于计算机
,具体涉及跨域小样本识别方法。
技术介绍
[0002]小样本学习旨在将知识从源数据集转移到只有一个或几个标记示例的新目标数据集。通常,小样本学习假设源数据集和目标数据集的图像属于同一数据域。然而,这样一个理想的假设在现实世界的多媒体应用中可能并不容易满足。例如,如工作[1]中所揭示的,在主要由大量多样的自然图像组成的数据集上训练的模型仍然无法识别新颖的细粒度鸟类。为此,跨域小样本识别旨在解决小样本识别中源域与目标域领域不一致的问题。
[0003]近年来,跨域小样本已在许多先前的方法中得到广泛研究[2、3、4、5、6]。他们中的大多数[3,5,6]仅使用源域图像进行训练,并主要致力于提高模型的泛化能力。尽管已经取得了一些成就,但由于源数据集和目标数据集之间存在巨大的领域差距,在性能上仍很难取得重大突破。因此,一些工作[2,4]放宽了该最基本但最严格的设置,允许在训练阶段使用目标数据。其中,STARTUP[4]中使用了大量的未标记目标数据
【技术保护点】
【技术特征摘要】
1.一种基于多教师知识蒸馏的跨域小样本识别方法,其特征在于,具体步骤如下:(1)搭建三个模块:特征提取网络、小样本分类器、动态领域拆分模型;(1.2)采用ResNet
‑
10,作为特征提取网络模型;使用该特征提取网络模型,给定源域或者目标域数据,抽取得到对应的源域特征F
S
、目标域特征F
T
;(1.2)采用GNN,作为小样本分类器;使用该小样本分类器,给定任意一个元学习任务{S,Q},得到Q的概率分布P;(1.3)动态领域拆分模型,主要功能在于动态地将网络的特定层拆解为源域相关的部分和目标域相关的部分;具体地,就是定义一个领域门矩阵M,该矩阵M的维度与需要拆解的卷积核的数量一致;相对应地,M中第i个元素的值M
i
表示的就是第i个卷积核分配给源域的概率,而1
‑
M
i
则表示这个卷积核被分配给目标域的概率;进一步引入Gumbel softmax,实现将浮点M二值化,当M
i
输出为1时表示源域通路激活,而目标域通路关闭;反之,当M
i
输出为0时表示源域通路关闭,而目标域通路激活;使用该动态领域拆分模型,给定网络某层的输出特征源域输出F
S
、目标域输出F
T
,通过矩阵M的数值确定最终的源域输出为F
S
·
M,目标域输出为F
T
·
(1
‑
M);设置矩阵M为可学习参数,随着网络训练共同更新;(2)基于三个模块形成三个网络模型:源域教师模型、目标域教师模型、领域可拆分的学生模型;(2.1)连接特征提取网络以及小样本分类器,构成源域教师模型;(2.2)连接特征提取网络以及小样本分类器,构成目标域教师模型;(2.3)连接特征提取网络以及小样本分类器,对特征提取网络中的特定几层插入动态领域拆分模型,构成领域可拆分的学生模型;(3)只使用源域数据训练源域教师模型,训练方法为:从源域数据集中随机采样一个元学习单元作为网络输入,依次通过特征提取网络、小样本分类器,得到模型对于此次查询集中各个图片类别的结果预测概率分布,然后通过跟正确类别之间的距离得到训练损失函数;(4)只使用目标域域数据训练目标域教师模型,训练方法为:从目标域数据集中随机采样一个元学习单元作为网络输入,依次通过特征提取网络、小样本分类器,得到模型对于此次查询集中各个图片类别的结果预测概率分布,然后通过跟正确类别之间的距离得到训练损失函数;(5)使用同时来自源域和目标域的数据训练领域可拆分学生模型,训练方法为:从源域和目标域数据集中各自采样一个元学习单元,将两个学习单元都通过标准通路和领域通路这两个通路分别得到概率预测结果;其中:所述标准通路是依次通过特征提取网络、小样本分类器,不对特征提取网络进行领域拆分,即不管是来自哪个领域的数据,所有的卷积核都会被激活;所述领域通路是依次通过特征提取网络、动态领域拆分模型、小样本分类器,对特征提取网络进行拆分,即只有对应领域门M
i
的输出值为1时,该卷积核才对当前领域的数据激活;然后执行两个子任务:(5.1)小样本元学习任务:通过跟正确类别之间的距离得到训练损失函数;
(5.2)知识蒸馏任务:将学生模型的预测概率分布与对应教师模型概率分布进行比较,得到训练损失;(6)在目标域未知类别测试数据上对领域可拆分的学生模型以小样本元分类任务进行性能测试,将数据通过标准通路和普通通路得到两个不同的概率预测结果,将两个预测概率取平均作为最终的预测结果,该概率分布中得分最高的类别即为此次的预测类别;重复该步骤若干次,得到最终的模型准确率。2.根据权利要求1所述基于多教师知识蒸馏的跨域小样本识别方法,其特征在于,步骤(5)中所述小样本元分类学习任务,首先从数据集中随机采样一个元学习单元{支持集,查询集}作为网络输入,依次通过特征提取网络、动态领域拆解网络、小样本分类器,得到模型对于此次查询集中各个图片类别的结果预测概率分布;该预测概率分布将用于与正确的查询集类别进行比较,得到训练损失;在本发明中,交叉熵损失用于计算小样本元分类损失。3.根据权利要求1所述基于多教师知识蒸馏的跨域小样本识别方法,其特征在于,步骤(5)中所述知识蒸馏任务,给定待输入到学生模型中特定领域的数据,例如源数据,将学生模型的预测概率分布与源教师模型对...
【专利技术属性】
技术研发人员:姜育刚,傅宇倩,谢宇,付彦伟,陈静静,
申请(专利权)人:复旦大学,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。