当前位置: 首页 > 专利查询>浙江大学专利>正文

一种基于无数据蒸馏的联邦学习聚合方法和系统技术方案

技术编号:38576733 阅读:11 留言:0更新日期:2023-08-26 23:24
本发明专利技术公开了一种基于无数据蒸馏的联邦学习聚合方法和系统,属于联邦学习和隐私保护研究领域。该方法通过引入分布式生成式对抗网络来实现无数据的联邦知识蒸馏,从而解决基于参数平均的联邦学习方法所存在的不支持模型异构、隐私泄露以及基于知识蒸馏的联邦学习方法所存在的难以获取公共数据集的问题。该方法主要包括本地对抗训练、生成式对抗网络聚合、联邦蒸馏三个步骤。本发明专利技术的方法提高了预测准确性,尤其是在一些极端的非独立同分布场景下(如客户端类别极度不均衡、数据量极度不均衡等)的预测准确性。同时,本发明专利技术的方法相较于现有的联邦学习方法实现了对客户端模型异构的支持,提高了隐私保护能力,增强了泛化性能。增强了泛化性能。增强了泛化性能。

【技术实现步骤摘要】
一种基于无数据蒸馏的联邦学习聚合方法和系统


[0001]本专利技术属于深度学习研究领域,具体涉及一种基于无数据蒸馏的联邦学习聚合方法和系统。

技术介绍

[0002]重视个人数据的安全以及强调对个人信息的保护已经成为世界性的趋势,现有的大部分深度学习方法需要有大数据的支撑,传统的集中数据进行训练的方案不再适用于数据保护的新场景,联邦学习(FL)作为一种分布式的机器学习算法为这个问题提供了新的解决方案。
[0003]在实际的工业场景中,客户端之间的计算能力、存储能力及通信能力往往存在较大的差异,且各自所拥有的数据分布及训练模型都存在异构的现象,这给现有的联邦学习算法带来了巨大的挑战。
[0004]部分研究面向数据异构或是模型异构的单一场景设计了解决方案,但是并未将两者综合起来进行考虑,很难得到有效的应用。另外,这些方法有着很强的假设限制,在现实场景中很难得到有效的满足。因此,设计一种能够同时适用于模型异构和数据异构问题的泛化性能更强的联邦学习聚合方法是本领域亟待解决的技术问题。

技术实现思路

[0005]为了解决上述技术问题,本专利技术提供了一种基于无数据蒸馏的联邦学习聚合方法和系统,提升了现有联邦学习的泛化性能、精度以及隐私保护能力。
[0006]为实现上述目的,本专利技术的技术方案是:
[0007]第一方面,本专利技术提出了一种基于无数据蒸馏的联邦学习聚合方法,包括以下步骤:
[0008]步骤1:由服务端定义生成器和判别器网络结构并发送给每个参与联邦学习的客户端;所述的客户端定义本地分类器网络结构;
[0009]步骤2:各客户端利用本地私有数据对生成器、判别器和本地分类器进行若干轮次的三方对抗训练;将训练完成后的生成器和判别器参数反馈至服务端;
[0010]步骤3:服务端接收所有客户端反馈的生成器参数和判别器参数并计算全局参数,将全局参数和预定义的噪声向量、批次大小发送给各客户端;
[0011]步骤4:客户端接收全局参数后加载得到全局生成器和全局判别器,将噪声向量输入全局生成器,得到生成样本,将生成样本输入本地分类器得到软标签,并将软标签反馈至服务端;
[0012]步骤5:服务端接收所有客户端的软标签,计算各客户端的全局平均软标签并发回给对应的客户端;
[0013]步骤6:客户端根据接收到的全局平均软标签训练本地分类器;
[0014]步骤7:重复步骤2

6,直至本地分类器收敛。
[0015]进一步的,所述的客户端根据本地的个性化训练任务自定义本地分类器模型结构。
[0016]进一步的,步骤6中,各客户端之间通过服务端回传的同一批生成样本的软标签进行知识蒸馏,每一个客户端接收到的全局平均软标签为除自身之外的其他客户端生成的软标签的均值。
[0017]进一步的,所述的蒸馏操作的损失为:
[0018][0019]其中,L表示蒸馏损失,p(z)表示噪声向量z服从的正态分布,表示生成标签服从的均匀分布,G表示全局生成器,C
k
表示客户端k的本地分类器,CE表示交叉熵,KL表示Kullback

Leibler散度,表示客户端k的全局平均软标签。
[0020]第二方面,本专利技术提出了一种基于无数据蒸馏的联邦学习聚合系统,用于实现上述的联邦学习聚合方法,包括:
[0021]服务端,其用于定义生成器和判别器网络结构;根据所有客户端反馈的生成器参数和判别器参数计算全局参数;以及根据客户端发送的软标签计算各客户端的全局平均软标签;
[0022]客户端,其用于根据本地的个性化训练任务自定义本地分类器模型结构;以及接收服务端发送的噪声向量和全局生成器,利用全局生成器得到生成样本,将生成样本输入本地分类器得到软标签;
[0023]本地训练模块,其置于每一个客户端内,用于利用本地私有数据对生成器、判别器和本地分类器进行若干轮次的三方对抗训练;以及,利用全局平均软标签对本地分类器进行蒸馏操作;
[0024]数据传输模块,其用于服务端与客户端时间的信息传输,包括服务端向客户端发送初始化生成器网络结构、初始化判别器网络结构、全局参数、预定义的噪声向量、批次大小和全局平均软标签;以及,客户端向服务端发送本地训练完成后的生成器和判别器参数和软标签。
[0025]相较于现有技术,本专利技术具有以下有益效果:本专利技术通过引入分布式生成式对抗网络以及知识蒸馏解决了基于参数平均的联邦学习方法所存在的不支持模型异构、隐私泄露以及基于知识蒸馏的联邦学习方法所存在的难以获取公共数据集的问题。同时,本专利技术在各种极端的非独立同分布联邦学习场景中,尤其是在一些特殊情况下(如客户端类别极度不均衡、数据量极度不均衡等),相较于现有的联邦学习聚合方法均取得了更加优异的结果。
附图说明
[0026]图1为本专利技术实施例示出的基于无数据蒸馏的联邦学习聚合方法的流程示意图。
[0027]图2为本专利技术实施例示出的本地三方对抗训练的方法示意图。
[0028]图3为本专利技术实施例示出的相较于现有的联邦学习方法在类别极度不均衡场景下的计算精度对比示意图。
[0029]图4为本专利技术实施例示出的相较于现有的联邦学习方法在数量极度不均衡场景下的计算精度对比示意图。
具体实施方式
[0030]下面结合附图对本专利技术做进一步的说明。
[0031]本专利技术提出了一种基于无数据蒸馏的联邦学习聚合方法,通过引入分布式生成式对抗网络来实现无数据的联邦知识蒸馏,从而解决基于参数平均的联邦学习方法所存在的不支持模型异构、隐私泄露以及基于知识蒸馏的联邦学习方法所存在的难以获取公共数据集的问题。该方法主要包括本地对抗训练、生成式对抗网络聚合、联邦蒸馏三个步骤。本专利技术可用于人脸活体检测等联邦学习的应用场景。
[0032]本专利技术提出的基于无数据蒸馏的新型联邦学习聚合方法,包括以下步骤:
[0033]步骤1:服务端定义联邦学习所用的生成器和判别器网络结构,每个参与联邦学习的客户端定义个性化的本地分类器网络结构;服务端将生成器和判别器发送给参与联邦学习的客户端。生成器和判别器可以采用任意的生成式对抗网络结构,分类器可以采用任意的分类网络结构。
[0034]步骤2:客户端下载服务端的生成器和判别器模型,并和本地分类器一起在本地的私有数据上进行若干轮次的三方对抗训练。
[0035]如图2所示,将服从正态分布的噪声向量z~p(z)和服从均匀分布的生成标签作为生成器G
k
的输入,得到服从真实样本分布的生成样本将真实样本和生成样本作为判别器D
k
的输入,判别器在理想情况下要将真实样本分类为真,并输出A表示接受,将生成样本分类为假,并输出R表示拒绝,根据分类结果计算损失;利用真实样本(x,y)和生成样本作为分类器C
k
的输入,根据分类结果计算损失;根据总损失执行三方对抗训练,梯度反向传播更新参数。
[0036]本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于无数据蒸馏的联邦学习聚合方法,其特征在于,包括以下步骤:步骤1:由服务端定义生成器和判别器网络结构并发送给每个参与联邦学习的客户端;所述的客户端定义本地分类器网络结构;步骤2:各客户端利用本地私有数据对生成器、判别器和本地分类器进行若干轮次的三方对抗训练;将训练完成后的生成器和判别器参数反馈至服务端;步骤3:服务端接收所有客户端反馈的生成器参数和判别器参数并计算全局参数,将全局参数和预定义的噪声向量、批次大小发送给各客户端;步骤4:客户端接收全局参数后加载得到全局生成器和全局判别器,将噪声向量输入全局生成器,得到生成样本,将生成样本输入本地分类器得到软标签,并将软标签反馈至服务端;步骤5:服务端接收所有客户端的软标签,计算各客户端的全局平均软标签并发回给对应的客户端;步骤6:客户端根据接收到的全局平均软标签训练本地分类器;步骤7:重复步骤2

6,直至本地分类器收敛。2.根据权利要求1所述的基于无数据蒸馏的联邦学习聚合方法,其特征在于,所述的客户端根据本地的个性化训练任务自定义本地分类器模型结构。3.根据权利要求1所述的基于无数据蒸馏的联邦学习聚合方法,其特征在于,在步骤2所述的三方对抗训练中,生成器的损失函数为:其中,表示客户端k的生成器损失,x表示本地真实样本,表示客户端k的本地真实样本分布,p(z)表示噪声向量z服从的正态分布,表示生成标签服从的均匀分布,G
k
表示客户端k的生成器,D
k
表示客户端k的判别器。4.根据权利要求1所述的基于无数据蒸馏的联邦学习聚合方法,其特征在于,在步骤2所述的三方对抗训练中,判别器的损失函数为:其中,表示客户端k的判别器损失,p(z)表示噪声向量z服从的正态分布,表示生成标签服从的均匀分布,G
k
表示客户端k的生成器,C
k
表示客户端k的本地分类器,CE表示交叉熵。5.根据权利要求1所述的基于无数据蒸馏的联邦学习聚合方法,其特征在于,在步骤2所述的三方对抗训练中,本地分类器的损失函数为:其中,表示客户端k的本地分类器损失,表示客户端k的本地真实样本分布,x、y表示本地真实样本和标签,p(z)表示噪声向量z服从的正态分布,表示生成标签服从的...

【专利技术属性】
技术研发人员:吴超张真源李皓
申请(专利权)人:浙江大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1