一种横向联邦学习方法、装置及存储介质制造方法及图纸

技术编号:34125831 阅读:11 留言:0更新日期:2022-07-14 14:11
本发明专利技术提供一种横向联邦学习方法、装置及存储介质。其中方法主要包括:中心服务器选择初始联邦学习模型和参数并下发给各个客户端;各客户端分别基于各自本地训练数据同时开始模型训练,进而得到本地的学习模型;各客户端之间采用循环通信方式将本地的学习模型发送至下一个客户端进行训练,所有客户端完成训练后更新循环通信顺序,继续传递模型并训练直至达到设定的训练次数;各客户端将最后训练的模型发送至中心服务器进行聚合,而后使用聚合结果更新联邦学习模型直至损失函数收敛,训练完成。本发明专利技术在训练传递模型或者梯度参数时,首先进行环型模型传播,再进行星型传播到服务器,从而构建了一种新的横向联邦学习模型训练架构。架构。架构。

【技术实现步骤摘要】
一种横向联邦学习方法、装置及存储介质


[0001]本专利技术涉及人工智能
,具体而言,尤其涉及一种横向联邦学习方法、装置及存储介质。

技术介绍

[0002]随着诸如电话,传感器和可穿戴设备之类的计算基板的功能和普及度的增长,与将数据移至数据中心相比,在分布式设备的网络中本地学习统计模型越来越有吸引力。这种技术被称为联邦学习。联邦学习定义如下:在进行机器学习的过程中,各参与方可借助其他方数据进行联合建模,各方无需共享数据资源,即数据不出本地的情况下,进行联合数据训练,建立共享的机器学习模型,而横向联邦学习则是在各个用户数据类型相似但是用户群体不同的情况下所进行的联邦学习算法。
[0003]联邦学习主要分为横向联邦学习和纵向联邦学习。在多个数据集的用户特征重叠较多而用户重叠较少的情況下,我们把数据集按照横向(即用户维度)切分,并取出多方用户特征相同而用户不完全相同的那部分数据进行训练,这种方法叫做横向联邦学习;在多个数据集的用户重叠较多而用户特征重疊较少的情况下,我们把数据集按照纵向(即特征维度)切分,并取出多方用户相同而用户特征不完全相同的那部分数据进行训练,这种方法叫做纵向联邦学习。
[0004]横向联邦学习的应用场景是用户不同但是用户数据结构相同,例如不同地区的医院有着不同的患者群体,但是其CT图像等则有着相同的数据形式。横向联邦学习通过使用模型或者梯度传输来达到不交换原始数据进行训练的目的。
[0005]目前横向联邦学习的训练架构有两种,一种是星型架构,在联邦网络中将多方用户参数汇总到中央服务器进行聚合。星型架构中多个用户并发运行,计算机算力利用率较高,但是在对于与no

iid数据上的分布式运算上,鲁棒性较差,常常会因为各模型权重的不同导致聚合模型精度大幅下降。另一种是环状架构,各用户形成环状结构,互相传递参数,使用从上家用户得到的模型参数来进行训练。环形结构有着较强的训练鲁棒性,在不同的数据集上训练时可以获得更强的适应性,但是在某一用户进行训练时,剩余的所有用户均需要进行等待,训练效率极低,在整个环形网络中只有一台机器的算力在被利用,导致训练时间太长。

技术实现思路

[0006]鉴于现有技术的不足,本专利技术提供一种横向联邦学习方法、装置及存储介质。本专利技术主要在横向联邦学习中融合使用星型架构和环型架构来进行训练。由于本方法融合了增量学习的训练方法,与星型训练结构相比提升了训练的鲁棒性;而因为本专利技术特殊的并行循环训练机制,在训练过程中也大大减少了环形训练架构的计算机算力闲置问题。相比于传统参数计算方法,本算法在应对非独立同分布(IID)数据的情况时可以取得较好的效果。
[0007]本专利技术采用的技术手段如下:
[0008]一种横向联邦学习方法,包括:
[0009]S1、中心服务器根据训练任务以及数据样本类型选择初始联邦学习模型和参数,并且将所述初始联邦学习模型和权重参数下发给所有参与联邦学习的各个客户端;
[0010]S2、各客户端分别基于各自本地训练数据同时开始模型训练,进而得到本地的学习模型;
[0011]S3、各客户端之间采用循环通信方式将本地的学习模型发送至下一个客户端,并基于所述下一个客户端的本地训练数据进行训练,所有客户端完成训练后更新循环通信顺序,继续传递模型并训练直至达到设定的训练次数;
[0012]S4、各客户端将最后训练的模型发送至中心服务器进行聚合,而后使用聚合结果更新联邦学习模型,反复执行S2

S4直至损失函数收敛,训练完成。
[0013]进一步地,S3、各客户端之间采用循环通信方式将本地的学习模型发送至下一个客户端,并基于所述下一个客户端的本地训练数据进行训练,所有客户端完成训练后更新循环通信顺序,继续传递模型并训练直至达到设定的训练次数,包括:
[0014]S301、中心服务器随机生成链表s,用于存储各客户端的训练顺序,将所述链表s下发到各客户端;
[0015]S302、各客户端按照所述链表s中的顺序将本地的学习模型发送至下一个客户端,链表s中的最后一个客户端将其本地模型发送给链表s中的第一个客户端;
[0016]S303、各客户端基于本地的训练数据对接收到的模型继续训练;
[0017]S304、中心服务器随机更新链表s,并将更新后的链表s下发到各客户端;
[0018]S305、反复执行S302

S304,直到到达设定的训练次数。
[0019]进一步地,S3中所述设定的训练次数为客户端个数的整数倍。
[0020]进一步地,S1中中心服务器首次下发的权重参数为随机权重。
[0021]本专利技术还提供了一种横向联邦学习装置,用于实现上述任意一项所述的横向联邦学习方法,包括:
[0022]初始化模块,用于通过中心服务器根据训练任务以及数据样本类型选择初始联邦学习模型和参数,并且将所述初始联邦学习模型和权重参数下发给所有参与联邦学习的各个客户端;
[0023]初始训练模块,用于通过各客户端基于各自本地训练数据同时开始模型训练,进而得到本地的学习模型;
[0024]循环训练模块,用于通过各客户端之间的循环通信方式将本地的学习模型发送至下一个客户端,并基于所述下一个客户端的本地训练数据进行训练,所有客户端完成训练后更新循环通信顺序,继续传递模型并训练直至达到设定的训练次数;
[0025]模型聚合模块,用于通过各客户端将最后训练的模型发送至中心服务器进行聚合,而后使用聚合结果更新联邦学习模型,反复执行循环训练过程直至损失函数收敛,训练完成。
[0026]本专利技术还提供了一种存储介质,所述存储介质包括存储的程序,其中,所述程序运行时,执行上述任一项权利要求所述的横向联邦学习方法。
[0027]较现有技术相比,本专利技术具有以下优点:
[0028]本专利技术提供了一种横向联邦学习模型训练架构,更新了横向联邦模型传递模式。
具体在训练传递模型或者梯度参数时,首先进行环型模型传播,再进行星型传播到服务器,服务器端进行模型聚合从而得到联邦学习全局模型,完成一轮训练。之后循环进行训练,直到模型收敛。本算法受到了增量学习的启发,在各个客户端中同步循环进行模型训练。通过同步开始的训练模式,充分利用了所有用户的本地算力,避免环状连接上所有客户端都在等待正在训练的客户端的情况所造成的算力浪费。在处理非IID数据时,与传统算法比较可获得更好更稳定的效果。
附图说明
[0029]为了更清楚地说明本专利技术实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图做以简单地介绍,显而易见地,下面描述中的附图是本专利技术的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
[0030]图1为本专利技术横向联邦学习训练架构。
[0031]图2为本专利技术横向联邦学习方法流程示意图。
[0032]图3为本发本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种横向联邦学习方法,其特征在于,包括:S1、中心服务器根据训练任务以及数据样本类型选择初始联邦学习模型和参数,并且将所述初始联邦学习模型和权重参数下发给所有参与联邦学习的各个客户端;S2、各客户端分别基于各自本地训练数据同时开始模型训练,进而得到本地的学习模型;S3、各客户端之间采用循环通信方式将本地的学习模型发送至下一个客户端,并基于所述下一个客户端的本地训练数据进行训练,所有客户端完成训练后更新循环通信顺序,继续传递模型并训练直至达到设定的训练次数;S4、各客户端将最后训练的模型发送至中心服务器进行聚合,而后使用聚合结果更新联邦学习模型,反复执行S2

S4直至损失函数收敛,训练完成。2.根据权利要求1所述的一种横向联邦学习方法,其特征在于,S3、各客户端之间采用循环通信方式将本地的学习模型发送至下一个客户端,并基于所述下一个客户端的本地训练数据进行训练,所有客户端完成训练后更新循环通信顺序,继续传递模型并训练直至达到设定的训练次数,包括:S301、中心服务器随机生成链表s,用于存储各客户端的训练顺序,将所述链表s下发到各客户端;S302、各客户端按照所述链表s中的顺序将本地的学习模型发送至下一个客户端,链表s中的最后一个客户端将其本地模型发送给链表s中的第一个客户端;S303、各客户端基于本地的训练数据对接收到的模型继续训练;S304、中心服务器随机更新链表s,并将更...

【专利技术属性】
技术研发人员:申岩王湾湾黄一珉何浩刘航付海燕郭艳卿
申请(专利权)人:深圳市洞见智慧科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1