一种基于超网络的分层联邦学习系统技术方案

技术编号:38529855 阅读:12 留言:0更新日期:2023-08-19 17:04
本发明专利技术公开了一种基于超网络的分层联邦学习系统,包括以下步骤:构建包括一个中央服务器、多个聚合服务器和多个客户端的三层联邦学习系统;各个客户端利用本地数据对模型进行训练,并将训练后的模型参数上传至聚合服务器;聚合服务器利用更新后的本地模型参数更新超网络模型参数,并将更新后的超网络模型上传至中央服务器;中央服务器对超网络模型参数进行联邦聚合,并将聚合更新后的超网络重新下发至各个聚合服务器。本发明专利技术将联邦学习扩展到三层,可在复杂的数据异构场景下,有效提升联邦学习方法的性能,在提高通信效率的同时降低计算成本。算成本。算成本。

【技术实现步骤摘要】
一种基于超网络的分层联邦学习系统


[0001]本专利技术属于联邦学习技术应用领域,涉及一种基于超网络的分层联邦学习系统。

技术介绍

[0002]伴随着物联网、云计算等技术的繁荣发展,联邦学习(Federated Learning,FL)能够有效的打破数据壁障,因而逐渐成为一种新兴的分布式机器学习范式,掀起了万物互联的潮流与趋势。在FL场景中,各方数据保留在本地,不泄露隐私也不违反通用数据保护条例;多个参与者在满足身份、地位平等的条件下,联合数据建立虚拟的全局模型,组成共同利益体系,协同训练共享数据价值而不是共享原始数据。这样既保护了本地数据的隐私,又解决了因本地数据量不够充足、数据类型不够丰富引起的模型泛化能力差、模型性能不尽人意的问题。
[0003]联邦学习的一个主要挑战是如何在客户之间训练一个关于非独立同分布数据集的高效全局模型,FedAvg的大多数变体致力于解决非独立同分布问题中的标签分布偏斜和数量不平衡的问题。现在假设这样一个典型的联邦学习场景。边缘设备(如传感器、监视器和可穿戴设备)希望利用自己的本地私有数据协作训练共享模型,但直接使用传统的联邦学习方法可能会导致三个问题:(1)由于边缘设备无处不在,传统联合学习方法的模型聚合方案将导致极其昂贵的通信成本,甚至可能导致模型无法收敛。(2)传统的联邦学习算法将在中央服务器中聚合所有模型参数,这将导致单个节点上的计算成本过高。(3)由于功能和位置的差异,位于不同区域的边缘设备可能具有特征分布偏斜的私有数据集,因此传统的联邦学习算法可能无法获得良好的准确性
[0004]由于存在上述问题,复杂数据异构场景下,现有的联邦学习方法效果仍没有达到最佳,性能提升空间很大,客户端之间的混合数据异构类型时模型聚合精度不高的问题亟待解决。

技术实现思路

[0005]针对上述技术问题,本专利技术提出一种基于超网络的分层联邦学习系统,将联邦学习扩展到三层,在下游设置聚合服务器并使用超网络为客户端生成本地模型参数,在上游利用中央服务器聚合并更新超网络的参数,实现了上层和下层传输参数量的解耦,在提高通信效率和降低计算成本的同时能够获得较好的模型性能。为实现上述目的,本专利技术采用的技术方案为:
[0006]步骤一、构建包括一个中央服务器、J个聚合服务器和J
×
N个客户端的三层联邦学习系统:
[0007]步骤101、中央服务器与J个聚合服务器相连,每个聚合服务器与N个客户端相连,同一聚合服务器下的客户端本地数据存在标签分布不平衡,不同聚合服务器下的客户端本地数据存在特征分布不平衡;
[0008]步骤102、在每个聚合服务器中初始化一个超网络,用于生成该聚合服务器下所有
客户端的本地模型;
[0009]步骤103、聚合服务器利用客户端嵌入向量为客户端生成其本地模型并将该本地模型下发给与其相连的客户端,即:
[0010][0011]其中,代表超网络,其参数为上标t为中央服务器与聚合服务器之间的通信回合数,上标r为聚合服务器与客户端之间的通信回合数;
[0012]步骤二、各个客户端利用本地数据对模型进行训练,并将训练后的模型参数上传至聚合服务器:
[0013]步骤201、本地客户端采用随机梯度下降法对模型进行E次本地更新,即:
[0014][0015]其中,是本地损失函数,η是本地学习率;
[0016]步骤202、客户端将更新后的本地模型上传至与其相连的聚合服务器;
[0017]步骤三、聚合服务器利用更新后的本地模型参数更新超网络模型参数,并将更新后的超网络模型上传至中央服务器:
[0018]步骤301、聚合服务器接收到本地模型后,利用对该聚合服务器内的超网络和客户端嵌入向量进行更新,步骤如下:
[0019][0020][0021][0022]其中,是超网络模型参数,是客户端嵌入向量,α是超网络学习率;
[0023]步骤302、在聚合服务器与客户端通信R次后(即超网络和客户端嵌入向量进行了R次更新),将更新后的超网络模型上传至中央服务器;
[0024]步骤四、中央服务器对超网络模型参数进行联邦聚合,并将聚合更新后的超网络重新下发至各个聚合服务器:
[0025]步骤401、中央服务器接收到J个聚合服务器上传的超网络模型后,执行联邦平均算法对J个超网络模型进行聚合,即:
[0026][0027]其中,是客户端C
j,i
的本地数据集,n是所有客户端的样本数之和;
[0028]步骤402、中央服务器将更新后的超网络下发至各个聚合服务器,进入下一个通信回合,T次通信后输出最终的超网络模型和本地模型
[0029]本专利技术提供的基于超网络的分层联邦学习系统,与现有技术相比,具有以下特点:
[0030](1)基于分层结构,本专利技术设计了一种新的聚合机制,该机制对底层边缘设备的私有数据进行多次采样,然后将其发送到中央服务器进行处理,以降低总通信成本;
[0031](2)本专利技术使用更轻的超网络为每个边缘设备生成本地模型,以降低单个服务器的计算成本;
[0032](3)本专利技术所提出的联邦学习系统减轻了客户端之间非IID数据的影响,在测试精度、收敛性和鲁棒性方面表现良好。
附图说明
[0033]图1为本专利技术的流程框图;
[0034]图2为本专利技术的系统架构图;
具体实施方式
[0035]下面结合附图及本专利技术的实施对本专利技术的方法作进一步的详细的说明。显然,所描述的实施例仅仅是本专利技术一部分实施例,而不是全部的实施例。基于本专利技术中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本专利技术保护的范围。
[0036]图1表示的是一种基于超网络的分层联邦学习系统流程图。
[0037]图2表示的是一种基于超网络的分层联邦学习系统架构图。
[0038]为便于对本专利技术实施例的理解,下面结合附图说明本专利技术的合理性与有效性,包含具体步骤如下:
[0039]步骤一、构建包括一个中央服务器、J=5个聚合服务器和50个客户端的三层联邦学习系统:
[0040]步骤101、中央服务器与J=5个聚合服务器相连,每个聚合服务器与N=10个客户端相连,同一聚合服务器下的客户端本地数据存在标签分布不平衡,不同聚合服务器下的客户端本地数据存在特征分布不平衡;
[0041]步骤102、在每个聚合服务器中初始化一个超网络,用于生成该聚合服务器下所有客户端的本地模型;
[0042]步骤103、聚合服务器利用客户端嵌入向量为客户端生成其本地模型并将该本地模型下发给与其相连的客户端,即:
[0043][0044]其中,代表超网络,其参数为上标t为中央服务器与聚合服务器之间的通信回合数,上标r为聚合服务器与客户端之间的通信回合数;
[0045]步骤二、各个客户端利用本地数据对模型进行训练,并将训练后的模型参数上传至聚合服务器:
[本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于超网络的分层联邦学习系统,其特征在于,包括以下步骤:步骤一、构建包括一个中央服务器、J个聚合服务器和J
×
N个客户端的三层联邦学习系统:步骤101、中央服务器与J个聚合服务器相连,每个聚合服务器与N个客户端相连,同一聚合服务器下的客户端本地数据存在标签分布不平衡,不同聚合服务器下的客户端本地数据存在特征分布不平衡;步骤102、在每个聚合服务器中初始化一个超网络,用于生成该聚合服务器下所有客户端的本地模型;步骤103、聚合服务器利用客户端嵌入向量为客户端生成其本地模型并将该本地模型下发给与其相连的客户端,即:其中,代表超网络,其参数为上标t为中央服务器与聚合服务器之间的通信回合数,上标r为聚合服务器与客户端之间的通信回合数;步骤二、各个客户端利用本地数据对模型进行训练,并将训练后的模型参数上传至聚合服务器:步骤201、本地客户端采用随机梯度下降法对模型进行E次本地更新,即:其中,是本地损失函数,η是本地学习率;步骤202、客户端将更新后的本地模型上传...

【专利技术属性】
技术研发人员:蒋雯杨季皓聂来森
申请(专利权)人:西北工业大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1