基于MoCo模型的多模块知识蒸馏方法技术

技术编号:34268073 阅读:20 留言:0更新日期:2022-07-24 15:23
本发明专利技术公开了一种基于MoCo模型的多模块知识蒸馏方法,利用中间过程中生成的特征间具有相似度这一特点,将教师和学生网络各自分成对应的多个模块,通过MoCo模型提取到教师和学生网络的每个模块生成的特征计算相似度,利用相似度达到教师网络指导学生网络的目的。本发明专利技术可以在只有少量标签的基础上,自动地对样本特征进行动态更新,此方法的内存效率更高,解决了在有限内存的情况下训练大规模数据集的问题,使教师网络指导下的学生网络有鲁棒性的同时,兼具泛化性。兼具泛化性。兼具泛化性。

【技术实现步骤摘要】
基于MoCo模型的多模块知识蒸馏方法


[0001]本专利技术属于模型轻量化技术,尤其涉及一种基于MoCo模型的多模块知识蒸馏方法。

技术介绍

[0002]近年来,机器学习和深度学习在计算机视觉、自然语言处理、预测和音频处理等方面都有了卓越的进步,对于这些复杂的任务,训练后模型的规模很大,这使得在资源受限的设备上部署它很困难。在知识蒸馏中,在大数据集上训练的较大的繁琐网络(教师模型)可以很好地将学习到的知识转移到作为一个学生模型的更小更轻的网络中。
[0003]在基于瘦长网络的提示的研究中,引入了一种两阶段的策略来训练深度网络,但是没有明显的速度提升;深度相互学习提出了教师

学生网络相互学习,并且同时更新,但是难以提取学习更细节的信息,带来的误差更大;再生网络中,提出了利用学习到的学生网络指导下一级的学生网络,但是训练时间长且冗余过程较多。

技术实现思路

[0004]本专利技术的目的在于提供一种基于MoCo模型的多模块知识蒸馏方法,解决了在有限内存的情况下训练大规模数据集的问题,达到了减少运算量提高内存效率的效果。
[0005]实现本专利技术目的的技术解决方案为:一种基于MoCo模型的多模块知识蒸馏方法,包括以下步骤:
[0006]步骤S1、在Imagenet中随机采集K幅带标签的图像,1000<K<10000,对上述K幅图像逐张统一尺寸后进行数据增强,得到像素大小为h
×
w的2K幅带标签的图像,构成教师网络训练集。
[0007]步骤S2、将教师网络训练集输入教师网络,利用教师网络训练集对教师网络进行预训练,得到预训练教师网络。
[0008]步骤S3、在Instagram中随机采集N幅无标签的图像,10000<N<100000,对上述N幅图像逐张统一尺寸后进行数据增强,得到像素大小为h
×
w的2N幅无标签的图像,构成教师

学生网络训练集。
[0009]步骤S4、构建MoCo模型:
[0010]所述MoCo模型包括预训练教师网络、学生网络、编码器和动态编码器,将预训练教师网络划分成m个模块,并将学生网络也对应划分成m个模块,2<m<100。
[0011]步骤S5、将教师

学生网络训练集输入MoCo模型,提取预训练教师网络和学生网络中各模块生成的特征,并将上述特征分别输入编码器和动态编码器进行编码,对应得到查询样本特征和匹配样本特征,求出查询样本特征和匹配样本特征的相似度。用学生网络中第n+1个模块生成的相似度学习预训练教师网络第n+1个模块生成的相似度和第n模块生成的相似度,以此更新学生网络的网络参数,1≤n≤m。同时,预训练教师网络和学生网络都根据自身各模块生成的相似度各自对网络参数进行更新,最终获得训练好的学生网络。
[0012]步骤S6、在Instagram中随机采集M幅带标签的图像,100<M<1000,对上述M幅图像逐张统一尺寸后进行数据增强,得到像素大小为h
×
w的2M幅图像,构成学生网络测试集。
[0013]步骤S7、将学生网络测试集输入MoCo模型中训练好的学生网络,输出学生网络测试集中每个样本对应的预测结果,测试训练好的学生网络的准确率。
[0014]本专利技术与现有技术相比,其显著优点在于:
[0015](1)首次将Moco模型学习到的相似度用于知识蒸馏方法中,可以在只有少量标签的基础上,自动地对样本特征进行动态更新,使内存效率更高,并且没有匹配提取到的特征的步骤,减少中间数据转换的误差,使教师网络指导下的学生网络有鲁棒性的同时,兼具泛化性。
[0016](2)利用Moco模型自身的特性,预训练教师网络和学生网络都能通过相似度对网络参数进行自我更新,学生网络不仅可以学习各模块的工作方式还可以回顾复习未被学习到的特征,通过增加更新策略的方式,提高了学生网络的准确度。
[0017](3)在Moco模型中加入了池化层,为前期的训练提供可靠的数据,加速数据收敛,并且利用移动平均值的策略,使网络更好的更新,既保留了原数据又平稳添加新的梯度。
附图说明
[0018]图1为基于MoCo模型的多模块知识蒸馏方法模型图。
具体实施方式
[0019]下面结合附图对本专利技术作进一步详细描述。
[0020]结合图1,本专利技术所述的一种基于MoCo模型的多模块知识蒸馏方法,步骤如下:
[0021]步骤S1、在Imagenet中随机采集K幅带标签的图像,1000<K<10000,对上述K幅带标签的图像逐张统一尺寸后进行数据增强,得到像素大小为h
×
w(h取值范围为0~256,w取值范围为0~256)的2K幅带标签的图像,构成带标签的教师网络训练集,转入步骤S2。
[0022]步骤S2、将带标签的教师网络训练集输入教师网络,利用教师网络训练集对教师网络进行预训练,得到预训练教师网络,转入步骤S3。
[0023]步骤S3、在Instagram中随机采集N幅无标签的图像,10000<N<100000,对上述N幅图像逐张统一尺寸后进行数据增强,得到像素大小为h
×
w的2N幅无标签的图像,构成无标签的教师

学生网络训练集,转入步骤S4。
[0024]步骤S4、构建MoCo模型:
[0025]所述MoCo模型包括预训练教师网络、学生网络、编码器和动态编码器,将预训练教师网络划分成m个模块,并将学生网络也对应划分成m个模块,2<m<100。
[0026]所述预训练教师网络和学生网络均无分支,包括但不局限于经典网络结构中的ResNet、VGGNet、Mobilenet等。预训练教师网络规模数据均大于学生网络,转入步骤S5。
[0027]步骤S5、将无标签的教师

学生网络训练集输入MoCo模型,提取预训练教师网络和学生网络中各模块生成的特征,并将上述特征分别输入编码器和动态编码器进行编码,对应得到查询样本特征和匹配样本特征,求出查询样本特征和匹配样本特征的相似度;用学生网络中第n+1个模块生成的相似度学习预训练教师网络中的第n+1个模块生成的相似度和第n模块(预训练教师网络中)生成的相似度,以此更新学生网络的网络参数,1≤n≤m;同
时,预训练教师网络和学生网络都根据自身各模块生成的相似度各自对网络参数进行更新,最终获得训练好的学生网络,具体如下:
[0028]编码器和动态编码器采用相同结构,编码器承担了生成查询特征的任务;动态编码器基于无监督学习的对比损失构建具有一致性的字典,字典是以队列的形式表现出来的:
[0029]当前的特征经过动态编码器编码后得到的匹配样本特征进入队列,最先进入的一组匹配样本特征被清理出队列。
[0030]当前有编码器生成的一个查询样本特征q和动态编码器生成的一组序列{k0,k1本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于MoCo模型的多模块知识蒸馏方法,其特征在于,步骤如下:步骤S1、在Imagenet中随机采集K幅带标签的图像,1000<K<10000,对上述K幅图像逐张统一尺寸后进行数据增强,得到像素大小为h
×
w的2K幅带标签的图像,构成教师网络训练集,转入步骤S2;步骤S2、将教师网络训练集输入教师网络,利用教师网络训练集对教师网络进行预训练,得到预训练教师网络,转入步骤S3;步骤S3、在Instagram中随机采集N幅无标签的图像,10000<N<100000,对上述N幅图像逐张统一尺寸后进行数据增强,得到像素大小为h
×
w的2N幅无标签的图像,构成教师

学生网络训练集,转入步骤S4;步骤S4、构建MoCo模型:所述MoCo模型包括预训练教师网络、学生网络、编码器和动态编码器,将预训练教师网络划分成m个模块,并将学生网络也对应划分成m个模块,2<m<100;转入步骤S5;步骤S5、将教师

学生网络训练集输入MoCo模型,提取预训练教师网络和学生网络中各模块生成的特征,并将上述特征分别输入编码器和动态编码器进行编码,对应得到查询样本特征和匹配样本特征,求出查询样本特征和匹配样本特征的相似度;用学生网络中第n+1个模块生成的相似度学习预训练教师网络第n+1个模块生成的相似度和第n模块生成的相似度,以此更新学生网络的网络参数,1≤n≤m;同时,预训练教师网络和学生网络都根据自身各模块生成的相似度各自对网络参数进行更新,最终获得训练好的学生网络,转入步骤S6;步骤S6、在Instagram中随机采集M幅带标签的图像,100<M<1000,对上述M幅图像逐张统一尺寸后进行数据增强,得到像素大小为h
×
w的2M幅图像,构成学生网络测试集,转入步骤S7;步骤S7、将学生网络测试集输入MoCo模型中训练好的学生网络,输出学生网络测试集中每个样本对应的预测结果,测试训练好的学生网络的准确率。2.根据权利要求1所述的基于MoCo模型的多模块知识蒸馏方法,其特征在于,步骤S5中,在MoCo模型中,提取预训练教师网络和学生网络中各模块生成的特征并输入编码器和动态编码器,其中,编码器和动态编码器采用相同结构,编码器承担了生成查询特征的任务;动态编码器基于无监督学习的对比损失构建具有一致性的字典,字典是以队列的形式表现出来的:当前特征经过动态编码器编码后得到的匹配样本特征进入队列,最先进入的一组匹配样本特征被清理出队列;当前有编码器生成的一个查询样本特征q和动态编码器生成的一组序列{k0,k1,k2,

},序列作为字典中的键,序列中存在一个与q匹配的键k
+
;利用点积度量相似性,提出对比损失函数L
q
:其中,τ是一个温度超参数,k
i
为字典中的键;字典中的键包括一个正样本k
+
和K个负样
本,1<K<100;当q与键k
+
相似,而与所有其他键不同时,L
q
的值趋近于0;查询样本特征q由编码器f
q
和池化层产生,即q=f
q
(x
q
)+pool
q
(x
q
),x
q
表示任意一个查询样本;键k
i
由动态编码器f
k
和池化层产生,即k
i
=f
k
(x
ki
)+pool
ki
(x
ki
),x
ki
是字典中的键;此外,提出了一种缓慢进行的动态编码器更新方式,其动态是基于编码器的移动平均值来实现的,并以此与编码器保持一致性,将f
k
的参数表示为θ
k
...

【专利技术属性】
技术研发人员:王军袁静波刘新旺李玉莲李兵
申请(专利权)人:中国矿业大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1