System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind() 一种基于多质心任务描述子的动态类增量推理方法技术_技高网

一种基于多质心任务描述子的动态类增量推理方法技术

技术编号:40869771 阅读:2 留言:0更新日期:2024-04-08 16:36
本发明专利技术公开了一种基于多质心任务描述子的动态类增量推理方法,涉及计算机技术领域,深度学习,增量学习领域,具体来说是一种利用多质心任务描述子的门控网络,实现动态类增量推理的方法。本方法在现有多分支类增量学习模型基础上新增一个门控网络,该网络由两部分组成,一是用于提取特征的网络,二是属于每个任务的多个可学习质心。训练时首先将输入图片和质心做匹配得到图片的正负质心,不断拉近相关质心推远无关质心。预测时用该门控网络做动态类增量推理,根据门控网络输出的图片属于每个任务的概率对原有多分支网络的输出做动态融合后得到预测结果。

【技术实现步骤摘要】

本专利技术涉及计算机,深度学习,增量学习领域,具体来说是一种基于多质心任务描述子的门控网络,该网络能在推理阶段通过预测输入图片任务号,做动态类增量以提升现有增量学习正确率的方法。


技术介绍

1、类增量学习是指一种让模型能够持续学习新类别,同时不遗忘过去类别的训练手法。类增量学习的训练数据由多个有时间先后顺序的任务组成,每个任务由多个类别组成,模型在新任务上训练的时候,无法访问或只能访问一部分过去任务的数据,但是在测试时候,却要求能够在所有已经见过的类别上做分类。直接在新任务上训练会导致灾难性遗忘的问题,如何在学习新知识的同时,不遗忘过去任务的知识,是类增量学习领域的一大挑战。

2、类增量学习有以下几个常用解决方法:新旧模型蒸馏,保存旧任务中最具有代表性的样本,网络扩展。基于网络扩展的方法在每个新任务过来时候都会对原有网络的基础上新增一个网络,冻结原网络,仅训练新网络,此类方法实现简单,性能高效。

3、若在测试阶段,在给出输入图片的同时所属的任务号,可以仅在该任务号对应的多个类别中挑选预测结果,最后能得到正确率会大大提升。因此本专利技术考虑在原有网络的基础上额外增加一个任务预测网络负责给出图片所属任务的的概率,以此概率为权重对原增量模型的输出做动态融合后得到最终的预测结果,同时考虑预测任务的复杂性,创新地提出了基于多质心的任务描述子。


技术实现思路

1、本专利技术的目的在于预测类增量学习中的任务号,具体为在原有类增量学习算法的基础上新增一个门控网络,为每个任务分配多个可训练质心,然后提出一种基于匈牙利匹配的算法为每个图片寻找对应质心,该方法在类增量学习数据集上达到良好的识别效果。

2、实现本专利技术目的的具体技术方案是:

3、一种基于多质心任务描述子的动态类增量推理方法,特点是该方法包括如下步骤:

4、步骤一:初始化多分支网络、门控网络;

5、步骤二:当新任务来临的时候,将多分支网络做一个扩展,分配一个特征提取器,两个分类器,一个主分类器负责预测当前任务包含类别,另一个辅助分类器不仅负责预测当前任务类别,还会将所有过去类别预测为同一个类;门控网络的骨干网络保持不变,但对每个任务分配多个质心;

6、步骤三:将不同分支主分类器的输出拼接得到最后的预测结果,用交叉熵损失对多分支网络做训练;

7、步骤四:对于训练批次中包含的每个类,取其对应多张图片特征向量的均值作为原型点,用多质心匹配算法在每个任务内匹配原型和质心,得到原型和质心的对应关系,根据每个样本对应的类别,得到每个样本对应的最优质心;

8、步骤五:训练门控网络,减少样本和最优质心的距离,增大样本与无关质心的距离,同时又借用多分支网络的输出做蒸馏;

9、步骤六:当有新任务过来时,重复步骤二到五;

10、步骤七:在完成上述训练过程后,使用该网络实现动态类增量推理,使用门控网络预测输入图片所属任务概率,最终预测结果为多分支网络的输出乘上该任务概率的结果。

11、所述的多质心匹配算法,其步骤如下:首先将当前批次(batch)的数据,计算每个类对应图片的特征的平均值,作为该类原型点;然后做逐任务的质心匹配,对于一个任务,将当前批次中存在的,属于该任务的所有类原型点和该任务的多个质心计算相似度,做匈牙利匹配,得到每个原型点的匹配质心,样本可由样本的类别间接获得对应质心;

12、p=[p1,p2,…,p|y|]    (1)

13、

14、其中代表了当前批次内包含的类的原型向量,上标代表了对应任务号,下标代表了对应的类别号,y代表了当前批次的所有类目数量,代表了质心向量,上标代表了任务号,下标代表了质心对应序号,σ代表了求解公式(2)得到的总代价最小的质心顺序。

15、所述训练门控网络,其训练过程,由两个损失函数组成,其一对比学习损失,将样本和最优质心匹配,推远无关质心;其二是蒸馏损失,将输入图片与所有质心计算得到的相似度向量经过一个全连接层后做维度匹配,同多分支网络的输出做蒸馏;

16、

17、

18、其中f代表了样本经过门控网络后得到的特征,c+,cj代表了该样本匹配的正质心和第j个质心,d′i代表了相似度向量经过全连接层后输出的第i项结果,l代表了多分支网络输出的第i项结果,τ代表了用于纠正长尾分布的权重超参,π+,πj代表了该样本匹配的正质心的的纠偏系数和第j个质心的纠偏系数,λ代表了蒸馏损失使用的温度超参。

19、所述动态类增量推理,其过程如下:

20、步骤a1:首先将输入图片经过网络提取特征,计算与所有质心的相似度:

21、

22、步骤a2:在每个任务对应的多个质心中,选取其最大相似度作为该任务的相似度:

23、

24、步骤a3:将该相似度向量经过带温度的softmax后得到归一化的概率:

25、w′(t|x)=softmax(w/η)    (7)

26、步骤a4:通过将概率向量同多分支网络的输出相乘,得到最终的预测结果:

27、p=softmax([w′1φ1(x),…,w′tφt(x)])    (8)

28、其中代表了输入图片特征和质心的相似度,上标代表了属于哪个任务,下标代表了属于该任务的第几个质心,|kt|代表了第t个任务的质心数量,η代表了软化超参,φ1(x)代表了分支的输出结果,下标代表是第几个分支,x代表输入图片。

29、本专利技术同现有技术相比,其优点在于:本专利技术的多质心符合任务数据分布,有很好的任务预测的能力。本文提出的训练方法做法新颖,适合于该多质心网络。具体体现为:

30、1)本专利技术提出的多质心描述子符合任务数据的分布,相比现有的方法,能够更好适应任务数据方差大、在语义空间上形成多个簇的特点。

31、2)本专利技术为该多质心网络提出一种基于匈牙利匹配的寻找正负样本和质心的算法,能有效并成功训练。

32、3)本专利技术的多质心网络可以辅助判断不同任务的重要性,可以根据实际需求仅启用多分支网络的部分分支,在牺牲少量正确率的情况下加快推理速度。

33、4)本专利技术在cifar100和imagenet100/1000数据集上取得的正确率超现有模型。

本文档来自技高网...

【技术保护点】

1.一种基于多质心任务描述子的动态类增量推理方法,其特征在于,该方法包括如下步骤:

2.如权利要求1所述的动态类增量推理方法,其特征在于,所述的多质心匹配算法,其步骤如下:首先将当前批次(batch)的数据,计算每个类对应图片的特征的平均值,作为该类原型点;然后做逐任务的质心匹配,对于一个任务,将当前批次中存在的,属于该任务的所有类原型点和该任务的多个质心计算相似度,做匈牙利匹配,得到每个原型点的匹配质心,样本可由样本的类别间接获得对应质心;

3.如权利要求1所述的动态类增量推理方法,其特征在于,所述训练门控网络,其训练过程,由两个损失函数组成,其一对比学习损失,将样本和最优质心匹配,推远无关质心;其二是蒸馏损失,将输入图片与所有质心计算得到的相似度向量经过一个全连接层后做维度匹配,同多分支网络的输出做蒸馏;

4.如权利要求1所述的动态类增量推理方法,其特征在于,所述动态类增量推理,其过程如下:

【技术特征摘要】

1.一种基于多质心任务描述子的动态类增量推理方法,其特征在于,该方法包括如下步骤:

2.如权利要求1所述的动态类增量推理方法,其特征在于,所述的多质心匹配算法,其步骤如下:首先将当前批次(batch)的数据,计算每个类对应图片的特征的平均值,作为该类原型点;然后做逐任务的质心匹配,对于一个任务,将当前批次中存在的,属于该任务的所有类原型点和该任务的多个质心计算相似度,做匈牙利匹配,得到每个原型点的匹配质心,...

【专利技术属性】
技术研发人员:谢源蔡腾皓张志忠邹谷初齐振一
申请(专利权)人:华东师范大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1