当前位置: 首页 > 专利查询>复旦大学专利>正文

一种基于知识蒸馏的面向设备异构的联邦学习方法技术

技术编号:37322625 阅读:13 留言:0更新日期:2023-04-21 23:02
本发明专利技术属于数据信息隐私保护技术领域,具体为一种基于知识蒸馏的面向设备异构的联邦学习方法。本发明专利技术涉及的系统包括K个客户端、1个服务器;每个客户端上有一个分类模型;为高效地进行客户端模型表示层之间知识蒸馏,把每轮通信分为服务器训练阶段和本地训练阶段;在服务器训练阶段,首先以推断样本低维表示的后验分布为目标在服务器上建立生成模型,然后把训练好的生成模型传递给客户端;在本地训练阶段,客户端一方面用私有样本计算任务损失,一方面用生成模型输出的均值样本计算调优表示层的损失。这样在多轮迭代之后,各个客户端可以得到比传统训练方法精度更高的模型。以得到比传统训练方法精度更高的模型。以得到比传统训练方法精度更高的模型。

【技术实现步骤摘要】
一种基于知识蒸馏的面向设备异构的联邦学习方法


[0001]本专利技术属于数据信息隐私保护
,具体涉及一种面向设备异构的联邦学习方法。

技术介绍

[0002]随着数据量的快速增加,以及出于隐私保护的需求,联邦学习已经发展成为一个非常有前景的方向。联邦学习一般由许多仅能访问私有数据的客户端以及一个可以协调学习过程而不能访问任何原始数据的中央服务器构成。它的目标是在不显式地分享私有数据的前提下,利用分布存储的数据在中央服务器上训练一个全局模型。这种方法面临的一个挑战是数据异构,当各个客户端拥有的数据分布不同时模型的性能会下降。现有的解决数据异构问题的方法大多基于梯度的整合,必须在本地模型同构的条件下进行。
[0003]本专利技术申请专注于设备异构的联邦学习问题。在这个问题背景下,各个客户端的存储、计算和通信能力的不同,造成本地模型的结构也会不同,现有方法会遇到严峻的挑战。在一些实际的联邦学习场景中,需要在硬件差异很大的设备上训练。当模型结构设计的较复杂时,资源较少的设备无法参与训练;当模型结构较简单时,另一些资源充足的设备又未充分利用。
[0004]为此,本专利技术提出一种基于知识蒸馏的联邦学习算法,它允许每个客户端建立个性化模型,可以同时解决模型异构和数据异构两种挑战。算法把每轮通信分为两个阶段,在服务器训练阶段,首先以推断样本低维表示的后验分布为目标在服务器上建立生成模型,然后把训练好的生成模型传递给客户端;在本地训练阶段,客户端一方面用私有样本计算任务损失,一方面用生成模型输出的均值样本计算调优表示层的损失。这样在多轮迭代之后,各个客户端可以得到比传统训练方法精度更高的模型。

技术实现思路

[0005]本专利技术的目的在于提出一种面向设备异构的联邦学习方法,以便在客户端资源差异很大的场景下进行联邦学习,从而为挖掘数据信息提供有力保障。
[0006]本专利技术提出的面向设备异构的联邦学习方法,是基于知识蒸馏技术的;其涉及的系统包括有K个客户端、1个服务器;其中;
[0007]每个客户端上有1个根据软硬件资源设置的分类模型,客户端的分类模型划分为表示层和决策层,表示层用于把样本映射为低维表示,决策层用于把低维表示映射为概率向量;客户端之间知识蒸馏的目标函数定义式:
[0008][0009]其中,K是客户端数量;X
k
是第k个客户端的私有数据集,x、y是样本和标签;L
k
是分类任务的损失函数;φ(
·
)是客户端模型的决策层函数,f
k
(
·
)是客户端模型的表示层函数,θ
k
是表示层参数;γ是超参数;各个客户端上的分类模型的表示层结构和参数不同,而
决策层的结构和初始化参数均相同。
[0010]所述服务器上设置1个生成模型,由共享输入的均值函数和方差函数构成。系统的目标是高效地求解(1)式,为此,本专利技术方法把每轮通信分为两个阶段:服务器训练阶段和本地训练阶段;在服务器训练阶段,首先以推断样本低维表示的后验分布为目标在服务器上建立生成模型,然后把训练好的生成模型传递给客户端;在本地训练阶段,客户端一方面用私有样本计算任务损失,一方面用生成模型输出的均值样本计算调优表示层的损失。这样在多轮迭代之后,各个客户端可以得到比传统训练方法精度更高的模型。
[0011]具体地:
[0012]在服务器训练阶段,服务器首先收集所有客户端模型的表示层;收集客户端采集的低维表示,组成集合Z。然后,用变分推断法求解后验分布具体地:
[0013]假设是高斯分布,生成模型把Z作为输入,均值函数和方差函数分别输出的均值和方差,即:
[0014][0015]其中,Z是低维表示的集合,是生成模型估计的样本集合;U(
·
)和θ
U
分别是生成模型的均值函数和参数;V(
·
)和θ
V
分别是生成模型的方差函数和参数;是均值u方差v的高斯分布;
[0016]从中采样的样本分别输入所有收集的表示层,得到K个低维表示;如果样本的标签是j,则以为权重求这些低维表示的加权均值,计算加权均值与Z的欧式距离损失(其中N
j
是所有客户端第j类样本的总数量,N
k,j
是第k个客户端上第j类样本的总数量)。
[0017]另一方面,计算与标准高斯分布的KL散度损失。以上过程即以(3)式为损失函数用随机梯度下降法训练生成模型:
[0018][0019]其中,f
k
(
·
)是客户端模型的表示层函数,θ
k
是表示层参数;λ是超参数;是标准高斯分布;KL(
·
)是KL散度(Kullback

Leibler divergence);w
k
是客户端权重;其他符号的含义与(2)中的相同;
[0020]最后,服务器把训练后的生成模型的均值函数传给所有客户端。
[0021]在本地训练阶段,客户端首先接收服务器传来的均值函数。然后,一方面用私有数据集X
k
计算分类任务损失;另一方面,收集计算分类任务损失过程中产生的低维表示,把低维表示输入均值函数得到均值样本,计算均值样本与X
k
的欧式距离损失,以上过程即以(4)式为损失函数用随机梯度下降法训练本地分类模型:
[0022][0023]其中,X
k
是第k个客户端的私有数据集,x、y是样本和标签;L
k
是分类任务的损失函数;φ(
·
)是客户端模型的决策层函数,f
k
(
·
)是客户端模型的表示层函数,θ
k
是表示层参数;λ

是超参数;其他符号的含义与(2)中的相同;
[0024]接着,客户端采集低维表示,具体地说,客户端继续执行随机梯度下降过程,每执行q轮,把这期间得到的低维表示分标签求均值,在得到至少c个低维表示均值后停止采集。
[0025]最后,客户端把分类模型表示层和采集的低维表示均值上传给服务器。
[0026]重复服务器训练阶段和本地训练阶段,这样在多轮迭代之后,各个客户端可以得到比传统训练方法精度更高的模型。
[0027]本专利技术的特点和优势主要有:
[0028]第一它允许不同客户端有不同分布的数据和不同结构的模型表示层,可以同时解决模型异构和数据异构两种挑战,拓展了应用场景;
[0029]第二,它允许每个客户端建立个性化模型,在数据异构的场景下,相比建立全局模型的其它方法,它可以使系统获得更高的平均精度;
[0030]第三,它基于知识蒸馏技术优化表示层,在从相关客户端获取信息的同时减少不相关客户端的干扰,使本地模型的精度比优化决策层的方法或其它传统方法更高;第四,在通信过程中,客户端不是上传低维表示的原值而是均值,本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于知识蒸馏的面向设备异构的联邦学习方法,所涉及的系统包括有K个客户端、1个服务器;其特征在于:每个客户端上有1个根据软硬件资源设置的分类模型,客户端的分类模型划分为表示层和决策层,表示层用于把样本映射为低维表示,决策层用于把低维表示映射为概率向量;客户端模型表示层之间知识蒸馏的目标函数定义式:其中,K是客户端数量;X
k
是第k个客户端的私有数据集,x、y是样本和标签;L
k
是分类任务的损失函数;φ(
·
)是客户端模型的决策层函数,f
k
(
·
)是客户端模型的表示层函数,θ
k
是表示层参数;γ是超参数;各个客户端上的分类模型的表示层结构和参数不同,而决策层的结构和初始化参数均相同;所述服务器上设置1个生成模型,由共享输入的均值函数和方差函数构成;系统的目标是高效地求解(1)式;为此,把每轮通信分为两个阶段:服务器训练阶段和本地训练阶段;在服务器训练阶段,首先以推断样本低维表示的后验分布为目标在服务器上建立生成模型,然后把训练好的生成模型传递给客户端;在本地训练阶段,客户端一方面用私有样本计算任务损失,一方面用生成模型输出的均值样本计算调优表示层的损失;这样在多轮迭代之后,各个客户端可以得到比传统训练方法精度更高的模型。2.根据权利要求1所述的基于知识蒸馏的面向设备异构的联邦学习方法,其特征在于,在服务器训练阶段,服务器首先收集所有客户端模型的表示层;收集客户端采集的低维表示,组成集合Z;然后,用变分推断法求解后验分布具体地:假设是高斯分布,生成模型把Z作为输入,均值函数和方差函数分别输出的均值和方差,即:其中,Z是低维表示的集合,是生成模型估计的样本集合;U(
·
)和θ
U
分别是生成模型的均值函数和参数;V(
·
)和θ
V
分别是生成模型的方差函数和参数;是均值u方差v的高斯分布;从中采样的样本分别输入所有收集的表示层,得到K个低维表示;如果样本的标签是...

【专利技术属性】
技术研发人员:王智慧焦孟骁
申请(专利权)人:复旦大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1