一种联邦学习方法技术

技术编号:38341939 阅读:16 留言:0更新日期:2023-08-02 09:22
本发明专利技术实施例提供了一种联邦学习方法,所述方法包括由中心节点将初始化的联合模型分发给多个客户端作为初始的客户端模型,并由中心节点和客户端配合完成多轮联邦训练,得到最终的联合模型,其中,每轮联邦训练包括:由中心节点获取每个客户端上传的当前轮训练后的客户端模型,其中,当前轮训练后的客户端模型是利用客户端的本地训练集以最新获得的客户端模型为基础训练得到的;由中心节点基于各客户端的贡献度对多个客户端当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型,并将该联合模型分发给多个客户端作为各客户端下一轮训练的基础,其中,各客户端的贡献度基于各客户端模型分别对预设的仿真样本的分类准确率确定。分类准确率确定。分类准确率确定。

【技术实现步骤摘要】
一种联邦学习方法


[0001]本专利技术涉及计算机
,具体来说,涉及联邦学习领域,更具体来说,涉及一种联邦学习方法。

技术介绍

[0002]随着《个人信息保护法》的发布,企业在使用、管理、存储隐私数据方面的成本不断增加,使得数据变得更加难以流通和共享,各领域“数据孤岛”现象愈发突出。而联邦学习技术的出现,为解决“数据孤岛”问题提供了新的思路。联邦学习是一种分布式训练AI模型的一种技术方案,它保证隐私数据保留在本地,参与训练模型的各方通过自身的隐私数据建立联合模型。由于隐私数据并不直接参与到共享中,仅通过模型的形式共享数据的特征,数据的所有权、使用权均不受影响,缓冲了隐私数据保护和使用的矛盾,打通“数据孤岛”,为数据驱动型产业在高数据监管力度下带来新的解决方案。但与此同时,联邦学习仍然面临着众多挑战,其中之一便是难以衡量各节点对于联合模型的贡献。
[0003]在实际应用中,各节点的数据数量和数据质量往往存在较大差异,公平的联邦学习系统需要综合考虑各节点的数据数量和数据质量,以分配公平的贡献度和奖励,来维持联邦学习系统的稳定性。但是,各节点的数据并不透明,仅以模型的形式进行公开,难以获取各节点数据数量和数据质量情况,例如,传统联邦学习系统贡献度计算通常直接计算各节点模型与联合模型相似度,虽然模型能够体现数据特征,但模型参数存在可解释性弱的问题,且微小的数据扰动都会引起模型参数的巨大变动。因此,直接衡量各节点模型与联合模型相似度的传统方案虽然为联邦学习系统的贡献度提供了一种量化方案,但其可解释性较弱且联邦学习效果较差。
[0004]因此,亟需一种在各节点数据不透明的情况下公平地衡量各节点的贡献度,以基于各节点的贡献度更好地进行联邦学习的方法。

技术实现思路

[0005]因此,本专利技术的目的在于克服上述现有技术的缺陷,提供一种联邦学习方法。
[0006]本专利技术的目的是通过以下技术方案实现的:
[0007]根据本专利技术的第一方面,提供一种联邦学习方法,所述方法包括由中心节点将初始化的联合模型分发给多个客户端作为初始的客户端模型,并由中心节点和客户端配合完成多轮联邦训练,得到最终的联合模型,其中,每轮联邦训练包括:由中心节点获取每个客户端上传的当前轮训练后的客户端模型,其中,当前轮训练后的客户端模型是利用客户端的本地训练集以最新获得的客户端模型为基础训练得到的;由中心节点基于各客户端的贡献度对多个客户端当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型,并将该联合模型分发给多个客户端作为各客户端下一轮训练的基础,其中,各客户端的贡献度基于各客户端模型分别对预设的仿真样本的分类准确率确定。
[0008]在本专利技术的一些实施例中,所述每个客户端的贡献度根据该客户端的数据标签分
布情况和该客户端模型对各类别下仿真样本的分类准确率确定,其中,所述数据标签分布情况指示对应客户端的各类别下的非仿真样本的占比。
[0009]在本专利技术的一些实施例中,所述每个客户端的贡献度是该客户端的数据标签分布情况中每个类别下的非仿真样本的占比与该客户端模型对该类别的仿真样本的分类准确率的乘积之和。
[0010]在本专利技术的一些实施例中,由中心节点按照以下方式获得当前轮更新后的联合模型:每轮联邦训练前,获取最新更新的每个客户端的贡献度,比较各客户端中最新更新的每个客户端的贡献度和预设阈值的大小,剔除贡献度小于预设阈值所对应的客户端,得到当前轮参与训练的多个客户端;对当前轮参与训练的多个客户端上传的当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型。
[0011]在本专利技术的一些实施例中,所述仿真样本按照以下方式生成:获取基于生成对抗方式训练得到的经训练的生成模型;利用经训练的生成模型针对每种类别对应生成多个仿真样本,其中,每种类别下的仿真样本的数据标签向量与为该类别预设的标签向量的距离小于预设阈值,数据标签向量是基于将仿真样本输入当前轮更新后的联合模型中得到的。
[0012]在本专利技术的一些实施例中,基于所述生成对抗方式进行一轮或者多轮迭代对抗训练,每轮对抗训练包括:获取对抗生成网络,其包括生成模型和判别模型;获取第一训练集训练判别模型,得到当轮训练的判别模型,所述第一训练集包括多个第一样本和每个第一样本对应的指示其是非仿真样本的置信度标签,单个第一样本为仿真样本或者非仿真样本,该置信度标签基于将第一样本输入当轮更新后的联合模型得到的输出结果确定;将生成的仿真样本输入当轮训练的判别模型,利用判别模型对生成的仿真样本的判别损失更新生成模型的参数。
[0013]在本专利技术的一些实施例中,所述第一训练集中每个第一样本的生成方式包括:获取数据集中的非仿真样本或生成模型基于随机数生成的仿真样本;将非仿真样本或仿真样本输入当轮更新后的联合模型中,得到数据标签向量,计算数据标签向量与为多种类别中每种类别预设的标签向量的距离,得到多个距离,从多个距离中选择与数据标签向量最小的距离;获取预设的超参数,根据预设的超参数以及计算的最小距离确定第一样本的置信度标签。
[0014]在本专利技术的一些实施例中,第一样本的置信度标签按照以下方式确定:
[0015][0016]其中,H表示置信度标签,α为预设的超参数,d表示计算的非仿真样本或仿真样本输入当前轮更新后的联合模型中得到的数据标签向量与为各类别预设的标签向量的最小距离,β表示预设的超参数,γ表示预设的超参数。
[0017]在本专利技术的一些实施例中,各客户端的贡献度是每隔预设的轮次周期性更新的,或者在所有的联邦训练的轮次中指定的联邦训练的轮次被间隔更新的。
[0018]根据本专利技术的第二方面,提供一种图像分类方法,所述方法包括:获取根据本专利技术的第一方面任一项所述的联邦学习方法得到最终的联合模型,其中,联邦学习方法中,各个客户端的本地训练集中的样本数据为图像数据,标签为对应样本数据对应的类别;基于最
终的联合模型对输入的图像数据进行图像分类。
[0019]在本专利技术的一些实施例中,样本数据对应的类别为飞机、汽车、鸟、猫、鹿、狗、蛙、马、船或卡车。
[0020]根据本专利技术的第三方面,提供一种用户分类方法,所述方法包括:获取根据本专利技术的第一方面任一项所述的联邦学习方法得到最终的联合模型,其中,联邦学习方法中,各个客户端的本地训练集中的样本数据为用户特征数据,标签为对应样本数据对应的用户类别;基于最终的联合模型对输入的用户特征数据进行用户分类。
[0021]在本专利技术的一些实施例中,所述样本数据对应的用户类别为优质用户、良好用户或普通用户。
[0022]根据本专利技术的第四方面,提供一种电子设备,包括:一个或多个处理器;以及存储器,其中存储器用于存储可执行指令;所述一个或多个处理器被配置为经由执行所述可执行指令以实现本专利技术的第一方面、第二方面、第三方面中任一项所述方法的步骤。
[0023]与现有技术相比,本专利技术的优点在于:
[0024]本专利技术利用中心节点预设的仿真样本为联邦学习中各客户端的贡本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种联邦学习方法,所述方法包括由中心节点将初始化的联合模型分发给多个客户端作为初始的客户端模型,并由中心节点和客户端配合完成多轮联邦训练,得到最终的联合模型,其中,每轮联邦训练包括:由中心节点获取每个客户端上传的当前轮训练后的客户端模型,其中,当前轮训练后的客户端模型是利用客户端的本地训练集以最新获得的客户端模型为基础训练得到的;由中心节点基于各客户端的贡献度对多个客户端当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型,并将该联合模型分发给多个客户端作为各客户端下一轮训练的基础,其中,各客户端的贡献度基于各客户端模型分别对预设的仿真样本的分类准确率确定。2.根据权利要求1所述的方法,其特征在于,所述每个客户端的贡献度根据该客户端的数据标签分布情况和该客户端模型对各类别下仿真样本的分类准确率确定,其中,所述数据标签分布情况指示对应客户端的各类别下的非仿真样本的占比。3.根据权利要求2所述的方法,其特征在于,所述每个客户端的贡献度是该客户端的数据标签分布情况中每个类别下的非仿真样本的占比与该客户端模型对该类别的仿真样本的分类准确率的乘积之和。4.根据权利要求2所述的方法,其特征在于,由中心节点按照以下方式获得当前轮更新后的联合模型:每轮联邦训练前,获取最新更新的每个客户端的贡献度,比较各客户端中最新更新的每个客户端的贡献度和预设阈值的大小,剔除贡献度小于预设阈值所对应的客户端,得到当前轮参与训练的多个客户端;对当前轮参与训练的多个客户端上传的当前轮训练后的客户端模型进行聚合,得到当前轮更新后的联合模型。5.根据权利要求1所述的方法,其特征在于,所述仿真样本按照以下方式生成:获取基于生成对抗方式训练得到的经训练的生成模型;利用经训练的生成模型针对每种类别对应生成多个仿真样本,其中,每种类别下的仿真样本的数据标签向量与为该类别预设的标签向量的距离小于预设阈值,数据标签向量是基于将仿真样本输入当前轮更新后的联合模型中得到的。6.根据权利要求5所述的方法,其特征在于,基于所述生成对抗方式进行一轮或者多轮迭代对抗训练,每轮对抗训练包括:获取对抗生成网络,其包括生成模型和判别模型;获取第一训练集训练判别模型,得到当轮训练的判别模型,所述第一训练集包括多个第一样本和每个第一样本对应的指示其是非仿真样本的置信度标签,单个第一样本为仿真样本或者非仿真样本,该置信度标签基于将第一样本输入当轮更新后的联合模型得到的输出结果确定;将生成的仿真样本输入当轮训练的判别模型,利...

【专利技术属性】
技术研发人员:史红周余孙婕曾辉
申请(专利权)人:中国科学院计算技术研究所
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1