随机采样共识联合半监督学习制造技术

技术编号:37842337 阅读:22 留言:0更新日期:2023-06-14 09:46
提供了用于在非IID设置中随机采样共识联合(RSCFed)学习的方法和系统。该方法包括:随机采样本地客户端,在同步轮次开始时将当前全局模型分配给随机采样的本地客户端以进行初始化,对随机采样的本地客户端进行本地训练,从随机采样的本地客户端中收集本地模型并对收集的本地模型进行距离重加权模型聚合(DMA),以得到子共识模型,多次重复上述步骤以得到一组子共识模型,以及基于子共识模型聚合新模型作为下一全局模型。新模型作为下一全局模型。新模型作为下一全局模型。

【技术实现步骤摘要】
随机采样共识联合半监督学习

技术介绍

[0001]联合学习(FL)的核心思想是在分布于不同位置或设备上的独立数据集上训练机器学习模型,从而在一定程度上保护本地数据隐私。在过去几年里,FL已经成为重要研究工具,并且研究FL在例如医学图像诊断[10,14,28]、图像分类[16]和目标检测[22]等领域的应用方面引起关注。已经提出了各种FL方法,例如FedAvg[23]、SCAFFOLD[12]和MOON[16],并且获得的初步结果很有前景。然而,由于需要在每个本地客户端上完全标记图像,这些方法在实际实践中的应用仍然受到限制。
[0002]最近,旨在利用未标记图像来增强FL的全局模型开发的联合半监督学习(FSSL)[8,19,21,28]已成为新的研究领域。FSSL的一条线专注于具有部分标记图像和未标记图像的每个客户端。例如,Jeong等人[8]引入了客户端间一致性损失的概念,其通过鼓励来自多个客户端的一致输出来改进全局模型。FSSL的另一条线[21,28]做出以下假设:一些本地客户端具有完全标记图像并被表示为标记客户端,而其它本地客户端具有未标记图像并被表示为未标记客户端。然而,这些方法在两个方面受到限制。首先,它们没有考虑本地客户端之间的非独立同分布(not independent and identically distributed,非IID)数据,这导致[9,15]的性能准确性下降。其次,一些方法[21]在本地客户端之间共享相关矩阵,这可能导致信息泄露。
[0003]该问题的一种潜在解决方案是将现有FSSL方法(例如FedIRM[21]和Fed

Consist[28])的应用扩展到非IID设置。然而,该方法未能将这些方法推广到非IID设置。FedIRM方法在客户端之间共享类间相关矩阵,由于本地客户端之间的异构数据而无法正确学习该类间相关矩阵,从而影响模型性能。Fed

Consist方法平等地平均了来自标记客户端和未标记客户端的模型权重。然而,当未标记客户端的数量增加时,随着全局模型可能由未标记客户端主导,模型性能显著下降。另一个可能的解决方案是调整标记客户端和未标记客户端的聚合权重,其中标记客户端的权重增加,而未标记客户端的权重降低。然而,初步结果仅显示出有限的性能改进。

技术实现思路

[0004]本领域仍需要改进的设计和技术,用于在非IID设置中执行随机采样共识联合(random sampling consensus federated,RSCFed)学习的方法和系统。
[0005]根据本专利技术的一个实施例,提供了一种随机采样共识联合(RSCFed)学习方法。该方法包括:随机采样本地客户端;在同步轮次开始时,将当前全局模型分配给随机采样的本地客户端进行初始化;对随机采样的本地客户端进行本地训练;从随机采样的本地客户端中收集本地模型,并对收集的本地模型执行距离重加权模型聚合(DMA),以得到子共识模型;多次重复上述步骤,以得到一组子共识模型;并
[0006]基于子共识模型聚合新模型作为下一全局模型。本地客户端包括具有标记的本地数据的标记本地客户端、和具有未标记的本地数据的未标记本地客户端。此外,将当前全局模型分配给随机采样的客户端以初始化包括:用当前全局模型初始化本地模型以对随机采
样的客户端执行本地训练。进行本地训练包括对标记本地客户端和未标记本地客户端分别进行标准监督训练和无监督训练。利用作为由以下方程定义的主要目标的交叉熵损失L
CE
对所述标记本地客户端进行所述本地训练:
[0007][0008]其中是对来自相应本地模型的随机采样的本地客户端的预测。通过基于mean

teacher的一致性正则化框架,并将学生模型视为所述本地模型,来对所述未标记本地客户端进行所述本地训练。此外,距离重加权模型聚合(DMA)被配置为动态调整所收集的模型的权重。在对未标记本地客户端进行本地训练期间,在生成来自学生模型和教师模型的预测后,配置锐化方法来提高教师模型的预测热度。当本地训练完成后,将学生模型作为本地模型提供给相应的未标记本地客户端。此外,执行距离重加权模型聚合(DMA)包括:为每个子集计算子集内平均模型;在每个子集中为本地客户端缩放权重;并将子集内模型权重归一化到[0,1]的范围内。
[0009]本专利技术的某些实施例中,提供了一种用于执行随机采样共识联合(RSCFed)学习的系统。该系统包括通过通信网络耦接到多个本地客户端的联合服务器,并且联合服务器被配置为:随机采样多个本地客户端,在同步轮次开始时将当前全局模型分配给随机采样的本地客户端进行初始化,对随机采样的本地客户端进行本地训练;从随机采样的本地客户端中收集本地模型,对收集的本地模型执行距离重加权模型聚合(DMA),以得到子共识模型,多次重复上述步骤,以得到一组子共识模型,并基于子共识模型聚合新模型作为下一全局模型。本地客户端包括具有标记的本地数据的标记本地客户端、和具有未标记的本地数据的未标记本地客户端。此外,将当前全局模型分配给随机采样的客户端以初始化包括:用当前全局模型初始化本地模型以对随机采样的客户端执行本地训练。本地训练包括对标记本地客户端和未标记本地客户端分别进行标准监督训练和无监督训练。利用作为由以下方程定义的主要目标的交叉熵损失L
CE
对所述标记本地客户端进行所述本地训练:
[0010][0011]其中是对来自相应本地模型的随机采样的本地客户端的预测。通过基于mean

teacher的一致性正则化框架,并将学生模型视为所述本地模型,来对所述未标记本地客户端进行所述本地训练。此外,距离重加权模型聚合(DMA)被配置为动态调整所收集的模型的权重。在对未标记本地客户端进行本地训练期间,在生成来自学生模型和教师模型的预测后,配置锐化方法来提高教师模型的预测热度。当本地训练完成后,将学生模型作为本地模型提供给相应的未标记本地客户端。此外,执行距离重加权模型聚合(DMA)执行:为每个子集计算子集内平均模型;在每个子集中为本地客户端缩放权重;并将子集内模型权重归一化到[0,1]的范围内。
附图说明
[0012]图1是示出了根据本专利技术实施例的随机采样共识联合(RSCFed)学习方法的步骤的流程图。
[0013]图2是根据本专利技术实施例的RSCFed方法的概述的示意图,其中标记本地客户端和未标记本地客户端分别通过监督交叉熵损失(supervised cross

entropy loss)L
CE
和基于
mean

teacher的一致性损失(mean

teacher

based consistency loss)L
MSE
来优化,其中RSCFed使用距离重加权模型聚合(distance

reweighted model aggregation,DMA)在所有客户端之间执行本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种随机采样共识联合(RSCFed)学习的方法,包括:随机采样本地客户端;在同步轮次开始时,将当前全局模型分配给所述随机采样的本地客户端,以进行初始化;对所述随机采样的本地客户端进行本地训练;从所述随机采样的本地客户端中收集本地模型,并对所述收集的本地模型执行距离重加权模型聚合(DMA),以得到子共识模型;多次重复上述步骤,以得到一组子共识模型;以及基于所述子共识模型聚合新模型作为下一全局模型。2.根据权利要求1所述的方法,其中,所述本地客户端包括具有标记的本地数据的标记本地客户端和具有未标记的本地数据的未标记本地客户端。3.根据权利要求1所述的方法,其中,将当前全局模型分配给所述随机采样的客户端以进行初始化的步骤包括:用所述当前全局模型初始化所述本地模型,以对所述随机采样的客户端执行本地训练。4.根据权利要求2所述的方法,其中,进行本地训练的步骤包括:对所述标记本地客户端和所述未标记本地客户端分别进行标准监督训练和无监督训练。5.根据权利要求1所述的方法,其中,利用作为由以下方程定义的主要目标的交叉熵损失L
CE
对所述标记本地客户端进行所述本地训练:其中是对来自相应本地模型的所述随机采样的本地客户端的预测。6.根据权利要求2所述的方法,其中,通过基于mean

teacher的一致性正则化框架,并将学生模型视为所述本地模型,来对所述未标记本地客户端进行所述本地训练。7.根据权利要求1所述的方法,其中,所述距离重加权模型聚合(DMA)被配置为动态调整所述收集的模型的权重。8.根据权利要求2所述的方法,其中,在对所述未标记本地客户端进行所述本地训练期间,在生成来自所述学生模型和所述教师模型的预测之后,配置锐化方法以提高所述教师模型的所述预测的热度。9.根据权利要求8所述的方法,其中,当所述本地训练完成时,将所述学生模型作为所述本地模型提供给对应的未标记本地客户端。10.根据权利要求1所述的方法,其中,执行距离重加权模型聚合(DMA)包括:计算每个子集的子集内平均模型;在每个子集中为所述本地客户端缩放权重;和将所述子集内模型权重归一化到[0,1]的范围内。11.一种用于执行随机采样共识联合(RSCFed)学习的系统,包括:...

【专利技术属性】
技术研发人员:李小萌
申请(专利权)人:香港科技大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1