System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind()
【技术实现步骤摘要】
本说明书涉及持续学习、深度学习领域,尤其涉及一种基于持续学习的图像分类方法及装置。
技术介绍
1、传统的机器学习虽然已经在分类任务上取得了比较好的效果,但却难以适应环境的动态改变。在某些应用场景中,图像分类任务可能会受到环境的持续变化的影响。比如,视频监控领域中的场景可能会随时间改变,或者是产品领域中的新产品不断推出等。持续学习使得机器能够及时适应这些变化,并持续提供准确的分类结果。通过持续学习,机器能够在已有的图像分类知识基础上不断积累新的知识。
2、在现有技术中,可以通过动态网络扩展的方法进行持续学习,即,可以通过动态地扩展网络结构,来适应新的任务或数据。在训练过程中,每新增一个任务,可以在网络中扩展用于分类新增的任务对应的类别的结构,并对网络进行一轮有监督训练,但是,通过这种方式训练出的网络进行图像分类时不够准确,网络在不断学习新任务的同时,可能会逐渐忘记旧任务。
3、因此,如何提高在图像分类领域中持续学习的准确性,则是一个亟待解决的问题。
技术实现思路
1、本说明书提供一种基于持续学习的图像分类方法及装置,以部分的解决现有技术存在的上述问题。
2、本说明书采用下述技术方案:
3、本说明书提供了一种基于持续学习的图像分类方法,包括:
4、设置缓冲区,所述缓冲区用于存储历史任务中的部分样本;
5、确定上一轮训练后得到的持续学习分类模型,并确定当前任务,获取当前任务对应的样本,所述样本中包含图片以及所述图片在
6、根据所述缓冲区中的样本以及所述当前任务对应的样本,对所述持续学习分类模型进行当前轮的训练,得到当前轮训练完成的持续学习分类模型,其中,一轮训练包括动态扩展阶段和模型压缩阶段,所述动态扩展阶段中:将上一轮训练后得到的持续学习分类模型作为教师模型,将持续学习分类模型中新增用于特征编码的多层自注意力模块以及所述当前任务对应的分类模块后,作为学生模型,通过蒸馏学习的方式对学生模型进行训练;所述模型压缩阶段中:将所述动态扩展阶段中的学生模型作为教师模型,将所述动态扩展阶段中的学生模型去掉所述多层自注意力模块,作为所述模型压缩阶段中的学生模型,通过蒸馏学习的方式对学生模型进行训练,得到当前轮训练完成的持续学习分类模型;
7、根据所述当前任务对应的样本,更新所述缓冲区,以进行下一轮训练,其中,每一轮训练对应一个新增的任务,每个任务对应若干分类类别,各任务之间的分类类别不存在交叉;
8、通过最终训练得到的持续学习分类模型,进行图像分类。
9、可选地,持续学习分类模型包括:图片编码模块和输出解码模块。
10、可选地,所述图片编码模块由补丁分词器和多层自注意力模块组成,所述补丁分词器的输入是图片,输出是补丁嵌入序列,所述多层自注意力模块的输入是补丁嵌入序列和可学习的位置嵌入序列相加,输出是经过编码的特征嵌入序列。
11、可选地,所述输出解码模块包括任务token生成模块、交叉注意力模块和分类模块,所述任务token生成模块用于生成每个任务对应的可学习的任务token,针对一个任务,该任务的任务token用于在所述交叉注意力模块中通过注意力机制得到该任务对应的分类嵌入,所述交叉注意力模块的输入为每个任务对应的任务token和所述图片编码模块的输出,所述交叉注意力模块的输出为每个任务对应的分类嵌入,所述分类模块用于根据每个任务对应的分类嵌入输出每个任务对应的分类概率。
12、可选地,所述动态扩展阶段中除了新增当前任务对应的分类模块,在所述任务token生成模块中还新增所述当前任务对应的任务token。
13、可选地,根据所述缓冲区中的样本以及所述当前任务对应的样本,对所述持续学习分类模型进行当前轮的训练之前,所述方法还包括:
14、针对每个样本,对该样本包含的图像进行标准化处理,以及对图片进行数据增强,所述数据增强包括:以一定概率进行随机裁剪、高斯模糊、水平翻转以及以一定概率随机调整图像的亮度、对比度、饱和度和色调。
15、可选地,在动态扩展阶段中,通过蒸馏学习的方式对学生模型进行训练,具体包括:
16、在所述动态扩展阶段中,冻结教师模型以及冻结学生模型中除本轮训练新增部分的其余模块,使第一损失函数最小化,对学生模型进行训练,其中,所述第一损失函数组成包括:学生模型的分类损失、针对所述多层自注意力模块的辅助分类损失、学生模型在历史任务上对应教师模型的分类蒸馏损失、学生模型辨别新旧任务的分类损失、在学生模型针对历史任务输出的class embedding进行正则化约束的蒸馏损失;
17、在模型压缩阶段中,通过蒸馏学习的方式对学生模型进行训练,具体包括:
18、在所述动态扩展阶段中,冻结教师模型以及冻结学生模型中的输出解码模块,使第二损失函数最小化,对学生模型进行训练,其中,第二损失函数包括:学生模型辨别新旧任务的分类损失、学生模型在所有任务上对应教师模型的分类蒸馏损失、学生模型在图片编码模块的输出对应教师模型图片编码模块输出的蒸馏损失。
19、本说明书提供了一种基于持续学习的图像分类装置,包括:
20、缓冲区模块,用于设置缓冲区,所述缓冲区用于存储历史任务中的部分样本;
21、确定模块,用于确定上一轮训练后得到的持续学习分类模型,并确定当前任务,获取当前任务对应的样本,所述样本中包含图片以及所述图片在当前任务下新增的分类类别;
22、训练模块,用于根据所述缓冲区中的样本以及所述当前任务对应的样本,对所述持续学习分类模型进行当前轮的训练,得到当前轮训练完成的持续学习分类模型,其中,一轮训练包括动态扩展阶段和模型压缩阶段,所述动态扩展阶段中:将上一轮训练后得到的持续学习分类模型作为教师模型,将持续学习分类模型中新增多层自注意力模块以及所述当前任务对应的分类模块后,作为学生模型,通过蒸馏学习的方式对学生模型进行训练;所述模型压缩阶段中:将所述动态扩展阶段中的学生模型作为教师模型,将所述动态扩展阶段中的学生模型去掉所述多层自注意力模块,作为所述模型压缩阶段中的学生模型,通过蒸馏学习的方式对学生模型进行训练,得到当前轮训练完成的持续学习分类模型;
23、更新模块,用于根据所述当前任务对应的样本,更新所述缓冲区,以进行下一轮训练,其中,每一轮训练对应一个新增的任务,每个任务对应若干分类类别,各任务之间的分类类别不存在交叉;
24、分类模块,用于通过最终训练得到的持续学习分类模型,进行图像分类。
25、本说明书提供了一种计算机可读存储介质,所述存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述基于持续学习的图像分类方法。
26、本说明书提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述基于持续学习的图像分类方法。<本文档来自技高网...
【技术保护点】
1.一种基于持续学习的图像分类方法,其特征在于,包括:
2.如权利要求1所述的方法,其特征在于,持续学习分类模型包括:图片编码模块和输出解码模块。
3.如权利要求2所述的方法,其特征在于,所述图片编码模块由补丁分词器和多层自注意力模块组成,所述补丁分词器的输入是图片,输出是补丁嵌入序列,所述多层自注意力模块的输入是补丁嵌入序列和可学习的位置嵌入序列相加,输出是经过编码的特征嵌入序列。
4.如权利要求2所述的方法,其特征在于,所述输出解码模块包括任务token生成模块、交叉注意力模块和分类模块,所述任务token生成模块用于生成每个任务对应的可学习的任务token,针对一个任务,该任务的任务token用于在所述交叉注意力模块中通过注意力机制得到该任务对应的分类嵌入,所述交叉注意力模块的输入为每个任务对应的任务token和所述图片编码模块的输出,所述交叉注意力模块的输出为每个任务对应的分类嵌入,所述分类模块用于根据每个任务对应的分类嵌入输出每个任务对应的分类概率。
5.如权利要求4所述的方法,其特征在于,所述动态扩展阶段中除了新增当前任
6.如权利要求1所述的方法,其特征在于,根据所述缓冲区中的样本以及所述当前任务对应的样本,对所述持续学习分类模型进行当前轮的训练之前,所述方法还包括:
7.如权利要求1或2所述的方法,其特征在于,在动态扩展阶段中,通过蒸馏学习的方式对学生模型进行训练,具体包括:
8.一种基于持续学习的图像分类装置,其特征在于,包括:
9.一种计算机可读存储介质,其特征在于,所述存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述权利要求1~7任一项所述的方法。
10.一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现上述权利要求1~7任一项所述的方法。
...【技术特征摘要】
1.一种基于持续学习的图像分类方法,其特征在于,包括:
2.如权利要求1所述的方法,其特征在于,持续学习分类模型包括:图片编码模块和输出解码模块。
3.如权利要求2所述的方法,其特征在于,所述图片编码模块由补丁分词器和多层自注意力模块组成,所述补丁分词器的输入是图片,输出是补丁嵌入序列,所述多层自注意力模块的输入是补丁嵌入序列和可学习的位置嵌入序列相加,输出是经过编码的特征嵌入序列。
4.如权利要求2所述的方法,其特征在于,所述输出解码模块包括任务token生成模块、交叉注意力模块和分类模块,所述任务token生成模块用于生成每个任务对应的可学习的任务token,针对一个任务,该任务的任务token用于在所述交叉注意力模块中通过注意力机制得到该任务对应的分类嵌入,所述交叉注意力模块的输入为每个任务对应的任务token和所述图片编码模块的输出,所述交叉注意力模块的输出为每个任务对应的分类嵌入,所述分类模块用于根据每个任务对应的分类嵌入输出...
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。