System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind() 一种基于分布式蒸馏的联邦学习方法技术_技高网
当前位置: 首页 > 专利查询>复旦大学专利>正文

一种基于分布式蒸馏的联邦学习方法技术

技术编号:40171794 阅读:6 留言:0更新日期:2024-01-26 23:41
本发明专利技术公开了一种基于分布式蒸馏的联邦学习方法。本发明专利技术通过分布式知识蒸馏DKD模块,将在异质数据集上良训练的所有局部模型的知识迁移到全局模型上,提高了全局模型表现,丰富了全局模型知识,并得到下一轮局部训练更好的初始参数;本发明专利技术在已有联邦学习框架的基础上,通过优化DKD模块的目标函数使得全局模型从训练好的局部模型中学习,更好地优化了全局损失函数的上界,通过DKD模块消除了深度学习模型重参数化不变性对在参数空间求平均的影响,逼近在函数映射空间上平均的全局模型。本发明专利技术对异质数据集的联邦学习训练更加适用,可以取得更好的模型效果。当数据集十分异质时,在相同通讯代价下仍能得到更好的模型表现。

【技术实现步骤摘要】

本专利技术属于人工智能,具体的说,涉及一种基于分布式蒸馏的联邦学习方法


技术介绍

1、随着大数据与人工智能的兴起,人们隐私意识的增强,个人和机构的数据的隐私和安全成为人们越来越关心的问题,联邦学习在该背景下被提出,作为一个强调隐私保护的分布式训练框架,可以根据边端设备的计算、存储能力进行有效的模型训练,从而有效解决数据需求和隐私保护之间的矛盾。已在智慧医疗、无线通信、边端设备和推荐系统等业界场景被广泛应用。

2、联邦学习中,数据的异质性是其主要挑战之一,异质性的来源之一是局部数据集在标签分布上是异质的,另外一种则是不同来源的数据在特征上是异质的。数据的异质性会导致联邦训练收敛速度较慢,联邦模型精度较低等问题。

3、目前的联邦学习框架如图1所示,给定共k个客户端,第k个客户端上的局部数据集记为其单个样本为(x,y),在图像分类问题中,x为输入图像,y为该图像的标签,记标签的类别集合为要求局部数据集不在客户端之间或客户端与服务器之间进行共享,通过局部模型在局部数据集上的训练和模型参数或梯度等信息交流,训练全局模型,全局模型与各客户端上的模型具有相同的模型结构。其优化目标函数为:

4、

5、其中w为全局模型,pk≥0,∑k pk=1,可以用表示,lk为客户端k上局部模型平均每个训练样本的损失函数。

6、在上述目前的联邦学习框架中,局部模型局部训练结束后,上传参数到中心服务器,中心服务器对参数进行(加权)平均后得到全局参数,下发到客户端,这样的一个过程称为一个dkd轮次。该联邦学习框架仅为在每个客户端上并行地训练局部模型的过程优化。


技术实现思路

1、针对上述现有技术的不足,本专利技术通过引入网络配置的知识蒸馏模块,提出一种全新的联邦学习训练方法来解决中心化联邦学习的异质性问题。本专利技术通过分布式知识蒸馏dkd(distributed knowledge distillation,dkd)模块,将在异质数据集上良训练的所有局部模型的知识迁移到全局模型上,提高了全局模型的表现,丰富了全局模型的知识,并得到一个下一轮局部训练更好的初始参数实验证明,本专利技术对异质数据集的联邦学习训练更加适用,可以取得更好的模型效果。当数据集十分异质时,在相同通讯代价下仍能得到更好的模型表现。

2、本专利技术的技术方案具体介绍如下。

3、本专利技术提供一种基于分布式蒸馏的联邦学习方法,其通过网络配置的分布式知识蒸馏dkd模块,将在异质数据集上良训练的所有局部模型的知识迁移到全局模型上,局部模型和全局模型采用相同的神经网络模型结构;其在第t个dkd轮次训练时,包括以下步骤:

4、步骤1.dkd模块先获取在线的客户端集合这些客户端的局部模型参数被初始化为即第t-1个dkd轮次对全局模型的估计;

5、步骤2.客户端基于本地的局部数据集并行地优化训练局部模型,上传更新后的局部模型参数到服务器;

6、步骤3.服务器平均局部模型参数得到最新的全局模型,客户端下载服务器上最新的全局模型,下载平均参数;

7、步骤4.客户端进行若干轮局部的蒸馏,以本地局部模型为教师模型,以当前全局模型为学生模型,计算局部的蒸馏梯度,将局部蒸馏参数上传到服务器上进行联合蒸馏更新全局模型,下载更新的全局模型参数;

8、步骤5.更新后的全局模型作为第t个dkd轮次的算法输出,参与第t+1个dkd轮次的训练,下载本轮最终的全局参数。

9、本专利技术中,该联邦学习方法建立一个参数化的全局模型φ(x;w),其参数为w,来近似理想的φ*(x),其优化目标函数为

10、minw∫div(φ*(x),φ(x;w))p(dx);

11、其中:φ(x)为假设存在的一个全局的真实的函数映射,其使得对于任意样本x可以给出真实的标签y,即y=φ*(x)。

12、本专利技术中,服务器的全局模型的损失函数包括用于优化在每个客户端上并行地训练局部模型的过程的损失函数以及用于优化分布式知识蒸馏dkd模块的目标函数来使得全局模型从训练好的局部模型中学习的损失函数。

13、和现有技术相比,本专利技术的有益效果在于:

14、1)在已有联邦学习框架优化在每个客户端上并行地训练局部模型的过程的基础上,通过优化dkd模块的目标函数使得全局模型从训练好的局部模型中学习,更好地优化了全局损失函数的上界。

15、2)如图3所示,通过dkd模块,将在异质数据集上良训练的所有局部模型的知识迁移到全局模型上,提高了全局模型的表现,丰富了全局模型的知识,并得到下一轮局部训练更好的初始参数。

16、3)通过dkd模块,消除了深度学习模型重参数化不变性对在参数空间求平均的影响,逼近在函数映射空间上平均的全局模型。

17、4)数值实验证明,本专利技术提出的框架对异质数据集的联邦学习训练更加适用,可以取得更好的模型效果。当数据集十分异质时,本框架在相同通讯代价下仍能得到更好的模型表现。

本文档来自技高网...

【技术保护点】

1.一种基于分布式蒸馏的联邦学习方法,其特征在于,其通过网络配置的分布式知识蒸馏DKD模块,将在异质数据集上良训练的所有局部模型的知识迁移到全局模型上,局部模型和全局模型采用相同的神经网络模型结构;其在第t个DKD轮次训练时,包括以下步骤:

2.根据权利要求1所述的基于分布式蒸馏的联邦学习方法,其特征在于,该联邦学习方法建立一个参数化的全局模型Φ(x;w),其参数为w,来近似理想的模型映射Φ*(x),其优化目标函数为:

3.根据权利要求1所述的基于分布式蒸馏的联邦学习方法,其特征在于,服务器的全局模型的损失函数包括用于优化在每个客户端上并行地训练局部模型的过程的损失函数以及用于优化分布式知识蒸馏DKD模块的目标函数来使得全局模型从训练好的局部模型中学习的损失函数。

【技术特征摘要】

1.一种基于分布式蒸馏的联邦学习方法,其特征在于,其通过网络配置的分布式知识蒸馏dkd模块,将在异质数据集上良训练的所有局部模型的知识迁移到全局模型上,局部模型和全局模型采用相同的神经网络模型结构;其在第t个dkd轮次训练时,包括以下步骤:

2.根据权利要求1所述的基于分布式蒸馏的联邦学习方法,其特征在于,该联邦学习方法建立一个参数化...

【专利技术属性】
技术研发人员:卢文联李欣嘉
申请(专利权)人:复旦大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1