System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind() 基于联邦学习和知识蒸馏的多分类模型优化方法和装置制造方法及图纸_技高网

基于联邦学习和知识蒸馏的多分类模型优化方法和装置制造方法及图纸

技术编号:41207686 阅读:3 留言:0更新日期:2024-05-09 23:29
本发明专利技术属于机器学习技术领域,公开了一种基于联邦学习和知识蒸馏的多分类模型优化方法和装置,所述方法包括:响应于多分类任务,将初始化后的全局模型作为教师模型发送给全体客户端;在随机获取的活跃客户端中筛选当前边际收益最大的客户端,直至获得满足预定大小的边际收益最大的客户端子集;利用本地数据样本分别计算学生模型的预测概率和教师模型的预测概率分布;利用学生模型与教师模型输出的预测概率分布之间的KL散度损失和本地模型的损失进行训练,以得到本地模型;中央服务器接收各个客户端更新后的模型参数,通过模型聚合得到优化后的全局模型。经过优化的模型在多分类任务上实现了更好的性能,同时模型收敛速度也得到了显著提升。

【技术实现步骤摘要】

本专利技术涉及机器学习,特别涉及一种基于联邦学习和知识蒸馏的多分类模型优化方法和装置


技术介绍

1、本部分的陈述仅仅是提供了与本公开相关的
技术介绍
信息,不必然构成在先技术。

2、随着信息革命的发展,海量的数据在不断地产生,如何合理有效地利用这些数据成为一个热点方向。由于隐私政策的保护,很多数据不能被轻易的获取,数据间相互隔离,形成了一个个数据孤岛。如何建立数据孤岛间沟通的桥梁,打破数据之间的界限,成为一个热点问题,联邦学习(federated learning)为解决该问题提供了一个新的方向。

3、联邦学习是一种分布式机器学习方法,它允许多个客户端在不共享本地数据的情况下共同训练机器学习模型。在联邦学习中,每个客户端训练一个本地模型,而模型参数通过加密和安全的通信传输给中央服务器,以进行模型的全局聚合和改进。这种方式使得数据保持在本地,不必传输到中心服务器,从而有效保护了隐私。联邦学习在打破数据孤岛之间的隔离、促进跨界协作方面具有广泛的应用潜力,尤其适用于电力数据、医疗数据等敏感数据的分析。

4、在联邦学习的每一轮通信中,客户端的子集将其模型更新传回至服务器,然后服务器将它们的参数进行聚合。如何确定客户端的最佳子集目前仍是一项难题,最常见的策略是随机抽样和选择权方法(the power-of-choice approach)。随机抽样是指服务器从所有可用的客户端中随机选择一个子集,这个子集将参与每一轮的训练。其主要优点是简单性和公平性,每个客户端有平等的机会参与训练,不会发生偏向某个客户端的情况。但是,它可能无法充分考虑客户端的数据贡献,导致收敛速度慢或者需要更多轮的通信才能达到最佳性能。选择权方法尝试根据客户端的贡献或性能来有针对性地选择子集。一种常见的选择权方法是选择具有最大训练损失的客户端,其优势在于更倾向于选择那些可以提供更好模型参数更新的客户端,从而提高了收敛速度。然而,这可能会导致一些客户端频繁地被选择,从而丧失了公平性。

5、与此同时,知识蒸馏是一种模型压缩技术,旨在通过将一个复杂模型的知识转移到一个小型模型上来提高小型模型的性能。在联邦学习中,数据存储在本地设备上,无法直接传递到中心服务器,这使得模型的训练和优化变得更具挑战性。知识蒸馏通过在全局模型(教师模型)和本地模型(学生模型)之间传递知识,可以帮助较小的学生模型达到更好的性能;同时因为传递给全局模型的更新更小,减小了通信开销。目前知识蒸馏已经在联邦学习、迁移学习和模型压缩等领域得到了广泛应用。

6、因此,提供一种基于联邦学习和知识蒸馏的多分类模型优化方法和装置,以解决现有技术在客户端选择上的不足和模型收敛速度慢的问题。


技术实现思路

1、本专利技术实施例提供了一种基于联邦学习和知识蒸馏的多分类模型优化方法和装置,以解决现有技术存在的问题。为了对披露的实施例的一些方面有一个基本的理解,下面给出了简单的概括。该概括部分不是泛泛评述,也不是要确定关键/重要组成元素或描绘这些实施例的保护范围。其唯一目的是用简单的形式呈现一些概念,以此作为后面的详细说明的序言。

2、根据本专利技术实施例的第一方面,提供了一种基于联邦学习和知识蒸馏的多分类模型优化方法。

3、在一些实施例中,所述方法包括:

4、响应于多分类任务,对待优化的全局模型进行初始化,并将初始化后的全局模型作为教师模型发送给全体客户端;

5、在随机获取的活跃客户端中筛选当前边际收益最大的客户端,直至获得满足预定大小的边际收益最大的客户端子集,并将所述客户端子集作为学生模型;

6、利用本地数据样本分别计算所述学生模型的预测概率和所述教师模型的预测概率分布;

7、根据预设的知识蒸馏比例,利用学生模型与教师模型输出的预测概率分布之间的kl散度损失和本地模型的损失进行训练,以得到本地模型;

8、基于得到的本地模型,中央服务器接收各个客户端更新后的模型参数,通过模型聚合得到优化后的全局模型。

9、在一些实施例中,响应于多分类任务,对待优化的全局模型进行初始化,具体包括:

10、基于客户端上的多分类任务,中央服务器进行全局模型的初始化。

11、在一些实施例中,所述多分类任务y的表达式为:

12、y=argmax(softmax(wtx+b))

13、其中,w为可学习的参数矩阵,b为可学习的偏置项,x为输入样本的特征向量,wt矩阵w的转置矩阵。

14、在一些实施例中,应用随机贪婪算法,在随机获取的活跃客户端中筛选当前边际收益最大的客户端。

15、在一些实施例中,筛选当前边际收益最大的客户端,具体包括:

16、设定经过筛选后得到的客户端子集为m,并限制所选客户端的数量不超过l;

17、定义由所有客户端构成的集合n到子集m的映射α:n→m;

18、在联邦学习的每一轮通信中用客户端子集m代表所有客户端集合n并将其更新信息传递给中央服务器;

19、将集合n中每个客户端的梯度信息和它在集合m中的映射的梯度信息之间的差异降到最小为目的进行筛选。

20、在一些实施例中,所述本地模型的训练过程,具体包括:

21、客户端保存教师模型作为本地模型的初始模型,利用本地数据集d,计算教师模型的分布平缓程度λ、教师模型的预测概率分布p(x)和学生模型的预测概率分布q(x);

22、以q(x)中概率最高的类别作为模型的预测结果;

23、根据教师模型的kl散度损失和学生模型的多分类任务损失,计算本体模型的损失函数。

24、在一些实施例中,利用第一表达式,计算教师模型的分布平缓程度λ;

25、所述第一表达式为:

26、

27、其中,n是多分类任务的类别数量,zt,i(x)代表单个类别教师模型的logit向量,mean(zt(x))代表各个类别预测概率的平均值,i代表单个类别。

28、在一些实施例中,利用第二表达式,计算教师模型的预测概率分布p(x);

29、所述第二表达式为:

30、

31、其中,n是多分类任务的类别数量,zt,i(x)代表单个类别教师模型的logit向量,λ是教师模型的分布平缓程度,zt,j(x)代表单个类别教师模型的logit向量,j代表单个类别。

32、在一些实施例中,基于教师模型的预测概率分布p(x)和学生模型的预测概率分布q(x),利用kl散度计算学生模型和教师模型在同一数据样本上的预测概率分布的差异。

33、在一些实施例中,kl散度计算的表达式为:

34、

35、其中,dkl(p∣∣q)表示kl散度,pi表示教师模型对类别i的预测概率,qi表示学生模型对类别i的预测概率。

36、在一些实施例中,使用交叉熵损失函数来度量真实样本标签y和模型预测之间的本文档来自技高网...

【技术保护点】

1.一种基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,所述方法包括:

2.根据权利要求1所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,响应于多分类任务,对待优化的全局模型进行初始化,具体包括:

3.根据权利要求2所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,所述多分类任务y的表达式为:

4.根据权利要求1所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,应用随机贪婪算法,在随机获取的活跃客户端中筛选当前边际收益最大的客户端。

5.根据权利要求4所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,筛选当前边际收益最大的客户端,具体包括:

6.根据权利要求5所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,所述本地模型的训练过程,具体包括:

7.根据权利要求6所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,利用第一表达式,计算教师模型的分布平缓程度λ;

8.根据权利要求6所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,利用第二表达式,计算教师模型的预测概率分布p(x);

9.根据权利要求6所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,基于教师模型的预测概率分布p(x)和学生模型的预测概率分布q(x),利用KL散度计算学生模型和教师模型在同一数据样本上的预测概率分布的差异。

10.根据权利要求9所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,KL散度计算的表达式为:

11.根据权利要求6所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,使用交叉熵损失函数来度量真实样本标签y和模型预测之间的差异。

12.根据权利要求11所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,交叉熵损失函数的表达式为:

13.根据权利要求12所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,本体模型的损失函数L为:

14.根据权利要求13所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,优化后的全局模型wt+1为:

15.一种基于联邦学习和知识蒸馏的多分类模型优化装置,其特征在于,所述装置包括:

16.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至14中任一项所述的方法的步骤。

...

【技术特征摘要】

1.一种基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,所述方法包括:

2.根据权利要求1所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,响应于多分类任务,对待优化的全局模型进行初始化,具体包括:

3.根据权利要求2所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,所述多分类任务y的表达式为:

4.根据权利要求1所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,应用随机贪婪算法,在随机获取的活跃客户端中筛选当前边际收益最大的客户端。

5.根据权利要求4所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,筛选当前边际收益最大的客户端,具体包括:

6.根据权利要求5所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,所述本地模型的训练过程,具体包括:

7.根据权利要求6所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,利用第一表达式,计算教师模型的分布平缓程度λ;

8.根据权利要求6所述的基于联邦学习和知识蒸馏的多分类模型优化方法,其特征在于,利用第二表达式,计算教师模型的预测概率分布p(x);

9.根据权利要求6所述的基于...

【专利技术属性】
技术研发人员:孙莉莉刘冬兰刘新常英贤张昊王睿赵鹏王勇胡恒瑞陈剑飞赵丽娜王岳史玉良吕梁张方哲马雷赵夫慧姚洪磊于灏秦佳峰赵洺哲孙梦谦苏冰许善杰金玉辉
申请(专利权)人:国网山东省电力公司电力科学研究院
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1