一种基于联邦学习模型的训练方法技术

技术编号:39420550 阅读:15 留言:0更新日期:2023-11-19 16:09
本发明专利技术属于联邦学习领域,提供了一种基于联邦学习模型的训练方法,包括以下步骤:S11,定义问题:确定需要解决的机器学习问题、本地数据的来源、以及参与联邦学习的设备或节点;S12,模型选择和初始化:选择相应的模型,并在所有的本地设备或节点上初始化相应的模型,并下发至所有用户端;S13,本地训练:每个本地设备或节点使用其本地数据集对初始化的模型进行训练,得到一个本地模型;S14,模型聚合:在中央服务器上聚合本地模型;本发明专利技术通过在每个本地设备或节点都可以进行本地模型的训练和更新,进一步分散计算负载,提高训练速度和效率;通过设定停止条件来控制模型更新的频率,进而避免过度拟合等问题。避免过度拟合等问题。避免过度拟合等问题。

【技术实现步骤摘要】
一种基于联邦学习模型的训练方法


[0001]本专利技术属于联邦学习领域,具体地说是一种基于联邦学习模型的训练方法。

技术介绍

[0002]联邦学习是一种新兴的机器学习技术,它使用节点(例如边缘服务器及终端设备)的本地数据集进行分布式模型训练。与传统的在远程云端服务器上上传节点原始数据的机器学习相比,联邦学习中的节点只共享模型参数而不上传原始数据,因此它能够为网络节点提供隐私保护。
[0003]中国专利CN112348199B公开了一种基于联邦学习与多任务学习的模型训练方法,通过对几个任务同时训练共同或相关的网络层,使得多个任务互相促进训练准确性,提高了模型收敛速率与泛化能力,获得目标神经网络高效训练的同时,能够保证基础数据的私密性与安全性,通过将目标神经网络模型进行了拆分更新、传输,本专利技术设计对用户隐私保护能力较传统联邦学习有所提高,目标神经网络模型中后部模型自适应的按照任务的相关性进行更新,整个模型拥有较好的个性化,并且模型加入异步的模型后半段参数共享,使得带宽富裕节点得到更有效地利用。
[0004]上述专利虽然通过对几个任务同时训练,可以达到高效训练的效果,但是上述专利在训练时只有中央服务器进行模型的聚合和更新,会增大计算负载,而且很难控制各个本地设备或节点的模型更新频率,容易出现模型过度拟合,并且在训练时直接计算所有本地数据,会增大内存的消耗等问题。
[0005]为此,本领域技术人员提出了一种基于联邦学习模型的训练方法来解决
技术介绍
提出的问题。

技术实现思路

[0006]为了解决上述技术问题,本专利技术提供一种基于联邦学习模型的训练方法,以解决现有技术中在训练时只有中央服务器进行模型的聚合和更新,会增大计算负载,而且很难控制各个本地设备或节点的模型更新频率,容易出现模型过度拟合,并且在训练时直接计算所有本地数据,会增大内存的消耗等问题。
[0007]一种基于联邦学习模型的训练方法,包括以下步骤:
[0008]S11,定义问题:确定需要解决的机器学习问题、本地数据的来源、以及参与联邦学习的设备或节点;
[0009]S12,模型选择和初始化:选择相应的模型,并在所有的本地设备或节点上初始化相应的模型,并下发至所有用户端;
[0010]S13,本地训练:每个本地设备或节点使用其本地数据集对初始化的模型进行训练,得到一个本地模型;
[0011]S14,模型聚合:在中央服务器上聚合本地模型;
[0012]S15,循环迭代:重复进行本地训练和模型聚合直到满足预设的停止条件;
[0013]S16,发布模型:停止训练后,将全局模型发布给所有本地设备或节点,使它们可以使用最新的模型进行推断或预测。
[0014]优选的,所述步骤S13中本地训练的方法包括以下步骤:
[0015]S21,下载全局模型:从中央服务器下载共享的全局模型,作为本地训练的起点;
[0016]S22,加载数据:将本地数据集加载到内存中,并将其拆分成多个小批次;
[0017]S23,初始化模型:在本地设备上初始化全局模型并设置准确率P;
[0018]S24,周期训练:使用本地数据集对初始化的模型进行训练,迭代执行多个小批次的训练,同时对模型参数进行更新并保存到本地磁盘上;
[0019]S25,检查训练结果:检查当前客户端训练的准确率P


[0020]S26,将本地模型上传到服务器:每个客户端将训练好的模型上传到服务端;
[0021]S27,更新客户端模型:服务器将新的全局模型推送到各个客户端,用于更新本地模型;
[0022]S28,重复迭代:重复以上步骤,直到模型收敛或达到预设的停止条件为止;
[0023]S29,返回本地模型:训练完成后,返回训练得到的本地模型给中央联邦服务器进行模型聚合。
[0024]优选的,所述步骤S25中检查训练结果的方法包括以下步骤:
[0025]S31,计算当前客户端的准确率P


[0026]S32,判断前客户端的准确率P

是否大于等于预定准确率P;
[0027]S33,若是,返回训练成功,将本地模型上传至中央服务器;
[0028]S34,若否,重新训练,设置训练轮次N;
[0029]S35,判断实际训练轮次N

是否大于等于训练轮次N;
[0030]S36,若是,返回训练失败;
[0031]S37,若否,重复步骤S32

S35。
[0032]优选的,所述步骤S36中训练失败的处理方法为:
[0033]S41,数据加密:将训练失败的设备的本地数据进行加密;
[0034]S42,传输本地数据:将加密的本地数据传输给备用设备或者上传至中央联邦服务器;
[0035]S43,分配数据:中央联邦服务器根据每个本地设备的处理能力以及每个本地设备剩余数据的处理数量进行分配;
[0036]S44,上传:将训练失败的设备的数据训练好的模型上传到服务端。
[0037]优选的,所述步骤S43中分配数据的分配方法为:
[0038]x_i=(total_data

failed_data)*(C_s/P_s)*(P_i/sum(P_j

C_j))*(1

D_i/max(D_j));
[0039]其中,x_i表示第i个设备需要处理的数据量;total_data表示所有待处理的数据总量;failed_data表示训练失败设备的所有数据总和;C_s表示中央服务器的处理能力;P_s表示中央服务器和其他设备的处理能力之和;表示第i个设备的处理能力;sum(P_j

C_j)表示其他设备的总处理能力减去它们已经完成的任务量之和,即剩余的可用处理能力;D_i表示第i个设备剩余未处理的数据量;max(D_j)表示所有设备中最大的未处理数据量。
[0040]优选的,所述步骤S41中数据加密的方法为同态加密。
[0041]优选的,所述步骤S41中数据加密的方法为加密叠加。
[0042]与现有技术相比,本专利技术具有如下有益效果:
[0043]1、本专利技术通过在每个本地设备或节点都可以进行本地模型的训练和更新,进一步分散计算负载,提高训练速度和效率。
[0044]2、本专利技术通过设定停止条件来控制模型更新的频率,进而避免过度拟合等问题。
[0045]3、本专利技术通过将本地数据拆分成多个小批次,可以带来以下好处:
[0046]降低内存消耗:当数据量很大时,一次性读入所有数据可能会导致内存不足,因此将数据划分成小批次可以避免这种情况发生。
[0047]提高训练速度:将数据划分成小批次可以使训练过程更加高效,因为在每个小批次中仅需计算一部分样本的梯度,从而减少了计算量。
[0048]模型收敛更本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于联邦学习模型的训练方法,其特征在于,包括以下步骤:S11,定义问题:确定需要解决的机器学习问题、本地数据的来源、以及参与联邦学习的设备或节点;S12,模型选择和初始化:选择相应的模型,并在所有的本地设备或节点上初始化相应的模型,并下发至所有用户端;S13,本地训练:每个本地设备或节点使用其本地数据集对初始化的模型进行训练,得到一个本地模型;S14,模型聚合:在中央服务器上聚合本地模型;S15,循环迭代:重复进行本地训练和模型聚合直到满足预设的停止条件;S16,发布模型:停止训练后,将全局模型发布给所有本地设备或节点,使它们可以使用最新的模型进行推断或预测。2.如权利要求1所述一种基于联邦学习模型的训练方法,其特征在于:所述步骤S13中本地训练的方法包括以下步骤:S21,下载全局模型:从中央服务器下载共享的全局模型,作为本地训练的起点;S22,加载数据:将本地数据集加载到内存中,并将其拆分成多个小批次;S23,初始化模型:在本地设备上初始化全局模型并设置准确率P;S24,周期训练:使用本地数据集对初始化的模型进行训练,迭代执行多个小批次的训练,同时对模型参数进行更新并保存到本地磁盘上;S25,检查训练结果:检查当前客户端训练的准确率P

;S26,将本地模型上传到服务器:每个客户端将训练好的模型上传到服务端;S27,更新客户端模型:服务器将新的全局模型推送到各个客户端,用于更新本地模型;S28,重复迭代:重复以上步骤,直到模型收敛或达到预设的停止条件为止;S29,返回本地模型:训练完成后,返回训练得到的本地模型给中央联邦服务器进行模型聚合。3.如权利要求2所述一种基于联邦学习模型的训练方法,其特征在于:所述步骤S25中检查训练结果的方法包括以下步骤:S31,计算当前客户端的准确率P

;S32,判断前客户端的准确率P

是否大于等于预定准确率P;S33,若是,返回训练成功...

【专利技术属性】
技术研发人员:刘睿霖程娇杜金浩张震石瑾高圣翔刘发强
申请(专利权)人:国家计算机网络与信息安全管理中心
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1