当前位置: 首页 > 专利查询>厦门大学专利>正文

一种面向长尾异构数据的联邦学习方法技术

技术编号:33303449 阅读:11 留言:0更新日期:2022-05-06 12:11
本发明专利技术公开了一种面向长尾异构数据的联邦学习方法包括如下步骤:步骤一、服务器端随机初始化全局模型w并将模型参数发给各个客户端,各个客户端利用收到的模型参数进行模型更新,并将更新后的模型参数上传至服务器端;步骤二、服务器端对收到的本地模型参数后进行聚合得到教师模型和学生模型;步骤三、服务器端对步骤二中得到的教师模型进行校准,让教师模型在无偏知识上进行学习,以此教出好的学生模型;步骤四、使用知识蒸馏将教师模型的无偏知识传递给学生模型,随后将学生模型发给各个客户端开始下一轮联邦训练。户端开始下一轮联邦训练。户端开始下一轮联邦训练。

【技术实现步骤摘要】
Loss去减弱不平衡带来的影响。但是这个方法的性能随着数据异构程度的加深而急剧下降。

技术实现思路

[0008]为解决现有技术的不足,实现在满足用户隐私保护、数据安全的同时,提升联邦学习下模型性能,从而提高图像识别效率的目的,本专利技术采用如下的技术方案:
[0009]一种面向长尾异构数据的联邦学习方法,包括如下步骤:
[0010]S1,服务器端随机初始化全局模型w,并将模型参数下发至各个客户端,各个客户端利用收到的模型参数进行本地模型更新,并将更新后的本地模型参数上传至服务器端;
[0011]S2,服务器端对收到的本地模型参数后进行聚合,得到教师模型和学生模型;
[0012]S3,服务器端对教师模型进行校准,使教师模型在无偏知识上进行学习,以此教出好的学生模型;
[0013]S4,通过知识蒸馏,将教师模型的无偏知识传递给学生模型,随后将学生模型下发至各个客户端开始下一轮联邦训练。
[0014]进一步地,步骤S1中,服务器端初始化全局模型参数w,随机选择参与本轮训练的客户端集合S,并将模型参数广播给参与本轮训练的客户端集合S,S中的每个客户端,均利用收到的全局模型参数w和本地的数据,执行随机梯度下降(SGD),以更新本地模型,客户端k更新得到的本地模型参数为w
k
,待更新之后,各个客户端将其更新的模型参数发还给服务器端。
[0015]进一步地,步骤S2包括如下步骤:
[0016]S21,服务器端对本地模型参数进行平均加权,得到学生模型,计算公式如下:
[0017][0018]φ
s
(x)=φ
w
(x)
ꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀ
(公式2)
[0019]其中,|D
k
|表示第k个客户端拥有的图像数据量,|D|表示所有客户端拥有的图像数据总量,K表示客户端数量,x表示输入图像数据,φ
w
(
·
)表示联邦平均模型的网络,φ
s
(
·
)表示学生模型的网络;
[0020]S22,服务器端对本地模型参数进行加权聚合,得到教师模型,计算公式如下:
[0021][0022]其中,φ
t
(
·
)表示教师模型的网络,e
k
表示赋给客户端k的权重,表示第k个客户端的网络。
[0023]进一步地,步骤S3中,由于本地模型是在具有不同分布的本地数据上进行训练的,每个本地模型在尾部类上的表现可能不同,因此我们要为在尾部类表现更好的本地模型分配更高的权重,然而,服务器端不知道哪些图像的类是尾部类,并且哪些客户端本地模型在上面表现良好,因此,我们不是给每个客户端一个固定的权重,相反,我们提出基于客户端的权重分配策略,以此来计算每个客户端本地模型的权重e
k
,最后将e
k
归一化使其总和等于1,即为最终权重,权重e
k
的计算公式如下:
[0024][0025]其中,a
e
∈R
c
和b
e
表示可被学习的网络参数,R
c
表示c维向量,T为转置符号,基于客户端的校准就像自注意力机制一样,根据模型的原始输出logits对本地模型计算权重,再将权重乘回原始输出logits。
[0026]进一步地,步骤S3中,若没有一个客户端本地模型可以很好地处理尾类,那么加权集成得到的教师模型仍偏向于头部类,为解决该问题,我们提出基于类的原始输出logits校准策略,以进一步提升模型在尾部类的性能,设被校准后的模型输出logits为z
cl
,计算公式如下:
[0027]z
cl
=a
z

φ
t
(x)+b
z
ꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀ
(公式5)
[0028]其中,a
z
和b
z
表示可被学习的网络参数,

表示哈达玛积。
[0029]进一步地,步骤S3中,上述对logits校准策略有效的前提是本地模型对输入的图像数据提取的表征信息足够好,若客户端本地模型对输入的图像数据的特征提取因为长尾分布而受到严重影响,那么仅仅只对输出logits做校准是不够的,因此,我们需要更新特征提取器来进一步提升模型性能,我们在服务器端利用额外的平衡有标签的图像,构成的平衡有标签数据集在全局模型w上进行微调,得到微调模型因为的数据分布是平衡的,所以微调模型可以获得无偏的特征提取器,然后,我们可以获得对于输入的图像数据为x的微调模型输出logits为其中z
ft
表示微调模型对x的输出,表示微调模型的网络。
[0030]进一步地,微调模型其中,η表示学习率,表示损失函数,表示求导。
[0031]进一步地,步骤S3中,z
cl
和z
ft
是从两个不同的层面去校准教师模型,z
cl
是对教师模型输出logits层面进行校准,其模型的特征提取器被固定,然而z
ft
是对特征提取器微调的结果,以此提升模型特征提取能力,为充分结合二者优势,我们提出通过一个校准门控网络对z
cl
和z
ft
做权衡,校准门控网络以集成特征作为输入,经由一个非线性层输出权重,使得每个样本根据自身的特征不同而获得不同的权重,权重计算公式如下:
[0032]σ=sigmoid(u
T
v)
ꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀ
(公式6)
[0033]其中,表示集成特征,表示第k个客户端的特征提取器,u∈R
d
表示可被学习的网络参数,R
d
表示d维向量,因此,通过校准门控网络的最终校准模型输出logits为z

,计算公式如下:
[0034]z

=σz
cl
+(1

σ)z
ft
ꢀꢀꢀꢀꢀꢀꢀ
(公式7)
[0035]其中σ∈(0,1)用于权衡z
cl
和z
ft
两个模型输出logits。
[0036]进一步地,集成校准的整个过程中,所有可被学习的参数都通过在上的交叉熵损失进行更新,损失函数如下:
[0037][0038]其中,C表示类别数,y
j
表示输入图像数据的真实标签,j表示y中第j维的值,exp(
·
)表示以自然常数e为底的指数函数,z

j
表示最终校准z

中第j维的值,z

i
表示最终校准z

中第i维的值,z
本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种面向长尾异构数据的联邦学习方法,其特征在于包括如下步骤:S1,服务器端随机初始化全局模型w,并将模型参数下发至客户端,客户端利用收到的模型参数进行本地模型更新,并将更新后的本地模型参数上传至服务器端;S2,服务器端对本地模型参数进行聚合,得到教师模型和学生模型;S3,服务器端对教师模型进行校准,使教师模型在无偏知识上进行学习;S4,通过知识蒸馏,将教师模型的无偏知识传递给学生模型,随后将学生模型下发至客户端开始下一轮联邦训练。2.根据权利要求1所述的一种面向长尾异构数据的联邦学习方法,其特征在于所述步骤S1中,服务器端初始化全局模型参数w,随机选择参与本轮训练的客户端集合S,并将模型参数广播给参与本轮训练的客户端集合S,S中的客户端,利用收到的全局模型参数w和本地的数据,执行随机梯度下降,以更新本地模型,客户端k更新得到的本地模型参数为w
k
,待更新之后,客户端将其更新的模型参数发还给服务器端。3.根据权利要求2所述的一种面向长尾异构数据的联邦学习方法,其特征在于所述步骤S2包括如下步骤:S21,服务器端对本地模型参数进行平均加权,得到学生模型,计算公式如下:φ
s
(x)=φ
w
(x)
ꢀꢀꢀꢀꢀꢀꢀ
(公式2)其中,|D
k
|表示第k个客户端拥有的数据量,|D|表示所有客户端拥有的数据总量,K表示客户端数量,x表示输入数据,φ
w
(
·
)表示联邦平均模型的网络,φ
s
(
·
)表示学生模型的网络。S22,服务器端对本地模型参数进行加权聚合,得到教师模型,计算公式如下:其中,φ
t
(
·
)表示教师模型的网络,e
k
表示客户端k的权重,表示第k个客户端的网络。4.根据权利要求3所述的一种面向长尾异构数据的联邦学习方法,其特征在于所述步骤S3中,提出基于客户端的权重分配策略,以此来计算每个客户端本地模型的权重e
k
,最后将e
k
归一化使其总和等于1,即为最终权重,权重e
k
的计算公式如下:其中,a
e
∈R
c
和b
e
表示可被学习的网络参数,R
c
表示c维向量,T为转置符号,根据模型的原始输出对本地模型计算权重,再将权重乘回原始输出。5.根据权利要求4所述的一种面向长尾异构数据的联邦学习方法,其特征在于所述步骤S3中,提出基于类的原始输出校准策略,校准后的模型输出为z
cl
,计算公式如下:z
cl
=a
z

φ
t
(x)+b
z
ꢀꢀꢀꢀꢀꢀꢀ
(公式5)其中,a
z
和b
z

【专利技术属性】
技术研发人员:卢杨尚心怡黄刚华炜王菡子
申请(专利权)人:厦门大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1