一种多服务器加速分割学习模型训练速度的方法技术

技术编号:38733987 阅读:14 留言:0更新日期:2023-09-08 23:22
本发明专利技术提出一种多服务器加速分割学习模型训练速度的方法,属于分布式深度学习领域;具体为:每个客户端将数据划分为数据1与数据2;客户端a使用其数据1与服务器s1进行训练,并更新模型给客户端b,s1更新其模型;客户端b使用其数据1与s1按上述方法训练,同时,客户端a使用其数据2与服务器s2按上述方法训练;客户端c继续按上述方法训练,同时,客户端b将得到的两个客户端模型聚合,基于此使用其数据2与s2继续训练。持续上述训练过程直至本轮训练结束,之后服务器通过客户端的聚合转发得到全局服务器模型,客户端通过模型转发得到最新模型,用于下一轮训练;直至完成实验设定的训练轮数。本发明专利技术加速了训练速度同时保护了模型的隐私。隐私。隐私。

【技术实现步骤摘要】
一种多服务器加速分割学习模型训练速度的方法


[0001]本专利技术属于分布式深度学习领域,涉及分割学习的模型效率和隐私安全,具体是一种多服务器加速分割学习模型训练速度的方法。

技术介绍

[0002]分割学习模型的特点之一是模型的更新方式会导致多个客户端顺序执行。
[0003]在传统多用户分割学习系统中,同一时间只有一个客户端与服务器交互进行训练;与联邦学习中多客户端并行训练的模型相比,会导致模型整体运行效率大大降低。
[0004]目前针对分割学习运行效率问题的解决方法主要是通过服务器聚合、客户端局部并行或集成学习等。如文献1:Thapa C,Arachchige P C M,Camtepe S,et al.Splitfed:When federated learning meets split learning[J].arXiv preprint arXiv:2004.12088,2020.提出用单独的聚合服务器来聚合客户端侧模型,其本质是一个分割学习与联邦学习结合起来的框架,称为SplitFed。
[0005]在该框架中,客户端和主服务器之间通过模型的分割共同执行训练过程,另外还引入了Fed服务器来对客户端的模型进行聚合操作,以此来结合分割学习和联邦学习的优势之处。通过引入分割学习对完整的模型进行分割,使部分算力不足的客户端能够参与训练,同时也有着更好的模型隐私性。通过引入Fed服务器对客户端模型进行聚合,使客户端部分可以并行训练,保留了联邦学习的优势。但是,该框架违背了分割学习的主要目标,即避免服务器与客户端之间的模型共享。
[0006]文献2:Gao Y,Kim M,Thapa C,et al.Evaluation and optimization of distributed machine learning techniques for internet of things[J].IEEE Transactions on Computers,2021,71(10):2538

2552.针对SplitFed提出了更为通用的框架generalized splitfed learning(SFLG);SFLG是在SplitFed的SFLV2的基础之上做出了改进。
[0007]其主要改进是在主服务器端进行分组,规定一个组接收来自几个客户端的数据并对它们进行训练,在每个组内部训练是顺序进行的,然后组与组之间并行地进行模型聚合,以此来提升训练速度。同SplitFed框架一样,该方法依然使用服务器接收了客户端的模型,不利于模型的隐私安全。
[0008]文献3:Jeon J,Kim J.Privacy

sensitive parallel split learning[A].//2020International Conference on Information Networking(ICOIN)[C].IEEE,2020:7

9.)认为分割学习顺序训练的方式容易产生过拟合现象,因此提出一种并行训练的框架,来防止因各客户端训练顺序和数据集大小的差异而产生的过拟合现象。
[0009]该方法将客户端本地数据拆分成多个小批量数据,其中每个批量大小都与其本地数据集大小成正比。所有客户端先同步初始化模型,在每一个轮次的训练中,每个客户端都在一个小批量的数据上训练并前向传播至分割层,将分割层输出交给服务器。服务器进行前向传播和反向传播,每个客户端也根据服务器传递来的损失值进行反向传播,得到本地
模型的梯度更新。然后每个客户端都将本地梯度更新再发送给服务器,服务器对收到的梯度进行聚合,并将聚合完毕的梯度信息再发送给客户端进行模型更新。
[0010]经过这样一个过程,所有客户端的模型数据将是相同的。但是,该方法依然使用了服务器聚合客户端模型的方法,会对模型的隐私安全造成一定威胁。

技术实现思路

[0011]针对上述问题,本专利技术提出了一种多服务器加速分割学习模型训练速度的方法,首先对客户端数据进行分割,将客户端不同的数据交给不同的服务器进行训练,使模型在保障隐私的情况下加快了训练速度,达到了提升收敛速度的技术效果。
[0012]所述多服务器加速分割学习模型训练速度的方法,具体步骤如下:
[0013]步骤一、搭建包含客户端和服务器的通信场景,每个客户端将各自的数据随机划分为第一部分数据和第二部分数据;
[0014]服务器的数量与客户端的数据份数相同;
[0015]步骤二、针对第一个客户端a,首先使用其第一部分数据a_1在本地客户端初始模型W
C
上进行前向传播训练,得到分割层的输出数据发送给第一个服务器s1。
[0016]步骤三、第一个服务器s1在本地服务器初始模型W
S
上对分割层输出数据继续训练,并反向传播,将返回给分割层的梯度发送给客户端a;
[0017]步骤四、客户端a基于返回的梯度继续进行反向传播更新其初始模型W
C
为并发送给下一个客户端b。
[0018]同时,服务器s1更新其初始模型W
S

[0019]步骤五、客户端b在得到模型后,使用其第一部分数据b_1进行前向传播训练,得到分割层输出数据发送给服务器s1;同时,客户端a的模型使用其第二部分数据a_2进行前向传播训练,得到分割层输出数据发送给服务器s2。
[0020]步骤六、服务器s1使用客户端b发送来的分割层输出数据在本地服务器模型上进行前向传播和反向传播,将梯度发送给客户端b,客户端b继续训练得到更新后的模型并将此模型发送给客户端c,服务器s1继续将其所拥有的模型进行更新。同时,服务器s2与客户端a的反向传播,训练完成后客户端a得到模型并将其发送给客户端b,服务器s2更新其模型
[0021]步骤七、客户端c使用数据c_1在模型上与服务器s1交互训练,客户端b将得到的两份客户端模型:模型与模型进行聚合,在聚合完成的模型上使用数据b_2与服务器s2继续进行训练,得到模型并将此模型发送给下一个客户端。
[0022]步骤八、持续上述训练过程直至服务器s1将模型更新过K次,服务器s2将模型更新过K次,此时两个服务器将各自的模型发送给客户端,由客户端将服务器模型进行聚合后便得到了全局服务器模型W
S
,再转发给两个服务器用于后续下一轮的训练。最后一个客户端也将其最终的模型发送给客户端a,用于后续下一轮的训练。
[0023]K的取值与客户端数量保持一致。
[0024]步骤九、直至完成实验设定的训练轮数,最终加速了分割学习模型的训练速度。
[0025]本专利技术的优点在于:
[0026]一种多服务器加速分割学习模型训练速度的方法,在避免服务器直接接触客户端模型的情况下提高模型的收敛效率;同时,除了加速训练速度外,相比于现有其他提高训练速度的方法,本专利技术在一定程度上保护了模型的隐私。本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种多服务器加速分割学习模型训练速度的方法,其特征在于,具体步骤如下:步骤一、搭建包含客户端和服务器的通信场景,每个客户端将各自的数据随机划分为第一部分数据和第二部分数据;步骤二、针对第一个客户端a,首先使用其第一部分数据a_1在本地客户端初始模型W
C
上进行前向传播训练,得到分割层的输出数据发送给第一个服务器s1;步骤三、第一个服务器s1在本地服务器初始模型W
S
上对分割层输出数据继续训练,并反向传播,将返回给分割层的梯度发送给客户端a;步骤四、客户端a基于返回的梯度继续进行反向传播更新其初始模型W
C
为并发送给下一个客户端b;同时,服务器s1更新其初始模型W
S
为步骤五、客户端b在得到模型后,使用其第一部分数据b_1进行前向传播训练,得到分割层输出数据发送给服务器s1;同时,客户端a的模型使用其第二部分数据a_2进行前向传播训练,得到分割层输出数据发送给服务器s2;步骤六、服务器s1使用客户端b发送来的分割层输出数据在本地服务器模型上进行前向传播和反向传播,将梯度...

【专利技术属性】
技术研发人员:芦效峰李颖慧闫彩虹
申请(专利权)人:北京邮电大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1