一种基于强化学习的非独立同分布环境中联邦学习优化方法技术

技术编号:37196195 阅读:12 留言:0更新日期:2023-04-20 22:54
本发明专利技术所解决的技术问题是非独立同分布环境对联邦学习性能影响的问题,提出一种基于深度强化学习的客户端节点选择算法,在每一轮联邦学习通信中,选择有利于全局模型收敛的客户端子集进行模型聚合,用有限的通信轮次达到更高的目标精度;将联邦学习建模为马尔可夫决策过程,通过使用深度强化学习中的Double Deep Q

【技术实现步骤摘要】
一种基于强化学习的非独立同分布环境中联邦学习优化方法


[0001]本专利技术以深度学习为基础提出一种基于深度强化学习的非独立同分布环境中联邦学习优化方法;在每一轮联邦学习通信中,通过深度强化学习智能体选择有利于全局模型收敛的客户端子集参与联邦学习的模型聚合过程,以更少的通信轮次实现目标精度;可以使联邦学习更好的部署在非独立同分布的现实场景中。

技术介绍

[0002]联邦学习已经成为一种新的分布式机器学习范式,多个客户端在中央服务器的协调下协作培训模型;其中,客户端只需要在每一轮通信过程中将训练好的本地模型参数上传到服务器,而不需要与他人共享本地数据;服务器收集客户端所上传的模型参数并执行聚合算法,更新全局模型参数,然后将全局模型参数下发给每个客户端,进行下一轮训练;重复该过程直到全局模型达到目标精度为止。
[0003]相比传统的集中式机器学习,联邦学习在隐私问题和通信效率方面展现出了巨大的优势;尽管有其优势,但联邦学习的一个关键且常见的挑战就是各方之间的本地数据是非独立同分布的,即设备的本地数据不能代表总体分布;具有相同初始参数的局部模型在本地训练阶段会向着各自的局部最优值更新,而数据异构性会导致局部最优偏离全局最优,从而使服务器聚合的全局模型远离全局最优值,因此大大降低了全局模型的准确性和通信效率。

技术实现思路

[0004]本专利技术所解决的技术问题是针对非独立同分布环境中联邦学习收敛速度慢、通信效率低的问题,将联邦学习中的客户端选择问题建模为马尔可夫决策过程,提出一种基于Double Deep Q

Learning(DDQN)的联邦学习参与节点选择算法,通过智能地选择参与模型聚合的客户端,减小全局模型参数与理想模型参数的偏差程度,在有限的通信轮次内最大限度地提高全局模型的精度。
[0005]为了实现上述目的,本专利技术的技术方案如下:一种基于深度强化学习的非独立同分布环境中联邦学习优化算法,包括以下步骤:步骤一:服务器初始化全局模型参数,并将全局模型参数发送给所有客户端;初始化评估网络、目标网络和记忆存储器;步骤二:每个客户端使用本地数据进行一轮模型训练;步骤三:每个客户端将本地训练损失值上传给服务器;步骤四:服务器中的深度强化学习智能体收集客户端上传的本地训练损失值作为状态,做出相应的动作,从全部客户端中选择一个子集作为本轮参与模型聚合的客户端;并向被选择的客户端发送指令;步骤五:被选择的客户端使用本地数据集完成剩余的本地训练任务,并将训练好
的局部模型参数上传到服务器;步骤六:服务器收集全部选择的客户端上传的局部模型参数并进行聚合以更新全局模型参数;服务器将新的全局模型发送给全部客户端;步骤七:DDQN智能体获得奖励,并进入下一时刻的状态;将当前的智能体运动交互轨迹存入记忆存储器,并从记忆存储器中抽样一批经验对DDQN智能体的网络进行训练;重复步骤二到步骤七,直至全局模型测试精度达到目标要求。
[0006]所述步骤一中的具体情况如下:一开始,服务器对全局模型的参数进行随机初始化,并将全局模型的参数发给所有的N个客户端;服务器对DDQN智能体部分的评估网络、实际网络和记忆存储器进行初始化;其中,评估网络用来评估动作,目标网络用来确定动作的价值,记忆存储器用来存放联邦学习过程中的经验,从而对智能体进行训练。
[0007]所述步骤二中的具体情况如下:服务器使用本地数据集对接收到的全局模型进行一个epoch的本地训练,使用交叉熵损失函数计算本地训练损失值。
[0008]所述步骤三中的具体情况如下:客户端将本地计算得到损失值上传给服务器,并停止本地训练任务。
[0009]所述步骤四中的具体情况如下:(1)服务器收集每个客户端上传的本地训练损失值作为状态;(2) 将状态向量s输入评估网络,得到每个客户端对应的Q值;根据贪婪策略选择Top

K个Q值对应的客户端作为本轮通信中选择的客户端子集作为动作;(3)服务器向被选择的客户端发送指令,使其进行本地训练。
[0010]所述步骤五中的具体情况如下:被选择的客户端执行随机梯度下降完成本地训练任务,更新局部模型;将更新好的局部模型参数上传到服务器。
[0011]所述步骤六中的具体情况如下:服务器收集k个客户端的局部模型参数后,使用联邦平均聚合算法更新全局模型。并将新的全局模型参数发送给每个客户端。
[0012]所述步骤七中的具体情况如下:服务器对聚合后的全局模型计算测试精度,并根据奖励函数公式计算奖励;然后进入下一时刻的状态,将本轮通信中的智能体运动交互轨迹存入记忆存
储器中;在每轮联邦学习通信中,使用Double Deep Q

Learning算法,通过从记忆存储器中随机采样一小批经验样本对智能体进行训练;由奖励值和经过衰减的目标网络的Q值之和计算获得目标值;其中表示衰减系数,为目标网络的网络权重,为评估网络的网络权重;通过使用评估网络来选择最大的Q值对应的动作,然后利用这个选择出来的动作在目标网络中计算相应的Q值;智能体通过目标值与评估网络得到的Q值,计算损失函数;通过更新评估网络的权重以最小化梯度下降损失;目标网络每隔轮复制评估网络的权重;重复步骤二到步骤六的过程,直到全局模型达到目标精度为止。
[0013]与现有技术相比,本专利技术技术方案的有益效果是:(1)本专利技术的方法部署在服务器端,通过节点选择的方式减缓非独立同分布数据对联邦学习收敛的影响,不会为资源有限的客户端造成额外的计算、通信以及存储资源的消耗;(2)本专利技术基于深度强化学习方法,可以在动态的环境中为联邦学习选择合适的客户端进行模型聚合,在有限的通信轮次内最大限度的提高模型精度;(3)本专利技术在不同种类的非独立同分布环境中都可以有效的提高联邦学习的收敛速度。
附图说明
[0014]图1为本专利技术使用的联邦学习框架图。
[0015]图2为本专利技术所使用的基于DDQN算法的智能体训练图。
具体实施方式
[0016]对于本领域技术人员来说,附图中某些公知结构及其说明可能省略是可以理解的;下面结合附图和实施例对本专利技术的技术方案做进一步的说明。
[0017]本专利技术提供了一种基于深度强化学习的非独立同分布环境中联邦优化算法,该方法有效缓解了非独立同分布环境对联邦学习的负面影响,对联邦学习达到目标精度所需要的通信轮次有极大的。
[0018]图1为本专利技术所应用的联邦学习框架图,图2为智能体节点选择策略图。
[0019]具体的实现步骤为:Step1.1 服务器对全局模型的参数进行随机初始化,服务器对DDQN智能体部分的评估网络、实际网络和记忆存储器进行初始化;Step1.2 服务器将初始化的全局模型参数发送给所有的客户端;Step2 服务器使用本地数据集对接收到的全局模型进行一个epoch的本地训练,使用交叉熵损失函数计算本地训练损失值;
Step3 客户端将本地计算得到损失值上传给服务器,并停止本地训练任务;Step4.1 服务器收集每个客本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于深度强化学习的联邦学习优化算法,其特征在于,包括以下步骤:Step 1:服务器端进行初始化,并将全局模型发送给所有客户端;Step 2:客户端进行一轮本地训练;Step 3:客户端向服务器上传本地训练损失值;Step 4:服务器收集每个客户端上传的本地训练损失值作为状态,智能体根据当前的状态做出决策,选择一个客户端子集作为本轮联邦学习模型聚合的参与者;Step 5: 服务器向被选择的客户端发送指令;Step 6: 被选择的客户端完成剩余的本地训练任务,并将训练好的模型参数上传给服务器;Step 7: 服务器对收集的局部模型参数进行聚合并更新全局模型,并将新的全局模型发送给所有客户端;同时智能体将当前的动作

状态存入记忆存储器,并从记忆存储器中抽样一批经验对DDQN智能体的网络进行训练;重复Step 2到Step 7,直到全局模型达到目标精度。2.根据权利要求1所述的一种基于深度强化学习的联邦学习优化算法,其特征在于,所述Step 1中的具体过程如下:Step1.1服务器对全局模型的参数进行随机初始化,服务器对DDQN智能体部分的评估网络、实际网络和记忆存储器进行初始化;Step1.2 服务器将初始化的全局模型参数发送给所有的客户端。3.根据权利要求1所述的一种基于深度强化学习的联邦学习优化算法,其特征在于,所述Step 2中的具体过程如下:Step2 服务器使用本地数据集对接收到的全局模型进行一个epoch的本地训练,使用交叉熵损失函数计算本地训练损失值;4.根据权利要求1所述的一种基于深度强化学习的联邦学习优化算法,其特征在于,所述Step 3中的具体过程如下:Step3 客户端将本地计算得到损失值上传给服务器,并停止本地训练任务。5.根据权利要求1所述的一种基于深度强化学习的联邦学习优化算法,其特...

【专利技术属性】
技术研发人员:李勇孟续涛任翔麟凌海潮刘彤彤杜炜张振健
申请(专利权)人:长春工业大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1