当前位置: 首页 > 专利查询>深圳大学专利>正文

联邦学习模型训练方法、装置、计算机设备及存储介质制造方法及图纸

技术编号:36965274 阅读:49 留言:0更新日期:2023-03-22 19:25
本发明专利技术实施例公开了联邦学习模型训练方法、装置、计算机设备及存储介质。所述方法包括:获取来自服务器端的全局模型梯度和部分全局模型参数,以得到初始参数;利用所述初始参数更新本地模型;随机选取部分样本数据迭代训练所述本地模型,以得到本地模型梯度和部分本地模型参数;发送本地模型梯度和部分本地模型参数至服务器端,以由服务器端更新全局模型,当全局模型未收敛时,发送所述全局模型梯度以及部分全局模型参数,并执行所述获取来自服务器端的全局模型梯度和部分全局模型参数,以得到初始参数。通过实施本发明专利技术实施例的方法可实现采用低时间成本获取性能良好的全局模型,缩短联邦学习整体训练时间。短联邦学习整体训练时间。短联邦学习整体训练时间。

【技术实现步骤摘要】
联邦学习模型训练方法、装置、计算机设备及存储介质


[0001]本专利技术涉及计算机,更具体地说是指联邦学习模型训练方法、装置、计算机设备及存储介质。

技术介绍

[0002]近年来机器学习、深度学习技术在计算机视觉、自然语言处理等领域得到了迅猛发展。特别是深度学习往往需要大量的训练数据才可以得到性能良好的深度学习模型。联邦学习是一种新的机器学习范式,其目的是保护数据隐私安全的同时解决“数据孤岛”问题,旨在让多个参与方共同训练机器学习模型,同时确保各参与方的本地数据分散化,即各参与方之间的数据不可互相访问。其中FedAvg是最常用的联邦学习算法框架,首先参与训练的客户端从服务器下载全局模型用于本地训练,其次客户端让本地模型在本地数据上进行多次迭代训练,再将本地模型的信息,如模型梯度上传至服务器,然后服务器将接收到的模型梯度加权平均后用于更新全局模型,再将新的全局模型信息发送至各客户端,最后重复上述过程,直至全局模型收敛或达到期望性能。
[0003]传统的联邦学习算法框架如FedAvg等,在本地训练中本地模型遍历本地数据至少一次,通常会遍历本地数据多次,客户端才会与服务器通进行通信并传递模型信息,这种方式会造成本地训练时间长,进而造成联邦学习整体训练时间长。特别是面向非独立同分布的训练数据场景时,即客户端间的训练数据是非独立同分布,不同客户端的本地训练数据分布与全局分布存在差异,本地模型目标的最优解与全局模型目标的最优解不一致,这种情况会阻碍联邦学习模型收敛,使其需要更多的通信轮次才能获得最优的全局模型,这种情况导致获得性能良好的全局模型会消耗更多的时间,也就是造成联邦学习整体训练时间长。
[0004]因此,有必要设计一种新的方法,实现采用低时间成本获取性能良好的全局模型,缩短联邦学习整体训练时间。

技术实现思路

[0005]本专利技术的目的在于克服现有技术的缺陷,提供联邦学习模型训练方法、装置、计算机设备及存储介质。
[0006]为实现上述目的,本专利技术采用以下技术方案:联邦学习模型训练方法,应用于一客户端,包括:
[0007]获取来自服务器端的全局模型梯度和部分全局模型参数,以得到初始参数;
[0008]利用所述初始参数更新本地模型;
[0009]随机选取部分样本数据迭代训练所述本地模型,以得到本地模型梯度和部分本地模型参数;
[0010]发送本地模型梯度和部分本地模型参数至服务器端,以由服务器端更新全局模型,当全局模型未收敛时,发送所述全局模型梯度以及部分全局模型参数,并执行所述获取
来自服务器端的全局模型梯度和部分全局模型参数,以得到初始参数。
[0011]其进一步技术方案为:所述随机选取部分样本数据迭代训练所述本地模型,以得到本地模型梯度和部分本地模型参数,包括:
[0012]将样本数据划分为若干个部分样本数据,以得到若干组数据;
[0013]随机选取一组数据迭代训练所述本地模型,以得到本地模型梯度和部分本地模型参数。
[0014]其进一步技术方案为:所述发送本地模型梯度和部分本地模型参数至服务器,以由服务器更新全局模型,当全局模型未收敛时,发送所述全局模型梯度以及部分全局模型参数,并执行所述获取来自服务器的全局模型梯度和部分全局模型参数,以得到初始参数,包括:
[0015]发送本地模型梯度和部分本地模型参数至服务器,以由服务器更新全局模型,当全局模型未收敛时,将不同客户端上传的本地模型梯度和部分本地模型参数分别实施加权平均,并利用加权平均后的模型梯度和少量模型参数更新全局模型,并执行所述获取来自服务器的全局模型梯度和部分全局模型参数,以得到初始参数。
[0016]本专利技术还提供了联邦学习模型训练方法,应用于一服务器端,包括:
[0017]初始化全局模型;
[0018]发送所述全局模型梯度和部分全局模型参数至客户端,以使得客户端根据所述全局模型梯度和部分全局模型参数更新本地模型,并随机选取部分样本数据迭代训练所述本地模型,以得到本地模型梯度和部分本地模型参数,上传本地模型梯度和部分本地模型参数至服务器端;
[0019]接收各个客户端上传的本地模型梯度和部分本地模型参数;
[0020]对各个客户端上传的本地模型梯度和部分本地模型参数分别进行加权求平均值,以得到加权平均结果;
[0021]利用所述加权平均结果更新全局模型;
[0022]判断所述全局模型是否收敛;
[0023]若所述全局模型未收敛,则执行发送所述全局模型梯度和部分全局模型参数至客户端,以使得客户端根据所述全局模型梯度和部分全局模型参数更新本地模型,并随机选取部分样本数据迭代训练所述本地模型,以得到本地模型梯度和部分本地模型参数,上传本地模型梯度和部分本地模型参数至服务器端。
[0024]其进一步技术方案为:所述发送所述全局模型梯度和部分全局模型参数至客户端,以使得客户端根据所述全局模型梯度和部分全局模型参数更新本地模型,并随机选取部分样本数据迭代训练所述本地模型,以得到本地模型梯度和部分本地模型参数,上传本地模型梯度和部分本地模型参数至服务器端,包括:
[0025]发送所述全局模型梯度和部分全局模型参数至客户端,以使得客户端根据所述全局模型梯度和部分全局模型参数更新本地模型,并将数据划分为若干个部分样本数据,以得到若干组数据;随机选取一组数据迭代训练所述本地模型,以得到本地模型梯度和部分本地模型参数,上传本地模型梯度和部分本地模型参数至服务器端。
[0026]本专利技术还提供了联邦学习模型训练装置,包括用于执行上述方法的单元。
[0027]本专利技术还提供了一种计算机设备,所述计算机设备包括存储器以及与所述存储器
相连的处理器;所述存储器用于存储计算机程序;所述处理器用于运行所述存储器中存储的计算机程序,以执行上述方法的步骤。
[0028]本专利技术还提供了一种存储介质,所述存储介质存储有计算机程序,所述计算机程序包括程序指令,所述程序指令当被处理器执行时可实现上述方法的步骤。
[0029]本专利技术与现有技术相比的有益效果是:本专利技术通过获取来自服务器端的全局模型梯度和部分全局模型参数,并利用获取的内容更新本地模型,利用少量数据对本地模型进行迭代训练,并发送训练后的本地模型梯度和部分本地模型参数至服务器端,以由服务器端进行加权求均值,并更新全局模型,实现采用低时间成本获取性能良好的全局模型,缩短联邦学习整体训练时间;由于训练数据的减少,减少客户端的计算成本,有利于部署到计算能力弱的设备上;通过客户端与服务器端间传递模型梯度和少量模型参数,如未参与梯度下降优化,从数据统计中获得的参数,使得各客户端本地模型保持同步更新,所有客户端的本地模型完全一样,确保各客户端本地模型从同一起点开始训练,减轻非独立同分布训练数据的消极影响,进而缩短联邦学习整体训练时间。
[0030]下面结合附图和具体实施例对本专利技术作进一步描述。
附图说明
[0031]为了更清楚地说明本发本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.联邦学习模型训练方法,应用于一客户端,其特征在于,包括:获取来自服务器端的全局模型梯度和部分全局模型参数,以得到初始参数;利用所述初始参数更新本地模型;随机选取部分样本数据迭代训练所述本地模型,以得到本地模型梯度和部分本地模型参数;发送本地模型梯度和部分本地模型参数至服务器端,以由服务器端更新全局模型,当全局模型未收敛时,发送所述全局模型梯度以及部分全局模型参数,并执行所述获取来自服务器端的全局模型梯度和部分全局模型参数,以得到初始参数。2.根据权利要求1所述的联邦学习模型训练方法,其特征在于,所述随机选取部分样本数据迭代训练所述本地模型,以得到本地模型梯度和部分本地模型参数,包括:将样本数据划分为若干个部分样本数据,以得到若干组数据;随机选取一组数据迭代训练所述本地模型,以得到本地模型梯度和部分本地模型参数。3.根据权利要求1所述的联邦学习模型训练方法,其特征在于,所述发送本地模型梯度和部分本地模型参数至服务器,以由服务器更新全局模型,当全局模型未收敛时,发送所述全局模型梯度以及部分全局模型参数,并执行所述获取来自服务器的全局模型梯度和部分全局模型参数,以得到初始参数,包括:发送本地模型梯度和部分本地模型参数至服务器,以由服务器更新全局模型,当全局模型未收敛时,将不同客户端上传的本地模型梯度和部分本地模型参数分别实施加权平均,并利用加权平均后的模型梯度和少量模型参数更新全局模型,并执行所述获取来自服务器的全局模型梯度和部分全局模型参数,以得到初始参数。4.联邦学习模型训练方法,应用于一服务器端,其特征在于,包括:初始化全局模型;发送所述全局模型梯度和部分全局模型参数至客户端,以使得客户端根据所述全局模型梯度和部分全局模型参数更新本地模型,并随机选取部分样本数据迭代训练所述本地模型,以得到本地模型梯度和部分本地模型参数,上传本地模型梯度和部分本地模型参数至服务器端;接收各个客户端上传的本地模型梯度和部分本地模型参数;对各个客户端上传的本地模型梯度和部分本地模型参数分别进行...

【专利技术属性】
技术研发人员:杜杰李炜刘鹏汪天富
申请(专利权)人:深圳大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1