一种基于元模块融合增量学习的图像分类方法技术

技术编号:35370098 阅读:16 留言:0更新日期:2022-10-29 18:11
本发明专利技术提供了一种基于元模块融合增量学习的图像分类方法,获取待分类图片,形成待分类图片集,依次将分类图片集输入至增量分类神经网络VGG网络或ResNet,训练元模型;在元模型的分类器后添加修正器,利用保留数据训练修正器,训练门控选择层,实现对元模型的融合;确定对应的具体图像类别。本发明专利技术能够使网络在长期增量阶段时,延缓精度下降,能有效的改善动态扩展重表达方法中存在的弊端,显著改进了识别精度,解决了现有的增量学习中图像分类精度下降过快的算法问题。本发明专利技术能在保持精度较高的情况下,实现在内存规模,网络增长规模和计算速度等多个上面的优势。速度等多个上面的优势。速度等多个上面的优势。

【技术实现步骤摘要】
一种基于元模块融合增量学习的图像分类方法


[0001]本专利技术涉及图像识别
,通过有限网络模型扩张与分类器重训练,实现少量样本数据驱动的增量学习的图像识别。

技术介绍

[0002]近年来,神经网络模型已经在很多机器学习领域取得了巨大成功,如图像识别、目标检测、自然语言处理、姿态估计等。但目前神经网络模型依然有很多不足,灾难性遗忘即是一个亟待解决的重要问题。增量学习能够解决网络学习中灾难性遗忘问题
[0003]不遗忘学习方法(LwF)是首次把知识蒸馏的思路应用到增量学习中,仅利用现有新样本就可以在学习新任务的同时,对旧任务保持记忆。但由于完全没有使用旧的类别样本,随着类的不断增加,整体准确率也急剧下降。增量分类器和特征重表达学习(iCaRL)是最经典的基于样本回放的增量学习模型,它在算法层面借鉴保留了前例中的蒸馏技术。同时采用特征提取器和分类器分离方法,并在固定内存规模的情况下,通过筛选出具有代表性的旧样本和新样本组成新的训练集,实现增量学习,因此较前者在准确率上有所提升,代价便是增大了内存容量。基于空间蒸馏损失的方法(PODNet),改进了特征的蒸馏方法,并将分类器与代理向量相结合,改进了分类器的损失形式,取得了不错的效果。小样本增量学习方法(FSCIL)沿用了特征提取器和分类器分开的思路。用拓扑关系来模拟特征空间上的关系,将特征提取后的特征空间上的位置做为神经气体网络的输入,以此输入分类器来分类。该方法在解决小样本增量学习问题上取得巨大成功。
[0004]最新的动态扩展重表达方法(DER)通过模型结构扩展的方式为每一个新任务训练一个特征提取器。在每个增量任务时候对特征进行扩展,都将上一个阶段提取出的特征进行固定,并且运用新的特征提取器再对特征进行提取。这就使得模型在保持旧任务知识的同时可以获得适用于新增量任务的新知识。但由于该方法在训练网络时把每次的增量类别都整合到同一个网络中去,这就造成了当网络长期处于增量阶段时,类别间分类精度的相互干扰,网络增量越多,准确率下降越快。

技术实现思路

[0005]为了克服现有技术的不足,本专利技术提供一种基于元模块融合增量学习的图像分类方法,能够使网络在长期增量阶段时,延缓精度下降,能有效的改善动态扩展重表达方法中存在的弊端。大量实验表明,本专利技术显著改进了识别精度。基于VGG网络和ResNet,在CUB、CIFAR

100和MiniImageNet,
[0006]本专利技术解决其技术问题所采用的技术方案包括以下步骤:
[0007]步骤一、获取待分类图片,形成待分类图片集,设定每次增量学习的新添类别数量为K,增量学习的增长步数为T,最大增长步数为L
max

[0008]其中,数据集D代表总的图像数据集,N代表图像类别,D
n
代表第n类图像的数据集,样本总数为S,(X
s
,Y
s
)代表样本输入以及对应
标签,K代表每次增量学习的新添类别数量,T代表增量学习的增长步数;
[0009]步骤二、依次将步骤一中分类图片集输入至增量分类神经网络VGG网络或ResNet,训练元模型;
[0010]步骤三、在元模型的分类器后添加修正器,利用保留数据训练修正器。
[0011]在每次增量学习之后扩展修正器的输出,并重训练修正器;
[0012]步骤3.1、训练修正器;对步骤二中所有训练过的数据集进行抽样得到保留数据集D
r
,在网络的分类层后添加一层全连接层FC做为修正器C,利用保留数据训练该FC层参数,训练方式采用交叉熵损失函数训练w个epoch,学习率从λ开始;
[0013]步骤3.2、如果训练步数t小于最大增长步数L
max
,则返回步骤二进行增量学习的元模块训练,即步骤二中的步骤2.2,如果训练步数大于等于最大增长步数L
max
,则完成元模型M
i
的训练;
[0014]步骤四、训练门控选择层,实现对元模型的融合;
[0015]步骤4.1、重复步骤二,直到训练完所有数据得到多个元模块每个元模块M
i
中包含L
max
次增量学习,总的增量学习次数是T,得到的元模块数量是
[0016]步骤4.2、维持已训练网络模型特征提取层参数不变,在θ
u
对应的特征提取层后添加门控分类层G;
[0017]步骤4.3、在总的数据集中抽取部分样本组成新的保留数据集D
r
训练新添的门控分类层G,训练损失函数为p
i
=η

m
i
,其中η表示输出向量,m
i
表示所有输出向量的平均,n表示训练元门控分类层G时输入样本数量;
[0018]步骤4.4、在测试阶段,输入图像依据门控分类层G的输出结果,选择对应的元模块,经过元模块的分类层,确定对应的具体图像类别。
[0019]所述步骤二中,增量分类神经网络训练元模型的具体步骤如下:
[0020]步骤2.1、训练初始网络;选择步骤一的数据输入初始神经网络VGG或ResNet中,神经网络采用随机初始化,采用交叉熵损失函数训练w个epoch,学习率从λ开始,得到神经网络特征提取层的参数θ
F
=[θ
u

s
]和分类层参数θ
C

[0021]步骤2.2、训练增量网络;保持初始网络特征提取层的后1/2层结构不变,θ
s
是神经网络特征提取层的后1/2层的参数,选择新的增量类别扩展初始网络结构的前部分,即θ
u
对应的特征提取层,利用新增数据训练新扩展层的参数,训练方式采用交叉熵损失函数训练w个epoch,学习率从λ开始。
[0022]所述epoch的w取值为小于等于100。
[0023]所述学习率λ取值为0.01。
[0024]本专利技术的有益效果在于通过提供一种基于元模块融合增量学习的图像分类方法,解决了现有的增量学习中图像分类精度下降过快的算法问题。通过将多个元模型融合的方式实现增量学习能有效的减少参数增长速度,延缓灾难性遗忘问题,保持分类精度在可靠范围内。相比与现有的方法能够在内存规模,网络模型规模,分类精度上达到一个较合适的平衡点。
[0025]相比于精度相当的算法例如,动态扩展重表达方法,能在内存规模和网络模型上形成优势,相比与内存规模相当的算法例如,增量分类器和特征重表达学习,能在精度和计算速断上形成优势。总的来说,本专利技术能在保持精度较高的情况下,实现在内存规模,网络增长规模和计算速度等多个上面的优势。
附图说明
[0026]图1为本专利技术总体的算法实现步骤图。
[0027]图2为元模型网络融合训练过程示意图。
具体实施方式
[0028]下面结合附图和实施例对本专利技术进一步说明。...

【技术保护点】

【技术特征摘要】
1.一种基于元模块融合增量学习的图像分类方法,其特征在于包括下述步骤:步骤一、获取待分类图片,形成待分类图片集,设定每次增量学习的新添类别数量为K,增量学习的增长步数为T,最大增长步数为L
max
;其中,数据集D代表总的图像数据集,N代表图像类别,D
n
代表第n类图像的数据集,样本总数为S,(X
s
,Y
s
)代表样本输入以及对应标签,K代表每次增量学习的新添类别数量,T代表增量学习的增长步数;步骤二、依次将步骤一中分类图片集输入至增量分类神经网络VGG网络或ResNet,训练元模型;步骤三、在元模型的分类器后添加修正器,利用保留数据训练修正器。在每次增量学习之后扩展修正器的输出,并重训练修正器;步骤3.1、训练修正器;对步骤二中所有训练过的数据集进行抽样得到保留数据集D
r
,在网络的分类层后添加一层全连接层FC做为修正器C,利用保留数据训练该FC层参数,训练方式采用交叉熵损失函数训练w个epoch,学习率从λ开始;步骤3.2、如果训练步数t小于最大增长步数L
max
,则返回步骤二进行增量学习的元模块训练,即步骤二中的步骤2.2,如果训练步数大于等于最大增长步数L
max
,则完成元模型M
i
的训练;步骤四、训练门控选择层,实现对元模型的融合;步骤4.1、重复步骤二,直到训练完所有数据得到多个元模块每个元模块M
i
中包含L
max
次增量学习,总的增量学习次数是T,得到的元模块数量是步骤4.2、维持已训练网络模型特征提取层参数不变,在θ
u
对应的...

【专利技术属性】
技术研发人员:王庆杨晨姚一杨周果清王雪
申请(专利权)人:西北工业大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1