一种类自适应联邦半监督学习方法技术

技术编号:38526664 阅读:7 留言:0更新日期:2023-08-19 17:02
一种类自适应联邦半监督学习方法,首先,基于客户端采集的样本数据,边缘服务器训练半监督学习模型,并仅将模型参数发送给聚合服务器。然后,聚合服务器根据收到的梯度以及边缘服务器的数量推算出全局数据类别数目的变化。其中,若有类别增加,则平均聚合收到的参数得到新的全局参数,无需做其他改变。若有类别减少,则使用上一轮收到的边缘参数平均聚合得到新的全局参数,无需做其他改变。若类别数目不变,某类样本数量变化,则将全局数据类别数目的变化发送至边缘服务器并依据全局数据类别数目的变化进行样本扩充。最后,边缘服务器使用调整后的数据与新的全局参数进行下一轮次的边缘训练。的边缘训练。的边缘训练。

【技术实现步骤摘要】
一种类自适应联邦半监督学习方法


[0001]本专利技术属于人工智能
,涉及联邦学习,特别涉及一种类自适应联邦半监督学习方法。

技术介绍

[0002]随着世界范围内隐私保护法律的日益严格实施,联邦学习(Federated Learning,FL)受到了极大的关注,并成为近期的热门研究课题。在边缘网络中,FL通过聚合多个设备的局部模型参数,使多个设备能够协同学习一个全局模型,不需要从本地上传数据到云端。它既减少了云端的资源消耗,也增强了客户端的隐私。在边缘智能场景中,边缘设备所采集到的数据是具有差异性的。因为边缘设备所处位置的不同,所以其采集到的数据在类别上也是区别与其他边缘设备。同时,边缘设备往往具有移动性和不稳定性,因此在FL模型迭代训练中,设备的动态进入或退出会影响全局模型的性能。
[0003]在FL期间,由于FL训练是以加密的形式交换梯度来执行的,训练数据对客户端或服务器来说都是无法完全观察到的,客户端和服务器之间的通信仅限于模型训练参数。出于隐私方面的考虑,服务器最好不要要求客户上传有关其本地数据的额外信息,这可能会暴露出潜在的攻击后门,导致隐私泄漏。因此,收集所有本地数据的信息并在全局范围内进行汇总分析是不可行的。同时在FL过程中,由于边缘网络的网络波动、设备宕机和设备移动等各方面问题,一些参与训练的客户端可能会断联,无法继续参与联邦训练,将会造成部分类别数据的丢失;一些新的客户端会加入到联邦训练中,将会增加新类别的数据。进而导致FL模型迭代训练的数据类别数目上发生变化,这将可能会造成新的数据分布不平衡。如果这种不平衡不能被及时发现,就会在早期训练阶段诱导全局模型走向错误的方向,从而毒害全局模型,恶化性能。

技术实现思路

[0004]为了克服上述现有技术的缺点,本专利技术的目的在于提供一种类自适应联邦半监督学习方法,利用每轮训练中客户端的数量与梯度变化检测数据中各类别的变化情况,并使用监测出的类别变化情况及时从全局与局部角度调整方案,从而优化这种类可变的不平衡数据,能够进一步提高模型训练的准确率,使得联邦学习训练的全局模型更加准确。
[0005]为了实现上述目的,本专利技术采用的技术方案是:
[0006]一种类自适应联邦半监督学习方法,包括如下步骤:
[0007]步骤1,基于客户端采集的样本数据,边缘服务器训练半监督学习模型,并仅将模型参数发送给聚合服务器;
[0008]步骤2,待一个全局训练轮次结束后,聚合服务器收到边缘服务器更新的模型参数,并平均聚合本轮收到的模型参数得到全局参数;根据边缘服务器全局参数前后两轮之间的变化,监测到每个类别的全局样本数目变化比率;
[0009]根据每个类别的全局样本数目变化比率,聚合服务器选择用于下一轮边缘训练的
全局参数:若有新增类,则正常聚合由聚合服务器收到的各个边缘服务器发送的该新增类的参数;若有缺失类,则使用聚合服务器收到的上一轮各个边缘服务器发送的该缺失类的参数更新;若类别数目无变化,某类别中样本数量改变造成数据不平衡,则将数据中每个类别的全局样本数目变化比率发送给边缘服务器,边缘服务器依据样本数量改变的类别的全局样本数目的变化进行样本扩充;
[0010]步骤3,将根据步骤2得到的全局参数发送给各个边缘服务器;
[0011]步骤4,各个边缘服务器利用本地数据的标签集和无标签集实现类再平衡自训练;
[0012]步骤5,各个边缘服务器利用平衡后的数据再次进行本地小批次梯度下降,将更新的参数发送给聚合服务器,重复迭代步骤1,步骤2,步骤3和步骤4,直到联邦学习模型收敛。
[0013]在一个实施例中,聚合服务器平均聚合本轮收到的边缘服务器的模型参数得到全局参数,若本轮边缘服务器的数量较上一轮有变化,则聚合服务器根据两轮之间边缘服务器数量的变化与两轮之间全局参数变化,监测到每个类别的全局样本数目变化比率,无需边缘服务器提供相关信息。
[0014]在一个实施例中,根据检测到的每个类别全局样本数目变化比率,与上轮比较,若本轮数据有类别增加,则说明有新增类,直接使用平均聚合后的全局参数作为下一轮边缘服务器的训练参数。
[0015]在一个实施例中,根据检测到的每个类别全局样本数目变化比率,与上轮比较,若本轮数据有类别减少,则说明有缺失类,使用上一轮聚合服务器平均聚合后的全局参数作为下一轮边缘服务器的训练参数。
[0016]在一个实施例中,根据检测到的每个类别全局样本数目变化比率,与上轮比较,若本轮数据类别数目无变化,但是存在有类别内样本数量的变化造成的全局数据不平衡,则在边缘服务器上使用无标签数据进行有选择的数据扩充。
[0017]在一个实施例中,所述步骤4.3,边缘服务器基于步骤2中的数据类别变化从伪标签集中选取相应比例的伪标签样本生成一个伪标签子集。利用伪标签子集扩展标签集,解决由于联邦学习中边缘设备动态加入或退出造成的数据类别不平衡,实现数据集的全局类分布平衡,在新的标签集上重新训练联邦学习模型。
[0018]与现有技术相比,本专利技术的有益效果是:
[0019](1)目前的联邦学习研究大多在关注类别数目固定的数据非平衡,而很少关注类别数目变化的数据非平衡,本专利技术可以缓解联邦学习中类别数目变化的数据非平衡。
[0020](2)提出一个类别数量监测方案,可以通过梯度推测出FL过程中各类别的变化情况,如果类别变化出现,通过监测结果可以及时采取措施以减轻负面影响。
[0021](3)本专利技术可以根据联邦学习训练过程中的类别变化及时调整数据,避免由于边缘设备动态退出与加入造成的数据不平衡对联邦学习中模型精度与准确率造成的影响。
附图说明
[0022]图1是本专利技术流程示意图。
[0023]图2是在不同FL训练阶段(图左为在第50轮、图右为在第100轮后边缘服务器断联)断联后全局模型准确率的变化曲线。
[0024]图3是本专利技术多层联邦学习结构图。
具体实施方式
[0025]为了使本专利技术的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本专利技术的面向联邦学习基于类别数目变化监测的数据再平衡方法进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本专利技术,并不用于限定本专利技术。
[0026]如图1所示,本专利技术类自适应联邦半监督学习方法,包括如下步骤:
[0027]步骤1,客户端采集样本数据,边缘服务器基于该样本数据训练半监督学习模型,在几轮边缘训练后,边缘服务器仅将模型参数发送给聚合服务器。在本专利技术中,半监督学习模型的训练方法为小批次梯度下降算法。
[0028]步骤2,待一个全局训练轮次结束后,聚合服务器收到边缘服务器更新的模型参数,根据边缘服务器全局参数前后两轮之间的变化,监测到类别样本变化比率R,即全局数据中每个类别的全局样本数目变化比率。本专利技术中,类别即数据类别,通过本专利技术方法维持全局数据类别的平衡。
[0029]其中,聚合服务器平均聚合本轮收到的边缘服务器的模型参数得到全局参本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种类自适应联邦半监督学习方法,其特征在于,包括如下步骤:步骤1,基于客户端采集的样本数据,边缘服务器训练半监督学习模型,并仅将模型参数发送给聚合服务器;步骤2,待一个全局训练轮次结束后,聚合服务器收到边缘服务器更新的模型参数,并平均聚合本轮收到的模型参数得到全局参数;根据边缘服务器全局参数前后两轮之间的变化,监测到每个类别的全局样本数目变化比率;根据每个类别的全局样本数目变化比率,聚合服务器选择用于下一轮边缘训练的全局参数:若有新增类,则正常聚合由聚合服务器收到的各个边缘服务器发送的该新增类的参数;若有缺失类,则使用聚合服务器收到的上一轮各个边缘服务器发送的该缺失类的参数更新;若类别数目无变化,某类别中样本数量改变造成数据不平衡,则将数据中每个类别的全局样本数目变化比率发送给边缘服务器,边缘服务器依据样本数量改变的类别的全局样本数目的变化进行样本扩充;步骤3,将根据步骤2得到的全局参数发送给各个边缘服务器;步骤4,各个边缘服务器利用本地数据的标签集和无标签集实现类再平衡自训练;步骤5,各个边缘服务器利用平衡后的数据再次进行本地小批次梯度下降,将更新的参数发送给聚合服务器,重复迭代步骤1,步骤2,步骤3和步骤4,直到联邦学习模型收敛。2.根据权利要求1所述类自适应联邦半监督学习方法,其特征在于,所述步骤1中,边缘服务器使用小批次梯度下降算法训练半监督学习模型,在几轮边缘训练后,边缘服务器仅将模型参数发送给聚合服务器。3.根据权利要求1所述类自适应联邦半监督学习方...

【专利技术属性】
技术研发人员:许志伟刘思远尹德辉
申请(专利权)人:北京崇实允升科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1