一种基于知识蒸馏的类增量学习方法技术

技术编号:35187103 阅读:14 留言:0更新日期:2022-10-12 17:58
本发明专利技术公开了一种基于知识蒸馏的类增量学习方法,包括以下步骤:在类增量学习的设置下完成对数据集的分割,将数据集分为常见的长任务序列,即将每个数据集分成n个任务;部署教师网络,移除分类层,在每次添加新任务时为其添加任务相应分类数量的分类头,训练教师网络;部署学生网络并训练;在增量阶段,加载训练好的教师模型,用知识蒸馏的方法让学生网络的输出拟合教师网络的输出,从而实现教师网络的知识转移,完成对学生网络的训练;移除加载的教师模型,释放内存,测试学生网络的预测效果。本发明专利技术不需要保存样本或者扩张模型,大大减少了内存开销,实现网络模型压缩,减少计算量以及节省计算资源的目的。及节省计算资源的目的。及节省计算资源的目的。

【技术实现步骤摘要】
一种基于知识蒸馏的类增量学习方法


[0001]本专利技术涉及计算机视觉领域,涉及一种基于知识蒸馏的类增量学习方法。

技术介绍

[0002]在深度学习这一背景设定下,为了达到更好的预测效果,通常会采取以下措施来训练网络:一.使用过参数化的深度神经网络,即扩张网络模型,尽可能地加深网络层数。二.集成模型,即集成多个弱的模型,从而增强预测效果。然而这两种方法的缺点都很明显,即需要的计算量和计算资源都很大。在这一缺点的驱动下,相关研究转向模型压缩这一领域,即用一个规模较小层数较浅的网络模型,能够达到和大模型一样或者相当的预测效果。显然从头训练小模型很难实现预想中的效果,因此技术手段为训练大模型并将其知识转移给小模型,这也就是知识蒸馏的具体做法。
[0003]增量学习是指一个学习系统能够不断地从新样本中学习新的知识,并能保存大部分以前已经学习到的知识。通常用稳定性——可塑性曲线衡量一个网络模型的增量学习能力。若要求一个网络模型同时具备学习新任务以及保留旧任务性能的能力,规模越大的网络模型通常表现越好,然而在增量设置下其计算量和计算资源需求呈线性增长,因此如何减少计算量节约计算资源一直是该领域内的一大研究热点。本专利技术以知识蒸馏为基础,采用mobilenet_v2作为学生网络,resnet34作为教师网络,在参数量减少十倍的前提下能够使得学生网络获得和教师网络相当的预测效果。并且本专利技术不需要保存样本以及扩张模型,大大减少了内存开销以及避免隐私泄露等问题。
[0004]目前增量学习的设置主要分成两种:任务增量学习以及类增量学习。两种设置的训练过程并无两样,主要区别在于测试阶段任务增量学习可以获取当前图像所属的任务id,而类增量学习无法获取该信息。显然类增量学习是难度更大的学习任务,因此本专利技术主要针对类增量学习的预测效果进行提高,同时在测试时会计算任务增量设置的预测结果,其实验结果表明对于任务增量设置同样有较好的提升效果。在学生网络的训练框架上选择最简单的微调,最小化计算量。

技术实现思路

[0005]专利技术目的:本专利技术的目的在于提供一种基于知识蒸馏的类增量学习算法,实现减少计算量,节省计算资源和内存开销以及避免隐私泄露的目的。
[0006]技术方案:一种基于知识蒸馏的类增量学习算法,包括利用resnet34在cifar100,tiny

imagenet等数据集上完成类增量学习设置下的教师网络模型训练,然后用知识蒸馏的方法将教师网络的知识转移给学生网络mobilenet_v2。通过该方法实现网络模型压缩,从而达到减少计算量节省计算资源的目的。
[0007]一种基于知识蒸馏的类增量学习方法,包括以下步骤:
[0008](1)对数据集进行分割,将数据集分为长任务序列;
[0009](2)部署并训练教师网络模型,在每次训练完当前任务之后保存相应的网络模型
参数;
[0010](3)部署学生网络,通过微调方法训练学生网络模型,加载步骤(2)中训练完成的教师网络模型,用知识蒸馏的方法让学生网络模型的输出拟合教师网络模型的输出,实现教师网络模型的知识转移,完成对学生网络的训练;
[0011](4)移除步骤(3)中加载的教师网络模型,释放内存,测试学生网络的预测效果。
[0012]所述步骤(1)中对数据集进行分割的方法为,将每个数据集分成n个任务,任务间的数据各不相交,且在增量阶段,旧任务的数据无法访问,具体为:
[0013](1.1)将当前数据集的训练集、测试集分别记为X,Y,将该数据集分割为n个任务,即
[0014]X={x
i
|i=1,2,

,n},Y={y
i
|i=1,2,

,n};
[0015](1.2)在训练过程中,在第i个训练阶段,仅用训练集x
i
训练当前网络,其中1≤i≤n;
[0016](1.3)在测试过程中,在第i个测试阶段,对1,2,...,i

1阶段中的测试集进行测试,即对y1∪y2∪

∪y
i
进行测试,其中1≤i≤n;其中即各个任务间的类不相交。
[0017]所述步骤(2)中训练教师网络模型的方法为:
[0018]首先移除分类层;在每次添加新任务时,根据当前任务需要的分类数量为教师网络模型添加相应数量的分类头;在每次训练完当前任务后保存相应的教师网络模型参数。
[0019]所述步骤(3)中对学生网络模型进行训练的微调方法为:
[0020]对第一个任务,选用SGD优化器,计算交叉熵损失并进行优化;
[0021]在增量阶段,设当前阶段为i,其中1<i≤n;当前任务的交叉熵损失为:
[0022][0023]其中y
i
是真实标签,为预测值,n为类别数;
[0024]设旧任务的分类头分别为head1,head2,

,head
i
‑1,其中1<i≤n;则当前图像在以往旧任务的分类头上的logits值输出为:令学生网络模型输出的logits值拟合教师网络模型输出的logits值,使两个输出值趋同,用知识蒸馏损失函数衡量趋同程度,对该损失函数进行优化,当前任务的知识蒸馏损失函数为:
[0025][0026]其中是当前任务图像在学生网络模型的旧任务分类头上输出的logits值,是当前任务图像在教师网络模型的相应任务分类头上输出的logits值;
[0027]当前任务的总损失记为:
[0028]loss=ce_loss+λ*kd_loss
[0029]其中λ为用于平衡损失比例的超参数。
[0030]所述步骤(3)中通过知识蒸馏方法完成教师网络模型的知识转移的方法为:
[0031]加载步骤(2)中训练完成的教师网络模型,定义网络模型后,加载步骤(2)中教师
网络模型的对应参数;
[0032]学生网络模型的训练过程中,在增量阶段i(1<i≤n),当前任务的训练集为x
i
,将x
i
送入学生网络模型后在旧任务的分类头head1,head2,

,head
i
‑1上输出的logits值为将所有logits值拼接为一个tensor,记为
[0033]将x
i
送入训练完成的教师网络模型中,x
i
在对应任务的分类头上输出的logits值为:将所有logits值拼接为一个tensor,记为对logits值做如下处理增大网络输出中较小logits值的权重:
[0034][0035]其中:
[0036][0037]n为标签数量,log为对数运算,T为温度系数;
[0038]联合交叉熵损失对学生网络进行优化,完成学生网络模型训练,当前任务的总损失为:
...

【技术保护点】

【技术特征摘要】
1.一种基于知识蒸馏的类增量学习方法,其特征在于,包括以下步骤:(1)对数据集进行分割,将数据集分为长任务序列;(2)部署并训练教师网络模型,在每次训练完当前任务之后保存相应的网络模型参数;(3)部署学生网络,通过微调方法训练学生网络模型,加载步骤(2)中训练完成的教师网络模型,用知识蒸馏的方法让学生网络模型的输出拟合教师网络模型的输出,实现教师网络模型的知识转移,完成对学生网络的训练;(4)移除步骤(3)中加载的教师网络模型,释放内存,测试学生网络的预测效果。2.根据权利要求1所述的一种基于知识蒸馏的类增量学习方法,其特征在于,所述步骤(1)中对数据集进行分割的方法为,将每个数据集分成n个任务,任务间的数据各不相交,且在增量阶段,旧任务的数据无法访问,具体为:(1.1)将当前数据集的训练集、测试集分别记为X,Y,将该数据集分割为n个任务,即X={x
i
|i=1,2,

,n},Y={y
i
|i=1,2,

,n};(1.2)在训练过程中,在第i个训练阶段,仅用训练集x
i
训练当前网络,其中1≤i≤n;(1.3)在测试过程中,在第i个测试阶段,对1,2,...,i

1阶段中的测试集进行测试,即对y1∪y2∪

∪y
i
进行测试,其中1≤i≤n;其中即各个任务间的类不相交。3.根据权利要求1所述的一种基于知识蒸馏的类增量学习方法,其特征在于,所述步骤(2)中训练教师网络模型的方法为:首先移除分类层;在每次添加新任务时,根据当前任务需要的分类数量为教师网络模型添加相应数量的分类头;在每次训练完当前任务后保存相应的教师网络模型参数。4.根据权利要求1所述的一种基于知识蒸馏的类增量学习方法,其特征在于,所述步骤(3)中对学生网络模型进行训练的微调方法为:对第一个任务,选用SGD优化器,计算交叉熵损失并进行优化;在增量阶段,设当前...

【专利技术属性】
技术研发人员:黄树成陶哲朱霞
申请(专利权)人:江苏科技大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1