System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind() 基于联邦知识蒸馏和集成学习的图像分类方法技术_技高网

基于联邦知识蒸馏和集成学习的图像分类方法技术

技术编号:40321923 阅读:5 留言:0更新日期:2024-02-09 14:17
本发明专利技术公开一种基于联邦知识蒸馏和集成学习的图像分类方法,其步骤为:服务器生成训练数据集和辅助数据集,构建联邦学习全局模型并进行初始化,将其下发至选择的客户端。客户端基于有监督损失和一致性约束损失训练本地模型,完成后将模型参数上传至服务器。服务器对接收到的模型进行加权聚合,利用辅助数据集进行基于集成学习的模型分段知识蒸馏过程,将客户端模型的知识融合到全局模型。本发明专利技术提升了全局模型的图像分类泛化性能,增强了客户端模型的分类精度,提高了系统对异质数据的鲁棒性。

【技术实现步骤摘要】

本专利技术属于图像处理,更进一步涉及数据分类中的一种基于联邦知识蒸馏和集成学习的图像分类方法。本专利技术可通过物联网中多个设备的协同训练,将训练好的模型用于图像分类任务。


技术介绍

1、图像分类是计算机视觉中的一个重要任务,在实际应用中有广泛的应用,如自动驾驶、医疗诊断、人脸识别等。随着物联网的快速发展,大量的设备和传感器能够收集数据,然而,由于隐私安全的限制,这些数据无法集中存储和处理。联邦学习允许多个设备协同训练全局共享的共识模型,同时保护数据的隐私。在此模型的基础上,把未知的图像输入进去,得到该测试样本的预测类别。然而,在物联网中,每个设备节点的数据分布存在较大差异,这种分布异质性会导致设备训练出的模型漂移、难以收敛和精度降低的问题。一些研究提出在本地模型训练时使用知识蒸馏技术来约束本地模型向全局模型学习来解决此问题。然而,这种方法仍然存在不足:当设备拥有的数据较少时,难以从中学习到有效的信息;可能需要设备间共享的额外代理数据集,这违背了隐私保护要求,增加了通信开销,在实际应用中可能会带来困难。

2、安徽师范大学在其申请的专利文献“基于动态自适应知识蒸馏的联邦学习模型聚合方法”(专利申请号:cn202310682277.6,专利公开号:cn116681144a,公布日期2023.09.01)中提出了一种基于动态自适应知识蒸馏的联邦学习模型聚合方法。该方法主要包括如下步骤:(1)服务器初始化全局模型并将其发送至参与本轮训练的客户端;(2)客户端接收到全局模型后,确定本轮知识蒸馏中对收到的全局模型学习的比例,自适应调整学习本地数据集和全局模型的比例,并动态调整教师模型的输出,使其处于最适合学习的分布状态,训练生成本地模型,并上传给服务器;(3)对接收到的本地模型进行聚合生成新的全局模型从而完成本轮训练过程。该方法存在的不足之处是,在面对客户端数据较少和分布不平衡的情况下执行联邦学习任务时,客户端模型优化程度有限,难以在全局模型中学习到有效信息,导致模型性能较低。

3、中国人民解放军总医院在其申请的专利文献“针对数据异质性的个性化联邦学习方法、系统及存储介质”(专利申请号:cn202311035140.8,专利公布号:cn116933866a,公布日期2023.10.24)中提出了一种针对数据异质性的个性化联邦学习方法、系统及存储介质。该方法包括如下步骤:(1)服务端将初始学习模型发送至客户端;(2)服务端根据所有客户端的数据分布相似性对客户端进行聚类,生成客户端的相似性网络图;(3)客户端进行本地迭代训练,得到训练梯度和第一权重参数进行训练更新后的第一更新参数;(4)服务端计算所有客户端上传的第一更新参数的平均值,得到第一平均更新参数;(5)服务端自动更新迭代初始学习模型,并得到第二权重参数更新后的第二更新参数;(6)服务端更新初始学习模型;(7)重复上述步骤直至模型损失函数收敛,得到联邦学习模型。该方法存在的不足之处是:该方法中客户端只上传特征提取层参数和梯度,在服务器进行迭代更新分类层参数并发送到客户端,这使得客户端本地模型的更新方向存在较大波动,降低了本地模型的收敛效率,最终影响在数据异质性场景下客户端模型的预测性能。


技术实现思路

1、本专利技术的目的在于针对上述现有技术存在的问题,提出一种基于联邦知识蒸馏和集成学习的图像分类方法,旨在解决数据异质联邦学习中训练图像分类模型时模型难以收敛和泛化性能差的问题。

2、实现本专利技术目的的思路是,本专利技术首先提出针对联邦学习客户端本地训练的损失函数的优化。对本地模型的目标函数施加一致性约束,使得客户端在迭代更新时,平衡本地模型和全局模型的更新方向,避免了各个客户端模型在聚合时的参数方差过大,造成全局模型无法收敛的问题。其次,本专利技术提出全局模型的可学习聚合策略和知识迁移方法。在服务器端,先进行客户端模型的加权聚合,得到全局模型。随后使用基于集成学习的策略,通过构建可学习的集成模型,结合多个本地模型,使用分段模型蒸馏训练,挖掘并学习在各个非独立同分布数据的上训练得到的模型的潜藏知识来提高模型的泛化性能。由于全局模型吸收了各个客户端的知识,因此在差异化数据样本空间上进行优化的时候,具有较强的抵抗数据异质性的鲁棒性。因此,本专利技术通过在联邦学习中应用集成学习策略和蒸馏学习方法,实现了在缺乏大量监督数据、类别分布不平衡、数据呈非独立同分布情况下学习高性能模型的目标。

3、实现本专利技术目的的具体步骤如下:

4、步骤1,生成和分配样本集:

5、生成训练样本集和辅助样本集,为每个客户端分配各自的客户端样本集;

6、步骤2,在服务端构建一个卷积神经网络和一个多层感知机网络,初始化网络参数,分别作为联邦学习全局模型和全局集成模型;

7、步骤3,服务器确定参与客户端:

8、服务器随机选择ns个客户端,并将其确定为下个轮次将要参与联邦学习的客户端,随后将联邦学习全局模型分发给被选择的客户端,ns≥3;

9、步骤4,客户端进行本地训练:

10、将每个参与客户端的样本集输入到其对应的模型中,使用有监督损失和一致性损失作为本地模型学习的联合损失函数,采用随机梯度下降算法作为优化器,进行梯度的反向传播计算更新模型参数,直到本地模型达到收敛;最终每个客户端得到训练好的联邦学习客户端模型,在客户端本地暂存一个模型副本,并将训练好的模型上传到服务器;

11、步骤5,服务器进行客户端模型的集成和知识迁移:

12、步骤5.1,服务器对本轮次接收到的从客户端上传的模型参数进行加权聚合,获得聚合后的联邦学习全局模型:

13、步骤5.2,将辅助样本集输入到客户端上传的模型,得到每个本地模型的中间输出向量和类别预测向量;

14、步骤5.3,将辅助数据集输入聚合后的联邦学习全局模型,得到中间输出向量和类别预测向量;

15、步骤5.4,计算步骤5.2和步骤5.3中得到的模型中间输出向量分布的kl散度,并求平均,得到基于模型特征输出分布的蒸馏损失值

16、步骤5.5,将步骤5.2中得到的类别预测向量作为输入,送入服务器全局集成模型中,得到综合类别预测向量

17、步骤5.6,将综合类别预测向量和步骤5.3中得到的联邦学习全局模型类别预测向量进行kl散度计算,得到基于模型预测软分布的蒸馏损失值

18、步骤5.7,将辅助样本集中的真实标签y和步骤5.3得到的联邦学习全局模型类别预测向量计算得到联邦学习全局模型拟合辅助样本数据的损失

19、步骤5.8,将三个损失值作为服务端损失函数使用随机梯度下降算法和梯度反向传播,对联邦学习全局模型进行再训练;

20、步骤6,判断服务器最终全局模型是否满足联邦学习训练的终止条件,若是,则得到最终训练好的联邦学习全局模型后执行步骤7,否则,将当前迭代次数加1后执行步骤3;

21、步骤7,将待分类的图像样本输入到训练好的联邦学习全局模型中,输出分类结本文档来自技高网...

【技术保护点】

1.一种基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于,基于集成学习和分段知识蒸馏,在服务器上对聚合后的全局图像分类模型进行本地知识的集成与迁移,增强了全局模型的泛化性,提高了训练效率;在客户端上,将融合本地知识的全局模型进行带有一致性约束的更新,使全局模型传递的广义知识更好地适应局部表示,减轻本地模型的漂移问题;最终得到一个具有对抗数据异质鲁棒性和泛化性的图像分类模型;该图像分类方法的具体步骤包括如下:

2.根据权利要求1所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于,步骤1中所述的训练样本集和辅助样本集指的是,生成至少包含10种类别的图像样本,其中每种类别至少6000张图像,将所选取的所有的图像组成样本集;将样本集中的每种类别随机选取至少500个样本组成辅助样本集,剩余的样本组成训练样本集;步骤1中所述的客户端样本集指的是,从训练样本集中的类别中随机选取至少3种类别,每种类别至少1000张图像,将所选取的图像组成一个客户端样本集;联邦学习系统中存在至少10个客户端,从而得到至少10个不同的客户端样本集,其中每个客户端对应一个客户端样本集。

3.根据权利要求1所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤2中所述的卷积神经网络是由12个层串联而成,其结构依次为:第一卷积层,第一批归一化层,第一激活层,第一池化层,第二卷积层,第二批归一化层,第二激活层,第二池化层,第一全连接层,第三激活层,dropout层,第二全连接层;前8层为特征提取模块,后4层为预测模块;将第一、第二卷积层的卷积核的个数分别设置为16,32,卷积核的大小均设置为5×5,步长均设置为1,填充宽度均设置为2;第一至第三激活层采用Relu激活函数,将inplace参数设置为False;单个图像样本输入网络经过第一卷积层和第二卷积层处理后的特征图维度分别为14×14,7×7;第一、第二池化层均采用最大池化方式,池化区域核的大小均设置为2×2,池化步长均设置为2;第一、第二批归一化层的eps参数设置为1×10-5,momentum参数设置为0.1,affine参数设置为True;将Dropout层的drop_rate参数设置为0.05;将第一、第二连接层的神经元的数量分别设置为512和C,其中C等于数据集样本的类别总数;步骤2中所述的多层感知机网络的结构包括3个全连接层和2个激活层,分别为第一全连接层,第一激活层,第二全连接层,第二激活层,第三全连接层;其中,第一、第二全连接层都具有128个隐藏单元,随后分别连接ReLU激活函数,并将inplace参数设置为False;第三全连接层具有C个神经元。

4.根据权利要求1所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤4中所述本地模型学习的联合损失函数如下:

5.根据权利要求1所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤5.4中所述的基于模型特征输出分布的蒸馏损失值如下:

6.根据权利要求5所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤5.5中所述的综合类别预测向量如下:

7.根据权利要求6所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤5.6中所述的基于模型预测软分布的蒸馏损失值如下:

8.根据权利要求5所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤5.7中所述的联邦学习全局模型拟合辅助样本数据的损失如下:

9.根据权利要求1所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤6所述的训练的终止条件指的是满足下述条件之一的情形:

...

【技术特征摘要】

1.一种基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于,基于集成学习和分段知识蒸馏,在服务器上对聚合后的全局图像分类模型进行本地知识的集成与迁移,增强了全局模型的泛化性,提高了训练效率;在客户端上,将融合本地知识的全局模型进行带有一致性约束的更新,使全局模型传递的广义知识更好地适应局部表示,减轻本地模型的漂移问题;最终得到一个具有对抗数据异质鲁棒性和泛化性的图像分类模型;该图像分类方法的具体步骤包括如下:

2.根据权利要求1所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于,步骤1中所述的训练样本集和辅助样本集指的是,生成至少包含10种类别的图像样本,其中每种类别至少6000张图像,将所选取的所有的图像组成样本集;将样本集中的每种类别随机选取至少500个样本组成辅助样本集,剩余的样本组成训练样本集;步骤1中所述的客户端样本集指的是,从训练样本集中的类别中随机选取至少3种类别,每种类别至少1000张图像,将所选取的图像组成一个客户端样本集;联邦学习系统中存在至少10个客户端,从而得到至少10个不同的客户端样本集,其中每个客户端对应一个客户端样本集。

3.根据权利要求1所述的基于联邦知识蒸馏和集成学习的图像分类方法,其特征在于:步骤2中所述的卷积神经网络是由12个层串联而成,其结构依次为:第一卷积层,第一批归一化层,第一激活层,第一池化层,第二卷积层,第二批归一化层,第二激活层,第二池化层,第一全连接层,第三激活层,dropout层,第二全连接层;前8层为特征提取模块,后4层为预测模块;将第一、第二卷积层的卷积核的个数分别设置为16,32,卷积核的大小均设置为5×5,步长均设置为1,填充宽度均设置为2;第一至第三激活层采用relu激活函数,将inplace参数设置为false;单个图像样本输入网络经过第...

【专利技术属性】
技术研发人员:毛文杰鱼滨张琛解宇刘伟明
申请(专利权)人:西安电子科技大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1