一种联邦知识蒸馏方法、装置、设备及存储介质制造方法及图纸

技术编号:37119586 阅读:18 留言:0更新日期:2023-04-01 05:14
本申请公开了一种联邦知识蒸馏方法、装置、设备及存储介质,涉及机器学习技术领域,包括:分别利用每个用户端的私有数据集对本地模型进行训练得到训练后本地模型;分别将相同公共数据输入至各个训练后本地模型,得到模型输出结果,并将所有模型输出结果上传至中心服务器,以便中心服务器利用模型输出结果训练判别器模型;通过中心服务器计算所有模型输出结果的平均值得到平均模型输出,并将平均模型输出和判别器模型的损失函数相对于用户端本地模型输出的梯度下发至对应的用户端,以便用户端利用平均模型输出和梯度对用户端本地模型进行知识迁移。本申请利用位于中心服务器的判别器模型能够实现多用户端间灵活高效的知识迁移,并降低通信成本。并降低通信成本。并降低通信成本。

【技术实现步骤摘要】
一种联邦知识蒸馏方法、装置、设备及存储介质


[0001]本申请涉及机器学习
,特别涉及一种联邦知识蒸馏方法、装置、设备及存储介质。

技术介绍

[0002]随着移动通信技术和人工智能的快速发展,用户对于智能应用的需求越来越迫切。为充分利用广泛分布在网络边缘的用户数据,联邦学习应运而生。在联邦学习中,每个用户首先利用私有数据训练一个本地模型,并将模型参数发送给中心服务器,接着中心服务器通过对所有用户模型参数进行平均实现全局模型聚合,该全局模型再广播给所有用户继续本地模型训练。通过周期性的本地模型训练和全局模型聚合,联邦学习能够实现大规模用户设备之间的合作机器学习,同时保证用户数据的隐私性。
[0003]然而,传统的联邦学习具有以下局限性:首先,当用户模型较大时,向中心服务器传输模型参数将消耗大量通信时间,从而降低了整体模型训练的速度;其次,传统联邦学习要求所有用户都具有相同的模型结构,从而进行模型参数平均,然而实际场景中用户设备具有各不相同的计算和存储资源容量,支持不同的本地模型结构,因此,面向异构的用户本地模型以及模型的隐私性需求,传统联邦学习方法难以适用。
[0004]而联邦知识蒸馏可通过交换数据量更小的用户本地模型输出,利用知识蒸馏技术实现黑盒异构模型间的知识迁移,保证用户模型隐私、降低通信成本。不同于传统的联邦学习,用户将本地模型针对相同公共数据的输出上传给中心服务器,并利用中心服务器计算得到的所有用户模型的平均输出进行知识蒸馏,即最小化每个用户的本地模型输出与所有用户全局平均输出之间的距离,实现其他用户模型向本地模型的知识迁移,最终得到在全局数据上具有高精确度的用户模型。然而,现实生活中用户设备收集到的数据通常是非独立同分布的,单纯地对所有用户本地模型输出进行平均难以反应真实的全局数据特征,以所有用户本地模型输出平均作为依据进行知识蒸馏严重限制了联邦知识蒸馏的最终模型性能。因此,针对用户异构的模型和数据特征,亟需设计新的联邦知识蒸馏方法实现灵活高效的用户间知识迁移。

技术实现思路

[0005]有鉴于此,本申请的目的在于提供一种联邦知识蒸馏方法、装置、设备及存储介质,能够实现多用户端间灵活高效的知识迁移,并降低通信成本。其具体方案如下:
[0006]第一方面,本申请公开了一种联邦知识蒸馏方法,包括:
[0007]分别利用每个用户端的私有数据集对所述用户端的本地模型进行训练,得到多个训练后本地模型;
[0008]分别将相同公共数据输入至各个所述训练后本地模型,得到相应的多个模型输出结果,并将所有所述模型输出结果上传至中心服务器,以便所述中心服务器利用所述模型输出结果训练判别器模型;
[0009]通过所述中心服务器计算所有所述模型输出结果的平均值得到平均模型输出,并将所述平均模型输出和所述判别器模型的损失函数相对于所述训练后本地模型输出的梯度下发至对应的所述用户端,以便所述用户端利用所述平均模型输出和所述梯度对所述训练后本地模型进行知识迁移。
[0010]可选的,所述利用所述平均模型输出和所述梯度对所述训练后本地模型进行知识迁移,包括:
[0011]利用所述平均模型输出和所述梯度对训练后本地模型进行最小化损失函数训练。
[0012]可选的,所述利用所述平均模型输出和所述梯度对所述训练后本地模型进行知识迁移之后,还包括:
[0013]判断所有当前用户端本地模型是否收敛,若收敛则停止执行所述知识迁移的过程,若不收敛则周期性的执行所述知识迁移的过程直到所有用户端本地模型收敛。
[0014]可选的,所述私有数据集中包括输入特征和对应的数据标签。
[0015]可选的,所述判别器模型用于识别用户端本地模型所属的用户端,并与所述用户端本地模型训练形成对抗学习,以使所有所述用户端本地模型具有相同的模型输出概率分布,以及对所述用户端本地模型输出所对应的所述数据标签进行分类,以扩大不同数据类别之间的决策边界。
[0016]第二方面,本申请公开了一种联邦知识蒸馏方法,应用于用户端,包括:
[0017]利用本地私有数据集对本地模型进行训练,得到训练后本地模型;
[0018]将相同公共数据输入至所述训练后本地模型,得到模型输出结果;
[0019]将所述模型输出结果上传至中心服务器,以便所述中心服务器利用所述模型输出结果及其他用户端上传的模型输出结果训练判别器模型,并计算所有模型输出结果的平均值得到平均模型输出;
[0020]接收所述中心服务器返回的所述平均模型输出和所述判别器模型的损失函数相对于所述训练后本地模型输出的梯度,并基于所述平均模型输出和所述梯度对所述训练后本地模型进行知识迁移。
[0021]第三方面,本申请公开了一种联邦知识蒸馏方法,应用于中心服务器,包括:
[0022]获取所有用户端上传的将相同公共数据输入至位于所述用户端的训练后本地模型得到的模型输出结果;其中,所述训练后本地模型为基于私有数据集对所述用户端的本地模型进行训练后得到的模型;
[0023]利用所述模型输出结果训练判别器模型;
[0024]计算所有所述模型输出结果的平均值得到平均模型输出,并将所述平均模型输出和所述判别器模型的损失函数相对于所述训练后本地模型输出的梯度下发至对应的所述用户端,以便所述用户端利用所述平均模型输出和所述梯度对所述训练后本地模型进行知识迁移。
[0025]第四方面,本申请公开了一种联邦知识蒸馏装置,包括:
[0026]模型训练模块,用于分别利用每个用户端的私有数据集对所述用户端的本地模型进行训练,得到多个训练后本地模型;
[0027]公共数据输入模块,用于分别将相同公共数据输入至各个所述训练后本地模型,得到相应的多个模型输出结果;
[0028]模型输出结果上传模块,用于将所有所述模型输出结果上传至中心服务器,以便所述中心服务器利用所述模型输出结果训练判别器模型;
[0029]均值计算模块,用于通过所述中心服务器计算所有所述模型输出结果的平均值得到平均模型输出;
[0030]数据下发模块,用于将所述平均模型输出和所述判别器模型的损失函数相对于所述训练后本地模型输出的梯度下发至对应的所述用户端;
[0031]知识迁移模块,用于所述用户端利用所述平均模型输出和所述梯度对所述训练后本地模型进行知识迁移。
[0032]第五方面,本申请公开了一种电子设备,包括处理器和存储器;其中,所述处理器执行所述存储器中保存的计算机程序时实现前述的联邦知识蒸馏方法。
[0033]第六方面,本申请公开了一种计算机可读存储介质,用于存储计算机程序;其中,所述计算机程序被处理器执行时实现前述的联邦知识蒸馏方法。
[0034]可见,本申请先分别利用每个用户端的私有数据集对所述用户端的本地模型进行训练,得到多个训练后本地模型,然后分别将相同公共数据输入至各个所述训练后本地模型,得到相应的多个模型输出结果,并将所有所述模型输出结果上传至中心服务本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种联邦知识蒸馏方法,其特征在于,包括:分别利用每个用户端的私有数据集对所述用户端的本地模型进行训练,得到多个训练后本地模型;分别将相同公共数据输入至各个所述训练后本地模型,得到相应的多个模型输出结果,并将所有所述模型输出结果上传至中心服务器,以便所述中心服务器利用所述模型输出结果训练判别器模型;通过所述中心服务器计算所有所述模型输出结果的平均值得到平均模型输出,并将所述平均模型输出和所述判别器模型的损失函数相对于所述训练后本地模型输出的梯度下发至对应的所述用户端,以便所述用户端利用所述平均模型输出和所述梯度对所述训练后本地模型进行知识迁移。2.根据权利要求1所述的联邦知识蒸馏方法,其特征在于,所述利用所述平均模型输出和所述梯度对所述训练后本地模型进行知识迁移,包括:利用所述平均模型输出和所述梯度对训练后本地模型进行最小化损失函数训练。3.根据权利要求1所述的联邦知识蒸馏方法,其特征在于,所述利用所述平均模型输出和所述梯度对所述训练后本地模型进行知识迁移之后,还包括:判断所有当前用户端本地模型是否收敛,若收敛则停止执行所述知识迁移的过程,若不收敛则周期性的执行所述知识迁移的过程直到所有用户端本地模型收敛。4.根据权利要求1至3任一项所述的联邦知识蒸馏方法,其特征在于,所述私有数据集中包括输入特征和对应的数据标签。5.根据权利要求4所述的联邦知识蒸馏方法,其特征在于,所述判别器模型用于识别用户端本地模型所属的用户端,并与所述用户端本地模型训练形成对抗学习,以使所有所述用户端本地模型具有相同的模型输出概率分布,以及对所述用户端本地模型输出所对应的所述数据标签进行分类,以扩大不同数据类别之间的决策边界。6.一种联邦知识蒸馏方法,其特征在于,应用于用户端,包括:利用本地私有数据集对本地模型进行训练,得到训练后本地模型;将相同公共数据输入至所述训练后本地模型,得到模型输出结果;将所述模型输出结果上传至中心服务器,以便所述中心服务器利用所述模型输出结果及其他用户端上传的模型输出结果训练判别器...

【专利技术属性】
技术研发人员:汉鹏超黄建伟
申请(专利权)人:香港中文大学深圳
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1