当前位置: 首页 > 专利查询>河海大学专利>正文

一种基于联邦学习与多任务学习的模型训练方法技术

技术编号:27313625 阅读:73 留言:0更新日期:2021-02-10 09:41
本发明专利技术涉及一种基于联邦学习与多任务学习的模型训练方法,将目标神经网络中各个全连接层划为目标神经网络的后部模型,以及将其中剩余部分划为目标神经网络的前部模型;由参数服务器负责更新各目标神经网络的前部模型,交换网络中的各个工作节点终端共同负责各目标神经网络的后部模型,如此应用联邦学习框架针对目标神经网络进行训练,通过对几个任务同时训练共同或相关的网络层,使得多个任务互相促进训练准确性,提高了模型收敛速率与泛化能力,获得目标神经网络高效训练的同时,能够保证基础数据的私密性与安全性。证基础数据的私密性与安全性。证基础数据的私密性与安全性。

【技术实现步骤摘要】
一种基于联邦学习与多任务学习的模型训练方法


[0001]本专利技术涉及一种基于联邦学习与多任务学习的模型训练方法,属于数据处理


技术介绍

[0002]在机器学习领域,数据的收集处理是一大难点,随着移动设备与人的关系越来越紧密,移动设备中存在了大量的有价值的、隐私的数据。传统的数据处理模式往往是服务提供商收集用户的数据到集中在服务器,再通过自身服务器进行清洗处理,但随着相关法律的不断完善,这种方式可能会存在法律风险。
[0003]为了能够有效、安全的利用用户的数据,人们提出了联邦学习的方法,联邦学习模型可以使用户不需要上传自身数据的前提下,只需要上传用户本地训练后的梯度,就能有有效利用用户的数据进行训练,共同训练出一个统一的模型,一定程度上保护了用户的隐私安全。例如:这种框架可以用于解决安卓手机终端用户在本地更新模型的问题,其设计目标是在保障大数据交换时的信息安全、保护终端数据和个人数据隐私、保证合法合规的前提下,在多参与方或多计算结点之间开展高效率的机器学习。
[0004]但是目前的联邦学习由于要妥协带宽不均衡的现状,该架构由一台参数服务器与多台工作节点组成,同步通讯,每天只通过数次同步参数平均化进行同步更新迭代,更新频率低,在目前的联邦学习情况下,而且该架构存在慢节点等问题,整个系统训练效率的瓶颈,浪费了部分通讯资源不受限节点的带宽。这样又会导致更多的训练节点掉队,第二个是传统的联邦学习训练出一个模型去适配所有节点的数据,它并不是针对单个用户的最优解。第三个是有些任务的数据量较小不能得到准确、泛化性能好的模型,传统的联邦学习并没有针对多任务学习有所突破利用。

技术实现思路

[0005]本专利技术所要解决的技术问题是提供一种基于联邦学习与多任务学习的模型训练方法,针对目标神经网络进行前后划分,应用联邦学习框架获得模型高效训练的同时,能够保证基础数据的私密性与安全性。
[0006]本专利技术为了解决上述技术问题采用以下技术方案:本专利技术设计了一种基于联邦学习与多任务学习的模型训练方法,用于同步实现针对至少一个目标神经网络的参数化训练,并且各个目标神经网络彼此之间具有相同结构的全连接层;基于参数服务器、以及各个工作节点终端,按如下步骤A至步骤C,同步实现各目标神经网络的参数化训练;
[0007]步骤A.分别针对各目标神经网络,将其中各个全连接层划为目标神经网络的后部模型,以及将其中剩余部分划为目标神经网络的前部模型,然后进入步骤B;
[0008]步骤B.参数服务器根据各个工作节点终端的参数属性,构建由满足预设参数要求的各个工作节点终端所组成的交换网络,由参数服务器负责各目标神经网络的前部模型,交换网络中的各个工作节点终端共同负责各目标神经网络的后部模型,然后进入步骤C;
[0009]步骤C.参数服务器与交换网络中的各个工作节点终端,根据各目标模型分别所对应的样本训练数据,应用多任务学习模式,针对各目标神经网络进行参数化训练,获得训练后的各个目标神经网络。
[0010]作为本专利技术的一种优选技术方案:所述步骤B中,参数服务器通过执行参数平均化进程与网络列表管理进程,实现所述交换网络的构建、以及参数服务器对各目标神经网络前部模型的负责、与交换网络中各个工作节点终端对各目标神经网络后部模型的负责。
[0011]作为本专利技术的一种优选技术方案:所述参数服务器按如下步骤I1至步骤I10,执行参数平均化进程;
[0012]步骤I1.参数服务器接收来自各工作节点终端分别所发送的参数列表,参数列表包括工作节点终端以及其算力电量带宽同时开启接收加入交换网络申请监听线程、以及请求对交换网络列表监听线程,然后进入步骤I2;
[0013]步骤I2.参数服务器通过概率加权的方法,选择预设数量n个工作节点终端,并向该各个工作节点终端发送确认信息,然后进入步骤I3;
[0014]步骤I3.参数服务器接收来自其在步骤I2中所交互各个工作节点终端的确认信息,其中,若超时则参数服务器重发确认信息;由此构建由满足预设参数要求的各个工作节点终端所组成的交换网络,然后进入步骤I4;
[0015]步骤I4.参数服务器将各目标神经网络后部模型的参数分发给交换网络中的各个工作节点终端,然后进入步骤I5;
[0016]步骤I5.参数服务器开启针对交换网络中各个工作节点终端所接收各目标神经网络中前部模型参数的监听,然后进入步骤I6;
[0017]步骤I6.参数服务器初始化接收列表为空列表,然后进入步骤I7;
[0018]步骤I7.参数服务器接收交换网络中各个工作节点终端所发送各目标神经网络中前部模型的参数,并应用接收列表记录所接收的各个工作节点终端,然后进入步骤I8;
[0019]步骤I8.参数服务器判断所接收工作节点终端的个数是否大于是则进入步骤I10;否则进入步骤I9;
[0020]步骤I9.判断参数服务器接收各工作节点终端发送各目标神经网络中前部模型参数的时长是否超时是则返回步骤I2;否则返回步骤I7;
[0021]步骤I10.参数服务器针对所接收交换网络中各个工作节点终端发送的各目标神经网络中前部模型参数,使用平均化计算前部模型参数,并将平均化后的前部模型参数分发给交换网络中的各个工作节点终端。
[0022]作为本专利技术的一种优选技术方案:所述参数服务器按如下步骤II1至步骤II7,执行网络列表管理进程;
[0023]步骤II1.参数服务器初始化节点交换字典并进入步骤II2;
[0024]步骤II2.参数服务器开启申请加入交换网络的监听线程,然后进入步骤II3;
[0025]步骤II3.参数服务器判断是否接收到申请加入交换网络的消息,是则进入步骤II4;否则继续执行步骤II3;
[0026]步骤II4.参数服务器分别针对将所接收各个申请加入交换网络消息,将申请加入交换网络消息所对应工作节点终端的IP地址作为key加入节点交换字典中,并初始化
该key所对应的value值为然后进入步骤II5;
[0027]步骤II5.参数服务器将节点交换字典返回给其所接收各个申请加入交换网络消息对应的各个工作节点终端,并确定所接收各个申请加入交换网络消息对应工作节点终端的数量n,然后进入步骤II6;
[0028]步骤II6.参数服务器针对节点交换字典中的各个value的值,分别执行自减(1/n)进行更新,然后进入步骤II7;
[0029]步骤II7.参数服务器删除节点交换字典中value值为0的key-value,然后返回步骤II3。
[0030]作为本专利技术的一种优选技术方案:所述步骤B至步骤C的执行过程中,交换网络中各个工作节点终端分别执行发送模型参数到其他工作节点终端的进程、以及接收其他工作节点终端所发送模型参数的进程。
[0031]作为本专利技术的一种优选技术方案:所述交换网络中各个工作节点终端分别按如下步骤III1至步骤III23,执行发送模型参数到其他工作节点终端的进程;<本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于联邦学习与多任务学习的模型训练方法,用于同步实现针对至少一个目标神经网络的参数化训练,并且各个目标神经网络彼此之间具有相同结构的全连接层;其特征在于:基于参数服务器、以及各个工作节点终端,按如下步骤A至步骤C,同步实现各目标神经网络的参数化训练;步骤A.分别针对各目标神经网络,将其中各个全连接层划为目标神经网络的后部模型,以及将其中剩余部分划为目标神经网络的前部模型,然后进入步骤B;步骤B.参数服务器根据各个工作节点终端的参数属性,构建由满足预设参数要求的各个工作节点终端所组成的交换网络,由参数服务器负责各目标神经网络的前部模型,交换网络中的各个工作节点终端共同负责各目标神经网络的后部模型,然后进入步骤C;步骤C.参数服务器与交换网络中的各个工作节点终端,根据各目标模型分别所对应的样本训练数据,应用多任务学习模式,针对各目标神经网络进行参数化训练,获得训练后的各个目标神经网络。2.根据权利要求1所述一种基于联邦学习与多任务学习的模型训练方法,其特征在于:所述步骤B中,参数服务器通过执行参数平均化进程与网络列表管理进程,实现所述交换网络的构建、以及参数服务器对各目标神经网络前部模型的负责、与交换网络中各个工作节点终端对各目标神经网络后部模型的负责。3.根据权利要求2所述一种基于联邦学习与多任务学习的模型训练方法,其特征在于:所述参数服务器按如下步骤I1至步骤I10,执行参数平均化进程;步骤I1.参数服务器接收来自各工作节点终端分别所发送的参数列表,参数列表包括工作节点终端以及其算力电量带宽同时开启接收加入交换网络申请监听线程、以及请求对交换网络列表监听线程,然后进入步骤I2;步骤I2.参数服务器通过根据概率加权的方法,选择预设数量n个工作节点终端,并向该各个工作节点终端发送确认信息,然后进入步骤I3;步骤I3.参数服务器接收来自其在步骤I2中所交互各个工作节点终端的确认信息,其中,若超时则参数服务器重发确认信息;由此构建由满足预设参数要求的各个工作节点终端所组成的交换网络,然后进入步骤I4;步骤I4.参数服务器将各目标神经网络后部模型的参数分发给交换网络中的各个工作节点终端,然后进入步骤I5;步骤I5.参数服务器开启针对交换网络中各个工作节点终端所接收各目标神经网络中前部模型参数的监听,然后进入步骤I6;步骤I6.参数服务器初始化接收列表为空列表,然后进入步骤I7;步骤I7.参数服务器接收交换网络中各个工作节点终端所发送各目标神经网络中前部模型的参数,并应用接收列表记录所接收的各个工作节点终端,然后进入步骤I8;步骤I8.参数服务器判断所接收工作节点终端的个数是否大于是则进入步骤I10;否则进入步骤I9;步骤I9.判断参数服务器接收各工作节点终端发送各目标神经网络中前部模型参数的时长是否超时是则返回步骤I2;否则返回步骤I7;步骤I10.参数服务器针对所接收交换网络中各个工作节点终端发送的各目标神经网
络中前部模型参数,使用平均化计算前部模型参数,并将平均化后的前部模型参数分发给交换网络中的各个工作节点终端。4.根据权利要求2所述一种基于联邦学习与多任务学习的模型训练方法,其特征在于:所述参数服务器按如下步骤II1至步骤II7,执行网络列表管理进程;步骤II1.参数服务器初始化空节点交换字典并进入步骤II2;步骤II2.参数服务器开启申请加入交换网络的监听线程,然后进入步骤II3;步骤II3.参数服务器判断是否接收到申请加入交换网络的消息,是则进入步骤II4;否则继续执行步骤II3;步骤II4.参数服务器分别针对将所接收各个申请加入交换网络消息,将申请加入交换网络消息所对应工作节点终端的IP地址作为key加入节点交换字典中,并初始化该key所对应的value值为然后进入步骤II5;步骤II5.参数服务器将节点交换字典返回给其所接收各个申请加入交换网络消息对应的各个工作节点终端,并确定所接收各个申请加入交换网络消息对应工作节点终端的数量n,然后进入步骤II6;步骤II6.参数服务器针对节点交换字典中的各个value的值,分别执行自减(1/n)进行更新,然后进入步骤II7;步骤II7.参数服务器删除节点交换字典中value值为0的key-value,然后返回步骤II3。5.根据权利要求1所述一种基于联邦学习与多任务学习的模型训练方法,其特征在于:所述步骤B至步骤C的执行过程中,交换网络中各个工作节点终端分别执行发送模型参数到其他工作节点终端的进程、以及接收其他工作节点终端所发送模型参数的进程。6.根据权利要求5所述一种基于联邦学习与多任务学习的模型训练方法,其特征在于:所述交换网络中各个工作节点终端分别按如下步骤III1至步骤III23,执行发...

【专利技术属性】
技术研发人员:谢在鹏陈瑞锋叶保留朱晓瑞屈志昊徐媛媛
申请(专利权)人:河海大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1