一种基于局部-全局伪标记的联邦半监督学习方法技术

技术编号:39720262 阅读:11 留言:0更新日期:2023-12-17 23:26
本发明专利技术公开了一种基于局部

【技术实现步骤摘要】
一种基于局部

全局伪标记的联邦半监督学习方法


[0001]本专利技术涉及人工智能
,尤其是一种基于局部

全局伪标记的联邦半监督学习方法


技术介绍

[0002]联邦学习作为一种分布式解决方案最近引起了广泛的研究关注,它允许多个参与客户端协作学习一个全局模型,而无需共享自己的私有数据

联邦学习现有的成果主要集中在有监督的环境中,即每个客户端都有完全标记的数据

然而在现实场景中,由于注释成本高或专业知识不足,大多数客户端可能没有足够的真实标签数据,而且必须注意的是,有些客户端可能拥有完全未标注的数据
(
例如,医院病人的医疗数据
)。
因此,有学者提出了联邦半监督学习范式,它通过利用客户端存在的大量未标记数据来增强全局模型的性能

根据标记数据在客户端和服务器之间的分布情况,可以将联邦半监督学习分为以下三大类:Ⅰ)Labels

at

client
,该场景中客户端拥有由标记数据和未标记数据组成的混合数据集;Ⅱ)Labels

at

server
,该场景中客户端包含一个未标记的数据集,而服务器则包含一个标记数据集;Ⅲ)Labels

at

Partial

Client
,该场景中一部分客户端具有标记数据集,而其余客户端则具有未标记数据的数据集

[0003]目前,联邦半监督学习领域内的通用框架是
FedMatch
,它涵盖了联邦半监学习场景中的前两类
。FedMatch
使用伪标记的方法为无标签数据生成伪标签从而利用每个客户端的未标记数据集

然而,在伪标签的生成过程中它依赖来自其它客户端的辅助模型,这就带来了潜在的隐私泄露风险,这一点与联邦学习的隐私保护性背道而驰

其次,在每一轮通信开始前,服务器都需要将
H
个辅助模型发送给每个活动的客户端,这极大地增加了通信的开销,降低了全局模型的收敛速度

最后,它是为每一个批次中的未标记训练数据生成伪标签,该做法会导致模型遗忘从标记数据中学习到的知识,虽然参数分解机制能够缓解该问题,但不能从根本上解决它

[0004]SemiFL

Labels

at

server
场景下一个效果较好的解决方案

它首次提出了“使用标记数据微调全局模型”和“使用全局模型生成伪标签”的交替训练方式
。SemiFL
同样使用伪标记法利用未标记数据,但是它为无标签数据生成伪标签时仅使用了全局模型而忽略了本地模型

当本地客户端的数据分布与服务器的差异较大时,生成的伪标签可能包含大量噪声,从而导致严重的确认偏差现象而降低全局模型的性能

尤其是标记数据量较少时,该问题会进一步加剧


技术实现思路

[0005]针对当前联邦半监督学习方法存在的导致潜在信息泄漏,只使用全局模型来生成伪标签等缺陷,本专利技术提供一种基于局部

全局伪标记的联邦半监督学习方法

[0006]本专利技术提供的基于局部

全局伪标记的联邦半监督学习方法,从客户端和服务器之间数据分布差异的角度出发,在不共享任何敏感信息的情况下,提出了一种名为
FedLGMatch
的新型联邦半监督学习框架

直观的理解是,与全局模型相比,本地模型捕捉本地数据独特特征的能力更强

因此,在为未标记数据生成伪标签时,自然会想到使用在每个客户端训练的本地模型来辅助全局模型

[0007]本专利技术提供的基于局部

全局伪标记的联邦半监督学习方法,包括以下两个步骤:
[0008]S1、
在通信轮次
t
开始时,服务器将全局模型参数传输到活动客户端;每个客户端在接收到参数后再利用全局模型和上一轮通信轮次
t
‑1中训练得到的本地模型在未标记数据的弱增强视图上生成伪标签,并将其作为本地训练强增强视图的目标用于优化交叉熵损失

[0009]S2、
本地训练结束后,每个客户端将本地模型的参数发回到服务器,服务器首先聚合这些参数,然后利用标记数据集对其进行微调,最后得到一个新的全局模型
[0010]上述交替训练过程重复多次至全局模型收敛后结束

[0011]所述步骤
S1
中,生成伪标签的方法如下:
[0012]对于客户端
C
u
的未标记数据集使用公式
(1)

(2)
一次性标记数据集内所有数据,并通过下列公式
(3)
的方式构建一个固定的伪标记数据集
[0013][0014][0015][0016]式中,
I(
·
)
是一个指示函数,
DA
w
是弱数据增强操作,
x
u
是未标记数据,是通信轮次
t
中全局模型的参数,
F(
·
)
是一个基于卷积神经网络的编码器,
τ
是置信度阈值;是上一轮通信轮次
t
‑1中训练得到的本地模型参数,
L
是未标记数据集中的类别总数

[0017]如果数据集为空,则该客户端的训练过程将直接被跳过;否则,则从数据集
D
u
中随机采样一个与等大小的数据集用于辅助训练,其定义如下:
[0018][0019]其中,是生成的伪标签,是数据集的大小

[0020]在客户端的本地训练过程中,数据集和被随机划分为大小为
B
u
的小批次数据

[0021]对于伪标记数据集定义如下训练目标:
[0022][0023]其中,
CE
表示交叉熵损失,
DA
s
是一种强数据增强操作

[0024]从数据集和被划分的小批次数据中分别采样一对样本和然后利用线性插值方法构建一对新的样本:
[0025][0026]其中
Beta(
·
)
表示
Beta
分布,
α
是其对应的超参数;
λ

Beta
分布生成的一个数值,代表插值方法构造出的新数据,
i
表示索引下标,
u
是本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.
一种基于局部

全局伪标记的联邦半监督学习方法,其特征在于,包括以下两个步骤:
S1、
在通信轮次
t
开始时,服务器将全局模型参数传输到活动客户端;每个客户端在接收到参数后再利用全局模型和上一轮通信轮次
t
‑1中训练得到的本地模型在未标记数据的弱增强视图上生成伪标签,并将其作为本地训练强增强视图的目标用于优化交叉熵损失;
S2、
本地训练结束后,每个客户端将本地模型的参数发回到服务器,服务器首先聚合这些参数,然后利用标记数据集对其进行微调,最后得到一个新的全局模型上述交替训练过程重复多次至全局模型收敛后结束
。2.
如权利要求1所述的基于局部

全局伪标记的联邦半监督学习方法,其特征在于,步骤
S1
中,生成伪标签的方法如下:对于客户端
C
u
的未标记数据集使用公式
(1)

(2)
一次性标记数据集内所有数据,并通过下列公式
(3)
的方式构建一个固定的伪标记数据集的方式构建一个固定的伪标记数据集的方式构建一个固定的伪标记数据集的方式构建一个固定的伪标记数据集式中,
I(
·
)
是一个指示函数,
DA
w
是弱数据增强操作,
x
u
是未标记数据,是通信轮次
t
中全局模型的参数,
F(
·
)
是一个基于卷积神经网络的编码器,
τ
是置信度阈值;是上一轮通信轮次
t
‑1中训练得到的本地模型参数,
L
是未标记数据集中的类别总数
。3.
如权利要求2所述的基于局部

全局伪标记的联邦半监督学习方法,其特征在于,如果数据集为空,则该客户端的训练过程将直接被跳过;否则,则从数据集
Du
中随机采样一个与等大小的数据集用于辅助训练,其定义如下:其中,是生成的伪标签,是数据集的大小;在客户端的本地训练过程中,数据集和被随机划分为大小为
B
u
的小批次数据
。4.
如权利要求3所述的基于局...

【专利技术属性】
技术研发人员:储节磊赵晴李天瑞吕凤毛
申请(专利权)人:西南交通大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1