面向增量学习的图像分类方法以及相关设备技术

技术编号:38501790 阅读:14 留言:0更新日期:2023-08-15 17:09
本申请涉及人工智能领域以及数字医疗领域,提供了一种面向增量学习的图像分类方法以及相关设备,该方法通过确定第一本地图像样本集的分类任务类型,并根据目标分类任务类型获取对应的多个适配器以及预测输出模块,之后将多个适配器、预测输出模块和预训练模型进行结合以得到初始图像分类模型,之后基于初始图像分类模型进行联邦学习得到中间图像分类模型,并将中间图像分类模型作为教师模型对初始图像分类模型进行蒸馏处理,得到目标图像分类模型,最后利用目标图像分类模型进行分类预测,既能在联邦学习框架下充分学习新类数据,又能够缓解在学习新类数据时造成的灾难性遗忘,提高联邦学习架构下图像分类的增量学习能力。高联邦学习架构下图像分类的增量学习能力。高联邦学习架构下图像分类的增量学习能力。

【技术实现步骤摘要】
面向增量学习的图像分类方法以及相关设备


[0001]本申请涉及人工智能
和数字医疗领域,尤其涉及一种面向增量学习的图像分类方法以及相关设备。

技术介绍

[0002]随着人工智能的发展,越来越多深度学习技术被应用在各行各业中。对于计算机视觉领域中的图像分类技术,其在实际应用中对数据隐私的保护需求使得原始数据无法共享,多个数据源之间形成“数据孤岛”。基于联邦学习的图像分类因此被提出,通过联邦学习方式训练分类模型,使得每个联邦学习的参与者能够从其他参与者的图像数据中获益,同时能够确保每个参与者的图像数据不离开本地,在保证了各方数据隐私的前提下成功解决了数据孤岛的问题。
[0003]例如在数字医疗领域中,利用深度学习技术进行医学影像处理是医疗辅助诊断中至关重要的一步,通过联邦学习方法训练基于深度学习的分类模型能够在保护医疗数据隐私的前提下解决数据孤岛问题,提高医学影像分类的精准性。
[0004]此外,图像分类在实际应用中亦需要模型具备的增量学习能力,也就是在不忘记已经学习到的知识的基础上,不断地通过新类数据学习新的知识。然而,目前大多数基于联邦学习的图像分类方法在增量学习的过程中都会出现灾难性遗忘的问题,即在学习新类别的信息后,全局模型在旧类别的表现大幅度降低。
[0005]因此,如何提高联邦学习架构下图像分类的增量学习能力成为亟待解决的技术问题。

技术实现思路

[0006]本申请实施例的主要目的在于提出一种面向增量学习的图像分类方法、装置、电子设备及计算机可读存储介质,能够提高联邦学习架构下图像分类的增量学习能力,缓解灾难性遗忘现象。
[0007]为实现上述目的,本申请实施例的第一方面提出了一种面向增量学习的图像分类方法,所述方法包括:
[0008]获取第一本地图像样本集;
[0009]根据所述第一本地图像样本集确定目标分类任务类型;
[0010]获取所述目标分类任务类型对应的多个适配器以及预测输出模块;
[0011]将多个所述适配器、所述预测输出模块和预训练模型进行结合处理,得到初始图像分类模型;
[0012]基于所述第一本地图像样本集对所述初始图像分类模型进行训练,得到中间图像分类模型;
[0013]将所述中间图像分类模型的本地模型参数上传至所述服务器,以使所述服务器对所述本地模型参数进行整合处理,得到全局模型参数;
[0014]从所述服务器获取所述全局模型参数,并根据所述全局模型参数更新所述中间图像分类模型;
[0015]将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型;
[0016]获取待分类图像;
[0017]将所述待分类图像输入至所述目标图像分类模型中,以通过所述目标图像分类模型得到所述待分类图像对应的分类预测结果。
[0018]根据本专利技术一些实施例提供的面向增量学习的图像分类方法,所述将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型,包括:
[0019]获取第二本地图像样本集,所述第二本地图像样本集与所述第一本地图像样本集为相同分类任务类型的旧图像样本;
[0020]对所述第一本地图像样本集和所述第二本地图像样本集进行融合处理,得到训练样本集;
[0021]将所述训练样本集分别输入至所述初始图像分类模型,以通过所述初始图像分类模型得到所述训练样本集中图像样本的第一分类预测结果;
[0022]将所述训练样本集分别输入至更新后的所述中间图像分类模型,以通过所述中间图像分类模型得到所述训练样本集中图像样本的第二分类预测结果;
[0023]根据所述第一分类预测结果和所述第二分类预测结果确定损失值;
[0024]基于所述损失值对所述初始图像分类模型中的多个所述适配器和所述预测输出模块的模型参数进行更新,得到目标图像分类模型。
[0025]根据本专利技术一些实施例提供的面向增量学习的图像分类方法,在所述获取所述目标分类任务类型对应的多个适配器以及预测输出模块之前,所述方法还包括:
[0026]构建网络模块池,所述网络模块池包括多个网络模块组,所述网络模块组用于与预训练模型进行结合以得到图像分类模型;
[0027]其中,每个所述网络模块组分别对应一个分类任务类型,每个所述网络模块组包括多个适配器以及预测输出模块;
[0028]所述获取所述目标分类任务类型对应的多个适配器以及预测输出模块,包括:
[0029]根据所述分类任务类型与所述网络模块组的对应关系,从所述网络模块池中获取所述目标分类任务类型对应的多个适配器以及预测输出模块。
[0030]根据本专利技术一些实施例提供的面向增量学习的图像分类方法,在所述将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型之后,所述方法还包括:
[0031]获取所述目标图像分类模型中多个所述适配器和所述预测输出模块的第一模型参数;
[0032]根据所述第一模型参数更新所述网络模块池中对应的网络模块组。
[0033]根据本专利技术一些实施例提供的面向增量学习的图像分类方法,在获取所述目标图像分类模型中多个所述适配器和所述预测输出模块的第一模型参数之后,所述方法还包括:
[0034]将所述第一模型参数上传至所述服务器,以使所述服务器对所述第一模型参数进行整合处理,得到第二模型参数;
[0035]所述根据所述第一模型参数更新所述网络模块池中对应的网络模块组,包括:
[0036]从所述服务器获取第二模型参数,根据所述第二模型参数更新所述网络模块池中对应的网络模块组。
[0037]根据本专利技术一些实施例提供的面向增量学习的图像分类方法,所述根据所述第一本地图像样本集确定目标分类任务类型,包括:
[0038]获取训练好的第一图像分类模型;
[0039]从所述第一本地图像样本集中选择预设数量的图像样本作为测试样本集;
[0040]将所述测试样本集输入至所述第一图像分类模型,以通过所述第一图像分类模型得到所述测试样本集对应的分类预测结果;
[0041]根据所述分类预测结果,确定所述第一本地图像样本集对应的分类任务类型。
[0042]根据本专利技术一些实施例提供的面向增量学习的图像分类方法,所述全局模型参数通过以下公式得到:
[0043][0044]其中,所述w
t
为全局模型参数,所述n为参加联邦学习的客户端数量,所述n
k
为第k个客户端上传的本地图像样本数量,所述为第k个客户端上传的所述本地模型参数。
[0045]为实现上述目的,本申请实施例的第二方面提出了一种面向增量学习的图像分类装置,所述装置包括:
[0046]第一获取模块,用于获取第一本地图像样本集;
[0047]分类模块,用于根据所述第一本地图像样本集确定目本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种面向增量学习的图像分类方法,其特征在于,所述方法应用于参与联邦学习的客户端,所述客户端与服务器通信连接,其特征在于,所述方法包括:获取第一本地图像样本集;根据所述第一本地图像样本集确定目标分类任务类型;获取所述目标分类任务类型对应的多个适配器以及预测输出模块;将多个所述适配器、所述预测输出模块和预训练模型进行结合处理,得到初始图像分类模型;基于所述第一本地图像样本集对所述初始图像分类模型进行训练,得到中间图像分类模型;将所述中间图像分类模型的本地模型参数上传至所述服务器,以使所述服务器对所述本地模型参数进行整合处理,得到全局模型参数;从所述服务器获取所述全局模型参数,并根据所述全局模型参数更新所述中间图像分类模型;将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型;获取待分类图像;将所述待分类图像输入至所述目标图像分类模型中,以通过所述目标图像分类模型得到所述待分类图像对应的分类预测结果。2.根据权利要求1所述的方法,其特征在于,所述将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型,包括:获取第二本地图像样本集,所述第二本地图像样本集与所述第一本地图像样本集为相同分类任务类型的旧图像样本;对所述第一本地图像样本集和所述第二本地图像样本集进行融合处理,得到训练样本集;将所述训练样本集分别输入至所述初始图像分类模型,以通过所述初始图像分类模型得到所述训练样本集中图像样本的第一分类预测结果;将所述训练样本集分别输入至更新后的所述中间图像分类模型,以通过所述中间图像分类模型得到所述训练样本集中图像样本的第二分类预测结果;根据所述第一分类预测结果和所述第二分类预测结果确定损失值;基于所述损失值对所述初始图像分类模型中的多个所述适配器和所述预测输出模块的模型参数进行更新,得到目标图像分类模型。3.根据权利要求1所述的方法,其特征在于,在所述获取所述目标分类任务类型对应的多个适配器以及预测输出模块之前,所述方法还包括:构建网络模块池,所述网络模块池包括多个网络模块组,所述网络模块组用于与预训练模型进行结合以得到图像分类模型;其中,每个所述网络模块组分别对应一个分类任务类型,每个所述网络模块组包括多个适配器以及预测输出模块;所述获取所述目标分类任务类型对应的多个适配器以及预测输出模块,包括:根据所述分类任务类型与所述网络模块组的对应关系,从所述网络模块池中获取所述
目标分类任务类型对应的多个适配器以及预测输出模块。4.根据权利要求3所述的方法,其特征在于,在所述将更新后的所述中间图像分类模型作为教师模型对所述初始图像分类模型进行蒸馏处理,得到目标图像分类模型之后,所述方法还包括:获取所述目标图像分类模型中多个所述适配器和所述预测输出模块的第一模型参数;根据所述第一模型参数更新所述网络模块池中对应的网络模块组。5.根...

【专利技术属性】
技术研发人员:瞿晓阳王健宗王亮
申请(专利权)人:平安科技深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1