一种联邦学习方法、装置及电子设备制造方法及图纸

技术编号:37440753 阅读:24 留言:0更新日期:2023-05-06 09:12
本申请公开了一种联邦学习方法、装置及电子设备。方法包括:向各客户端发送元模型及初始模型参数,以使各客户端基于所述元模型、初始模型参数、客户本地的训练数据以及目标补丁模型进行第一轮模型训练;从各客户端中确定出若干目标客户端,并采集各目标客户端当前第n轮训练获得的第一模型参数;基于各第一模型参数,计算获得用于进行n+1轮模型训练的第二模型参数;将第二模型参数发送给各目标客户端,以使各目标客户端基于元模型、第二模型参数、本地的训练数据及目标补丁模型进行模型再训练获得当前第二模型;直至各各客户端训练获得的当前第二模型均符合训练条件,否则重新确定出若干目标客户端。本申请能够避免训练获得模型发生偏移。型发生偏移。型发生偏移。

【技术实现步骤摘要】
一种联邦学习方法、装置及电子设备


[0001]本专利技术涉及计算机
,特别涉及一种联邦学习方法、装置及电子设备。

技术介绍

[0002]联邦学习本质上是一种分布式机器学习框架,其核心思想是通过在多个拥有本地数据的数据源之间进行分布式模型训练,在不需要交换本地个体或样本数据的前提下,仅通过交互模型中间参数进行模型联合训练,原始数据可以不出本地。
[0003]然而,现有的联邦学习方法中,虽然保证了明文数据不出本地,但存在部分客户端的模型发生偏移、导致本地模拟效果较差的问题。

技术实现思路

[0004]有鉴于此,本专利技术提供了一种联邦学习方法、装置及电子设备,主要目的在于解决目前存在现有的联邦学习方法中容易造成客户端训练获得模型发生偏移、进而导致本地模拟效果较差的问题。
[0005]为解决上述问题,本申请提供一种联邦学习方法,包括:向各客户端发送元模型以及初始模型参数,以使各客户端基于所述元模型、初始模型参数、客户端本地的训练数据以及目标补丁模型进行第一轮模型训练;从各所述客户端中确定出用于进行模型再训练的、若干目标客户端,并采集各目标客户端当前第n轮训练获得的第一模型的第一模型参数,所述n为正整数;基于各目标客户端的第一模型参数,采用预定的计算方式计算获得用于进行n+1轮模型训练的第二模型参数;将所述第二模型参数发送给各所述目标客户端,以使各所述目标客户端基于已接收的元模型、第二模型参数、客户端本地的训练数据以及目标补丁模型进行模型再训练获得当前第二模型;判断各客户端训练获得的当前第二模型是否均符合训练条件,在各客户端训练获得的当前第二模型不符合训练条件时,重新确定出用于进行模型再训练的、若干目标客户端;在各客户端训练获得第二模型符合训练条件时停止训练。
[0006]可选的,在向各客户端发送元模型以及初始模型参数之前,所述方法还包括:接收各客户端发送的各数据类型以及与各数据类型对应的各训练数据的数据标识;基于各客户端发送的数据类型以及数据标识的数量,计算获得各数据类型之间数据标识总量的数据占比;在进行每一轮模型训练之前,基于所述数据占比以及各客户端所包含的目标数据类型,从各目标数据类型对应的若干数据标识中确定若干目标数据标识,并将各所述目标数据标识发送给对应的客户端,以为各客户端重新分配用于进行模型训练的训练数据。
[0007]可选的,所述基于所述数据占比以及各客户端所包含的目标数据类型,从各目标
数据类型对应的若干数据标识中确定若干目标数据标识,具体包括:基于所述数据占比以及各客户端所包含的目标数据类型,分别确定与各客户端对应的目标数据占比;基于各客户端的目标数据占比,从对应客户端所发送的各数据类型的数据标识中确定出若干目标数据标识。
[0008]可选的,所述基于各目标客户端的第一模型参数,采用预定的计算方式计算获得用于进行n+1轮模型训练的第二模型参数,具体包括:基于各目标客户端第n轮训练获得的第一模型参数以及各目标客户端第n

1轮训练获得的历史模型参数,确定各目标客户端对应的梯度参数;基于各目标客户端的梯度参数,采用预定的计算公式计算获得目标梯度参数;基于所述目标梯度参数确定用于进行n+1轮模型训练的第二模型参数。
[0009]可选的,所述方法还包括:基于模型训练任务的任务类型、服务器数据与客户端数据的数据差异度以及服务器元模型的结构,确定映射补丁模型、残差补丁模型、内部补丁模型中的任意一种为所述目标补丁模型。
[0010]可选的,所述基于模型训练任务的任务类型、服务器数据与客户端数据的数据差异度以及服务器元模型的结构复杂度,确定映射补丁模型、残差补丁模型、内部补丁模型中的任意一种为所述目标补丁模型,具体包括:在所述任务类型为监控任务或定位任务时,确定所述映射补丁模型为所述目标补丁模型;在所述服务器数据与客户端数据的数据差异度大于预定差异度阈值时,确定所述残差补丁模型为所述目标补丁模型;在服务器元模型的结构复杂度大于预定复杂度时,确定所述内部补丁模型为所述目标补丁模型。
[0011]可选的,所述映射补丁模型包括:映射网络以及激活层;所述残差补丁模型包括:残差连接层;所述内部补丁模型包括:卷积层以及激活层。
[0012]为解决上述问题,本申请提供一种联邦学习方法,应用于各客户端,包括:接收服务端发送的元模型、初始模型参数,以基于所述元模型、初始模型参数、客户端本地的训练数据以及目标补丁模型进行第一轮模型训练;接收服务端发送的第二模型参数,所述第二模型参数是由服务端基于若干目标客户端第n轮训练获得的第一模型参数所计算获得的;基于服务端发送的所述第二模型参数、已接收的元模型、客户端本地的训练数据以及目标补丁模型进行第n+1轮模型训练。
[0013]可选的,在接收服务端发送的元模型、初始模型参数之前,所述方法还包括:将各数据类型以及与各数据类型对应的各训练数据的数据标识发送给服务端,以使服务端基于各客户端发送的数据类型以及数据标识的数量,计算获得各数据类型之间数据标识总量的数据占比;并使服务端在进行每一轮模型训练之前,基于所述数据占比以及各客户端所包含的目标数据类型,从各目标数据类型对应的若干数据标识中确定若干目标数据标识;接收服务端发送的各目标数据标识,基于各所述目标数据标识从本地的训练数据
中确定与目标数据标识对应的目标训练数据,以基于重新分配获得的所述目标训练数据进行模型训练。
[0014]可选的,所述目标补丁模型包括如下任意一种:映射补丁模型、残差补丁模型、内部补丁模型;其中,所述映射补丁模型包括:映射网络以及激活层;所述残差补丁模型包括:残差连接层;所述内部补丁模型包括:卷积层以及激活层。
[0015]为解决上述问题,本申请提供一种联邦学习装置,包括:第一发送模块,用于向各客户端发送元模型以及初始模型参数,以使各客户端基于所述元模型、初始模型参数、客户端本地的训练数据以及目标补丁模型进行第一轮模型训练;采集模块,用于从各所述客户端中确定出用于进行模型再训练的、若干目标客户端,并采集各目标客户端当前第n轮训练获得的第一模型的第一模型参数,所述n为正整数;计算模块,用于基于各目标客户端的第一模型参数,采用预定的计算方式计算获得用于进行n+1轮模型训练的第二模型参数;第二发送模块,用于将所述第二模型参数发送给各所述目标客户端,以使各所述目标客户端基于已接收的元模型、第二模型参数、客户端本地的训练数据以及目标补丁模型进行模型再训练,获得当前第二模型;判断模块,用于判断各客户端训练获得的当前第二模型是否均符合训练条件,在各客户端训练获得的当前第二模型不符合训练条件时,基于所述采集模块重新确定出用于进行模型再训练的、若干目标客户端;在各客户端训练获得第二模型符合训练条件时停止训练。
[0016]为解决上述问题,本申请提供一种联邦学习装置,包括:接收模块以及模型训练模块;所述接收模块用于,接收服务端发送的元模型、初始模型参数,以及用于接收服务端发送的第二模型参数,所述第二模型参数是由服务端基于若干目标客户本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种联邦学习方法,应用于服务端,其特征在于,包括:向各客户端发送元模型以及初始模型参数,以使各客户端基于所述元模型、初始模型参数、客户端本地的训练数据以及目标补丁模型进行第一轮模型训练;从各所述客户端中确定出用于进行模型再训练的、若干目标客户端,并采集各目标客户端当前第n轮训练获得的第一模型的第一模型参数,所述n为正整数;基于各目标客户端的第一模型参数,采用预定的计算方式计算获得用于进行n+1轮模型训练的第二模型参数;将所述第二模型参数发送给各所述目标客户端,以使各所述目标客户端基于已接收的元模型、第二模型参数、客户端本地的训练数据以及目标补丁模型进行模型再训练获得当前第二模型;判断各客户端训练获得的当前第二模型是否均符合训练条件,在各客户端训练获得的当前第二模型不符合训练条件时,重新确定出用于进行模型再训练的、若干目标客户端;在各客户端训练获得第二模型符合训练条件时停止训练。2.如权利要求1所述的方法,其特征在于,在向各客户端发送元模型以及初始模型参数之前,所述方法还包括:接收各客户端发送的各数据类型以及与各数据类型对应的各训练数据的数据标识;基于各客户端发送的数据类型以及数据标识的数量,计算获得各数据类型之间数据标识总量的数据占比;在进行每一轮模型训练之前,基于所述数据占比以及各客户端所包含的目标数据类型,从各目标数据类型对应的若干数据标识中确定若干目标数据标识,并将各所述目标数据标识发送给对应的客户端,以为各客户端重新分配用于进行模型训练的训练数据。3.如权利要求2所述的方法,其特征在于,所述基于所述数据占比以及各客户端所包含的目标数据类型,从各目标数据类型对应的若干数据标识中确定若干目标数据标识,具体包括:基于所述数据占比以及各客户端所包含的目标数据类型,分别确定与各客户端对应的目标数据占比;基于各客户端的目标数据占比,从对应客户端所发送的各数据类型的数据标识中确定出若干目标数据标识。4.如权利要求1所述的方法,其特征在于,所述基于各目标客户端的第一模型参数,采用预定的计算方式计算获得用于进行n+1轮模型训练的第二模型参数,具体包括:基于各目标客户端第n轮训练获得的第一模型参数以及各目标客户端第n

1轮训练获得的历史模型参数,确定各目标客户端对应的梯度参数;基于各目标客户端的梯度参数,采用预定的计算公式计算获得目标梯度参数;基于所述目标梯度参数确定用于进行n+1轮模型训练的第二模型参数。5.如权利要求1所述的方法,其特征在于,所述方法还包括:基于模型训练任务的任务类型、服务器数据与客户端数据的数据差异度以及服务器元模型的结构,确定映射补丁模型、残差补丁模型、内部补丁模型中的任意一种为所述目标补丁模型。6.如权利要求5所述的方法,其特征在于,所述基于模型训练任务的任务类型、服务器数据与客户端数据的数据差异度以及服务器元模型的结构复杂度,确定映射补丁模型、残
差补丁模型、内部补丁模型中的任意一种为所述目标补丁模型,具体包括:在所述任务类型为监控任务或定位任务时,确定所述映射补丁模型为所述目标补丁模型;在所述服务器数据与客户端数据的数据差异度大于预定差异度阈值时,确定所述残差补丁模型为所述目标补丁模型;在服务器元模型的结构复杂度大于预定复杂度时,确定所述内部补丁模型为所述目标补丁模型。7.如权利要求5所述的方法,其特征在于,所述映射补丁模型包括:映射网络以及激活层;所述残差补丁模型包括:残差连...

【专利技术属性】
技术研发人员:谢翀陈永红兰鹏罗伟杰赵豫陕
申请(专利权)人:深圳前海环融联易信息科技服务有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1