面向类别不平衡的联邦学习图像分类方法及系统技术方案

技术编号:39330896 阅读:30 留言:0更新日期:2023-11-12 16:07
本发明专利技术公开了面向类别不平衡的联邦学习图像分类方法及系统,多个客户端按照标准的联邦学习方法联合训练一个全局模型;每个客户端将本地训练数据划分为头类和尾类,计算得到尾类所对应的混淆类,基于类激活图对本地训练数据中的尾类和对应的混淆类进行特征提取,得到尾类特有特征和混淆类通用特征;每个客户端将尾类特有特征与对应的混淆类通用特征进行特征融合,得到尾类的增广样本,以增强本地训练数据;每个客户端使用增强的本地训练数据对全局模型进行微调,并将其上传到服务器来进一步更新全局模型。此外,本发明专利技术设计了一个新的损失函数TailDistillation Loss,能够减轻全局类不平衡的影响。类不平衡的影响。类不平衡的影响。

【技术实现步骤摘要】
面向类别不平衡的联邦学习图像分类方法及系统


[0001]本专利技术涉及联邦学习
,特别是涉及面向类别不平衡的联邦学习图像分类方法及系统。

技术介绍

[0002]本部分的陈述仅仅是提到了与本专利技术相关的
技术介绍
,并不必然构成现有技术。
[0003]随着大数据、云计算和人工智能等新一代信息技术的快速发展,现代社会对隐私保护和信息安全提出了新的要求。目前,机器学习,特别是深度学习,在计算机视觉、自然语言处理等领域取得了巨大成功,这些成功都建立在大量数据的基础之上。然而,在许多应用领域,数据通常以分布式的形式存在,受限于法律、法规和版权要求,数据难以进行有效的流通,人们不得不面对难以桥接的数据孤岛问题。在此背景下,联邦学习应运而生,成为打通数据孤岛、避免隐私泄漏,并在更高效地共享数据价值的同时更好地保护数据隐私的关键技术。
[0004]联邦学习作为一种新兴的分布式机器学习范式,利用来自多个客户端的去中心化数据,在中央服务器的协调下联合训练一个共享的全局模型。然而,联邦学习的一个主要的实际挑战是各个客户端本地数据的非独立同分布所导本文档来自技高网...

【技术保护点】

【技术特征摘要】
1.面向类别不平衡的联邦学习图像分类方法,其特征是,包括:服务器向客户端发送当前的全局模型参数,客户端使用所接收到的模型参数和本地训练数据进行模型更新,并将更新后的本地模型上传到服务器,服务器对所有客户端上传的本地模型进行聚合,得到新的全局模型参数,服务器基于新的全局模型参数进行下一轮通信,直至全局模型收敛;每个客户端将本地训练数据划分为头类和尾类,计算得到尾类所对应的混淆类,基于类激活图对本地训练数据中的尾类和其对应的混淆类进行特征提取,得到尾类特有特征和混淆类通用特征;每个客户端将尾类特有特征与对应的混淆类通用特征进行特征融合,得到融合特征,将融合特征作为尾类的增广样本;将尾类的增广样本补充到本地训练数据中,得到增强的本地训练数据;每个客户端使用增强的本地训练数据,对训练得到的全局模型进行再训练,将再训练后的模型上传到服务器,服务器对所有客户端上传的模型进行聚合,进一步更新全局模型,并进行下一轮通信,直至全局模型重新收敛;每个客户端使用最终收敛后的全局模型,对待识别图像进行图像分类。2.如权利要求1所述的面向类别不平衡的联邦学习图像分类方法,其特征是,服务器向客户端发送当前的全局模型参数,客户端使用所接收到的模型参数和本地训练数据进行模型更新,并将更新后的本地模型上传到服务器,服务器对所有客户端上传的本地模型进行聚合,得到新的全局模型参数,服务器基于新的全局模型参数进行下一轮通信,直至全局模型收敛,具体包括:服务器分别与若干个客户端通信,所述服务器和客户端内均设置有相同的图像分类模型;服务器的图像分类模型的初始模型参数为设定值;服务器将初始模型参数发送给所有客户端;每个客户端接收服务器发送过来的初始模型参数,并将初始模型参数设置到图像分类模型中,每个客户端采用本地训练数据对自身的图像分类模型进行训练,得到训练后的图像分类模型,每个客户端将训练后的图像分类模型的参数上传给服务器;服务器接收所有客户端上传的模型参数,根据各个客户端训练数据的数量,对各个客户端的上传的模型参数进行加权求和,得到新的全局模型参数,并将新的全局模型参数下发给所有的客户端,进行下一轮通信,直至服务器的模型参数收敛。3.如权利要求1所述的面向类别不平衡的联邦学习图像分类方法,其特征是,每个客户端将本地训练数据划分为头类和尾类,具体包括:本地训练数据,包括:已知分类标签的图像;构建直角坐标系,横坐标为已知的分类标签,横坐标的分类标签按照图像数量由多到少的顺序从左到右依次排列;纵坐标为每种标签对应图像的数量,纵坐标靠近坐标系原点一端为零,在直角坐标系中,画出图像数量的分布曲线;选择图像数量最少的分类标签为尾类标签,剩余的分类标签均为头类标签,头类标签对应的本地训练数据为头类训练数据,尾类标签对应的本地训练数据为尾类训练数据。4.如权利要求1所述的面向类别不平衡的联邦学习图像分类方法,其特征是,所述计算
得到尾类所对应的混淆类,具体包括:第p个客户端中标签为a1的尾类训练数据中一共有M幅图像;将标签为a1的尾类训练数据中第q幅图像,输入到第p个客户端的图像分类模型中,输出第q幅图像被预测子标签为a1、a2、
……
、a
n
的概率值;以此类推,将标签为a1的尾类训练数据中M幅图像,均依次输入到第p个客户端的图像分类模型中,输出M幅图像被预测子标签的概率值;计算M幅图像子标签为a
r
概率值的平均值,对平均值按照由高到低的顺序进行排序,选择平均值最大的前Q个图像子标签,作为a1类所对应的混淆类,其中,r的取值范围是2~n。5.如权利要求1所述的面向类别不平衡的联邦学习图像分类方法,其特征是,所述基于类激活图对本地训练数据中的尾类和其对应的混淆类进行特征提取,得到尾类特有特征和混淆类通用特征,具体包括:将尾类训练数据,输入到当前客户端的图像分类模型中,计算类激活映射,根据所设定的阈值分离出尾类特有特征;将混淆类训练数据,输入到当前客户端的图像分类模型中,计算类激活映射,根据所设定的阈值分离出混淆类通用特征。6.如权利要求1所述的面向类别不平衡的联邦学习图像分类方法,其特征是,所述基于类激活图对本地训练数据中的尾类和其对应的混淆类进行特征提取,得到尾类特有特征和混淆类通用特征,具体包括:首先,计算类c的分数y
c
相对于图像分类模型中最后一个...

【专利技术属性】
技术研发人员:杨美红刘国正张玮史慧玲谭立状郝昊丁伟
申请(专利权)人:齐鲁工业大学山东省科学院
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1