一种基于长短梯度值的联邦学习恶意客户端检测方法技术

技术编号:38682862 阅读:10 留言:0更新日期:2023-09-02 22:55
本发明专利技术公开了一种基于长短梯度值的联邦学习恶意客户端检测方法。首先该方法引入了长短梯度值用于检测恶意客户端和不可靠客户端,短梯度值用于平滑训练客户端的本地梯度值,可以消除单轮本地梯度值的随机性,长梯度值是记录所有时间内本地梯度值的总和,可以反映训练客户端的累积影响。同时,本发明专利技术考虑了符号翻转、附加噪声、单标签翻转和多标签翻转等攻击类型,分类型进行检测,可以区分恶意客户端的目标攻击和无目标攻击。本发明专利技术提供的方法用于检测联邦学习的恶意客户端,从而保证联邦学习的学习安全。的学习安全。的学习安全。

【技术实现步骤摘要】
一种基于长短梯度值的联邦学习恶意客户端检测方法


[0001]本专利技术涉及联邦学习领域,具体来说,本专利技术具体涉及基于长短梯度值的联邦学习恶意客户端检测方法。

技术介绍

[0002]近年来,随着计算能力增强的智能设备的普及,为在大量分布式设备上训练机器学习模型奠定了坚实的基础。分布式机器学习需要一个安全的传输数据的环境,所以分布式机器学习要求较高的网络带宽,但可能引发巨大的隐私问题,为此,谷歌引入了联邦学习的概念,联邦学习允许本地客户端协作训练一个全局模型,而不是让本地客户端的数据离开本地设备。
[0003]然而由于客户端的不受控制和分布式特点,还有服务器无法访问客户端的数据等问题,容易造成联邦学习容易受到来自客户端发起的恶意攻击。目前来说,恶意客户端的攻击有两种:1、无目标攻击,恶意客户端试图破坏全局模型的收敛;2、目标攻击,恶意客户端只在特定目标情况下破坏全局模型的收敛。同时,联邦学习的客户端中还存在良性的不可靠客户端,由于数据质量低或者学习不可靠,会影响全局模型的整体性能,而由于不可靠客户端可能拥有有价值的数据,所以并不能直接排除不可靠客户端。
[0004]综上所述,现有联邦学习架构中存在有如下缺陷:(1)由于联邦学习的客户端不受控制以及服务器无法访问客户端的数据,使得联邦学习容易受到客户端发起的敌对攻击。(2)外部攻击者可以从联邦学习客户端所传递的参数中重构出原始的训练数据,从而导致隐私泄露。因此利用长短梯度值分类型检测恶意客户端和不可靠客户端可以使联邦学习更安全可靠。

技术实现思路

[0005]为解决上述现有技术中存在的不足,本专利技术提供一种基于长短梯度值的联邦学习恶意客户端检测方法,解决联邦学习中客户端不受控制的问题,提高联邦学习的鲁棒性。
[0006]为了实现上述技术目的,本专利技术的技术方案为:
[0007]基于长短梯度值的联邦学习恶意客户端检测方法,其特征在于以下步骤,包括:
[0008]步骤一:中心服务器C发送初始梯度值w1给所有训练客户端K={k1,k2,...,k
n
},每个训练客户端k
i
∈K,i=1,2,...,n根据各自本地的数据集利用初始梯度值w1进行训练得到本地梯度值
[0009]步骤二:训练客户端k
i
将本地梯度值发送给中央服务器C,中央服务器C根据梯度值的长短梯度值区分检测恶意客户端k
m
和不可靠客户端k
u

[0010]步骤三:中央服务器排除恶意客户端后,将正常客户端k
r
和不可靠客户端k
u
上传的梯度值进行加权聚合,得到全局梯度值w2。
[0011]进一步的,根据权利要求1所述的基于长短梯度值的联邦学习恶意客户端检测方法,所述步骤二为了检测恶意客户端和不可靠客户端,引入了长短梯度值用于区分恶意客
户端和不可靠客户端:
[0012](1)利用表示在第r轮训练之前中央服务器从客户端k
i
接收到的梯度值集合。那么在第r轮中训练客户端k
i
的短梯度值可以表示为:
[0013][0014]其中,x是一个动态的数值,这样可以选取需要的梯度值的学习轮数,短梯度值用来消除单轮梯度值的随机性。
[0015]那么在第r轮中训练客户端k
i
的长梯度值可以表示为:
[0016][0017]长梯度值是记录所有时间内梯度值的总和,所以长梯度值可以反映训练客户端对全局模型的累积影响。
[0018]进一步的,根据权利要求1所述的基于长短梯度值的联邦学习恶意客户端检测方法,所述步骤二可以按照无目标攻击、目标攻击、不可靠客户端和正常客户端的顺序来检测客户端:
[0019](1)由于无目标攻击客户端的目的是为了破坏整个模型,所以它的短梯度值与正常客户端的短梯度值有较大的不同。在符号翻转的无目标攻击中,恶意客户端的梯度值将改变到相反方向,会导致其短梯度值与所有客户端的短梯度值中位值存在较大的角度偏差,所以可以使用短梯度值计算余弦距离用于检测符号翻转的恶意客户端。如果一个客户端k
i
的短梯度值和中位值的余弦距离时,其中d
cos
(
·
)表示计算余弦距离,那么该客户端k
i
将被检测为符号翻转的无目标攻击客户端;
[0020](2)在排除了符号翻转的恶意客户端后,添加噪声的恶意客户端与剩余客户端拥有类似的短梯度值,所以在排除了符号翻转的恶意客户端后,使用了一种基于短梯度值的聚类方法来验证添加噪声的恶意客户端,由于添加噪声的恶意客户端的短梯度值相对于不可靠客户端更加远离其他客户端的短梯度值,所以利用DBSCAN作为聚类方法,通过寻找最大欧氏距离进行分类。通过对剩余客户端的短梯度值进行DBSCAN聚类,会得到两组客户端,一组是由添加噪声的恶意客户端和不可靠客户端组成的集合,另一组是由其他客户端组成的集合。本专利技术的方法将计算所有客户端短梯度值的中位值再通过计算其中d
Euc
(
·
)表示计算欧氏距离,通过找到由添加噪声的恶意客户端和不可靠客户端的短梯度值中任意两个连续值之间的最大距离,并使用该最大距离的中点作为分离边界值d
m
,如果那么该客户端k
i
将被检测为添加噪声的恶意客户端;
[0021](3)在排除无目标攻击的恶意客户端后,目标攻击的恶意客户端的目的是将全局模型操纵到一个特定的收敛点。在目标攻击的情况下,一些恶意客户端可能在某些迭代轮次中表现正常,以逃避检测,但在长梯度值的整个记录中,可以被检测到。该方法中,在排除检测到的无目标攻击的恶意客户端后,使用K=2的K

means聚类算法对长梯度值进行分类可以检测出目标攻击恶意客户端;
[0022](4)在排除所有检测到的恶意客户端后,不可靠的客户端的短梯度值会远离所有客户端的短梯度值的中位值。在这种情况下,利用余弦距离可以检测出不可靠客户端,通过
找到所有客户端短梯度值中任意两个连续值之间的最大距离,并使用该最大距离的中点作为分离边界值d
m
,那么用表示为每个客户端和短梯度值中位值的余弦距离,如果客户端满足那么该客户端被检测为不可靠客户端。
附图说明
[0023]图1是本专利技术的流程示意图。
[0024]图2是本专利技术验证恶意客户端和不可靠客户端的过程图。
[0025]图3是实施例中本专利技术与联邦学习平均算法的准确率对比图。
[0026]图4是实施例中本专利技术与联邦学习平均算法的精确度对比图。
具体实施方式
[0027]以下将结合附图,对本专利技术的技术方案进行详细说明。
[0028]本专利技术设计了一种基于长短梯度值的联邦学习恶意客户端检测方法,如图1所示,步骤如下:
[0029]步骤1:中心服务器C发送初始梯度值w1给所有训练客户端K={本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.基于长短梯度值的联邦学习恶意客户端检测方法,其特征在于以下步骤,包括:步骤一:中心服务器C发送初始梯度值w1给所有训练客户端K={k1,k2,...,k
n
},每个训练客户端k
i
∈K,i=1,2,...,n根据各自本地的数据集利用初始梯度值w1进行训练得到本地梯度值步骤二:训练客户端k
i
将本地梯度值发送给中央服务器C,中央服务器C根据梯度值的长短梯度值区分检测恶意客户端k
m
和不可靠客户端k
u
;步骤三:中央服务器排除恶意客户端后,将正常客户端k
r
和不可靠客户端k
u
上传的梯度值进行加权聚合,得到全局梯度值w2。2.根据权利要求1所述的基于长短梯度值的联邦学习恶意客户端检测方法,所述步骤二为了检测恶意客户端和不可靠客户端,引入了长短梯度值用于区分恶意客户端和不可靠客户端:步骤一:利用表示在第r轮训练之前中央服务器从客户端k
i
接收到的梯度值集合。那么在第r轮中训练客户端k
i
的短梯度值可以表示为:其中,x是一个动态的数值,这样可以选取需要的梯度值的学习轮数,短梯度值用来消除单轮梯度值的随机性。那么在第r轮中训练客户端k
i
的长梯度值可以表示为:长梯度值是记录所有时间内梯度值的总和,所以长梯度值可以反映训练客户端对全局模型的累积影响。3.根据权利要求1所述的基于长短梯度值的联邦学习恶意客户端检测方法,所述步骤二可以按照无目标攻击、目标攻击、不可靠客户端和正常客户端的顺序来检测客户端:步骤一:由于无目标攻击客户端的目的是为了破坏整个模型,所以它的短梯度值与正常客户端的短梯度值有较大的不同。在符号翻转的无目标攻击中,恶意客户端的梯度值将改变到相反方向,会导致其短梯度值与所有客户端的短梯度值中位值存在较大的角度偏差,所以可以使用短梯度值计算余弦距离用于检测符号翻转的恶意客户端。如果一个客户端k
i
的短梯度值和中位值的余...

【专利技术属性】
技术研发人员:孙永亮吴雪峰
申请(专利权)人:南京工业大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1