一种联邦学习方法技术

技术编号:38400759 阅读:12 留言:0更新日期:2023-08-07 11:12
本发明专利技术属于边缘人工智能技术领域,具体涉及一种联邦学习方法。该方法中各个客户端对服务器下传的网络模型进行稀疏化训练,在每一次迭代训练过程中,若网络模型精度大于精度阈值且剪枝率小于目标剪枝率,则进行剪枝,且对于置为0的权重,相应掩码中对应位置也置为0,直至达到训练的迭代终止条件,得到各个客户端的本地模型和掩码,并上传至服务器;服务器对各个客户端上传到本地模型进行聚合,得到全局模型,进而根据全局模型和各个客户端的掩码,得到全局模型的子网并进行下发。本发明专利技术中的客户端对网络模型进行稀疏化训练后再剪枝,使得网络模型更为轻量化,与直接剪枝的启发式方式相比,加快了收敛速度,更好地解决客户端资源受限的问题。限的问题。限的问题。

【技术实现步骤摘要】
一种联邦学习方法


[0001]本专利技术属于边缘人工智能
,具体涉及一种联邦学习方法。

技术介绍

[0002]据《全球移动市场报告》预测,2023年全球智能手机用户或将突破四十亿,这些设备每秒钟都会产生大量数据。将数据上传到云端处理会引起较大的时延,即使应用对实时性的要求不高,云端也会随着设备规模的不断扩大而面临更大的压力,将计算下沉到边缘端是一个缓解云端压力的方法。边缘计算在靠近用户或数据输入的地方提供服务,在边缘端完成部分工作可以更好地保护用户的数据隐私。
[0003]许多研究将机器学习应用到其他领域后取得了很大突破。边缘设备更靠近用户,能收集到许多数据,而机器学习能从数据中提取有用的信息,在边缘计算中使用机器学习的理论和技术可以为用户提供更好的服务。联邦学习能在保护数据安全的前提下,联合客户端和服务器训练神经网络模型。
[0004]然而,客户端在服务器的协助下训练模型有一定的困难。神经网络模型规模较大,在训练中需要多次迭代,而边缘设备的资源、计算能力受限,可能难以承受训练和通信开销。为了降低成本,许多研究通过减少参数量、通信数据量或通信次数来轻量化模型,但这忽略了客户端之间数据的非独立同分布事实。此外,由各个客户端之间异构的数据训练出的模型往往也具有异构性,在服务器中聚合后可能会出现模型性能下降的情况。个性化联邦学习是针对数据异构性的很好的解决方案,常见的构建个性化联邦学习的方法包括元学习、迁移学习、自适应调整等,但这些方法都没有考虑客户端的计算和通信能力等都受限的问题。
[0005]以上减少模型开销的方法会对模型性能有一定的损伤,而针对异构数据的方法需要用额外的参数表示数据,增加了模型的开销。在联邦学习中使用剪枝策略可以在去除冗余参数的同时为客户端训练个性化模型,但是通常采用的剪枝策略是启发式的,没有考虑到权重在训练中的动态性,会造成不小的精度损失,这需要额外的再训练来恢复,这种方法的代价仍然比较高。

技术实现思路

[0006]本专利技术的目的在于提供一种联邦学习方法,用以解决采用启发式的剪枝策略造成模型精度损失的问题。
[0007]为解决上述技术问题,本专利技术提供了一种联邦学习方法,包括如下步骤:
[0008]1)各个客户端对服务器下传的网络模型进行稀疏化训练,在每一次稀疏化迭代训练过程中均对网络模型精度和剪枝率进行判断,若网络模型精度大于精度阈值且剪枝率小于目标剪枝率,则将网络模型中部分趋近于0的权重置为0以进行剪枝,且对于置为0的权重,相应掩码中对应位置也置为0,直至达到稀疏化训练的迭代终止条件,从而得到各个客户端的本地模型和掩码,并上传至服务器;
[0009]2)服务器对各个客户端上传到本地模型进行聚合,得到全局模型;根据全局模型和各个客户端的掩码,得到全局模型的子网并下送至各个客户端。
[0010]上述技术方案的有益效果为:本专利技术中的客户端对网络模型进行稀疏化训练后再剪枝,使得网络模型更为轻量化,与直接剪枝的启发式方式相比,加快了收敛速度,更好地解决客户端资源受限的问题,最终使网络模型能更好地运行在客户端中。而且,与没有稀疏化训练的方法相比,剪枝对模型精度的影响更小,且在剪枝之前对网络模型精度和剪枝率进行判断,在两者均达到要求的情况下再剪枝,防止对模型的精度造成永久性损伤。
[0011]进一步地,在进行本地模型聚合时,仅对本地模型中未进行剪枝的部分进行聚合。
[0012]上述技术方案的有益效果为:仅对未剪枝部分进行聚合,可以减少数据异构对模型性能的影响。
[0013]进一步地,将各个本地模型中未进行剪枝的部分相应位置取平均值以实现聚合。
[0014]进一步地,步骤2)中将全局模型与客户端各自的掩码的对应位置进行相乘得到各个客户端的子网。
[0015]进一步地,将损失函数转化具有稀疏度要求的优化问题,并采用ADMM算法求解优化目标,以实现对网络模型进行稀疏化训练。
[0016]上述技术方案的有益效果为:ADMM能够很好地解决稀疏约束下的非凸优化问题,因此使用ADMM进行求解。
附图说明
[0017]图1是本专利技术的ADMM权重剪枝的个性化联邦学习过程图;
[0018]图2是本专利技术进行ADMM权重剪枝的流程图;
[0019]图3是本专利技术的聚合本地模型的示意图;
[0020]图4是本专利技术的获取子网的示意图;
[0021]图5(a)是针对CIFAR

10设置不同目标剪枝率时得到的精度图;
[0022]图5(b)是针对MNIST设置不同目标剪枝率时得到的精度图;
[0023]图6(a)是针对CIFAR

10采用不同算法得到的精度图;
[0024]图6(b)是针对MNIST采用不同算法得到的精度图;
[0025]图7(a)是针对CIFAR

10采用不同算法的通信成本对比图;
[0026]图7(b)是针对MNIST采用不同算法的通信成本对比图;
[0027]图8(a)是系统平台的主界面图;
[0028]图8(b)是系统平台的花卉识别功能界面图;
[0029]图9是选择模型界面图;
[0030]图10是选择图像界面图;
[0031]图11(a)和图11(b)分别是同一张图片下服务器和客户端模型的花卉识别结果图。
具体实施方式
[0032]为了使本专利技术的目的、技术方案及优点更加清楚明了,以下结合附图及实施例,对本专利技术进行进一步详细说明。
[0033]方法实施例:
[0034]假设网络中有1个服务器和n个客户端,客户端1、客户端2和客户端3为与服务器通信的任意3个客户端,服务器与这3个客户端协同训练模型。本实施例以服务器、客户端1、客户端2和客户端3为例对训练模型过程给出详细说明,其整个学习过程如图1所示,主要包括ADMM权重剪枝、聚合本地模型和获取子网三个过程。
[0035]步骤一,各客户端对服务器下传的网络模型进行ADMM稀疏化训练以及剪枝操作,得到各个客户端的本地模型和掩码,并上传至服务器。如图2所示,具体的:
[0036]对网络模型进行稀疏化训练。将模型的损失函数转化为具有稀疏度要求的优化问题,对网络模型(本实施例为神经网络)进行稀疏化训练,表示如下:
[0037][0038]上式表示客户端i在模型权重数量的稀疏约束下希望损失函数L
i
最小。其中,W
i
表示客户端i的神经网络权重,由W
g

M
i
得到;W
g
表示服务器中全局模型的参数;M
i
表示客户端i的掩码;S
i
={W
i
|card(W
i
)≤N
i
},为非凸集合,card(W
i
)返回客户端i中神经网络的非零元素数量,N
本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种联邦学习方法,其特征在于,包括如下步骤:1)各个客户端对服务器下传的网络模型进行稀疏化训练,在每一次稀疏化迭代训练过程中均对网络模型精度和剪枝率进行判断,若网络模型精度大于精度阈值且剪枝率小于目标剪枝率,则将网络模型中部分趋近于0的权重置为0以进行剪枝,且对于置为0的权重,相应掩码中对应位置也置为0,直至达到稀疏化训练的迭代终止条件,从而得到各个客户端的本地模型和掩码,并上传至服务器;2)服务器对各个客户端上传到本地模型进行聚合,得到全局模型;根据全局模型和各个客户端的掩码,得到全局模型的子网并下送至各个客户端。...

【专利技术属性】
技术研发人员:袁培燕石玲赵晓焱张俊娜刘春红
申请(专利权)人:河南师范大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1