当前位置: 首页 > 专利查询>深圳大学专利>正文

一种模型性能提升方法、装置、计算机设备及存储介质制造方法及图纸

技术编号:38749530 阅读:11 留言:0更新日期:2023-09-09 11:16
本申请实施例属于模型性能优化技术领域,涉及一种模型性能提升方法、装置、计算机设备及存储介质,该方法包括:获取初始客户端模型;将原始样本数据输入至特征提取器进行重新采样操作,得到原始样本特征数据;根据随机采样法对原始样本特征数据进行筛选操作,得到目标样本特征数据;获取初始服务端模型,根据目标样本特征数据对初始服务端模型进行模型训练操作,得到训练好的目标服务端模型;对目标服务端模型以及初始客户端模型进行模型集成操作,得到目标客户端模型。本申请可以在现有联邦算法的基础上大幅度提高准确率,使联邦学习的性能得到了明显提升。的性能得到了明显提升。的性能得到了明显提升。

【技术实现步骤摘要】
一种模型性能提升方法、装置、计算机设备及存储介质


[0001]本申请涉及模型性能优化
,尤其涉及一种模型性能提升方法、装置、计算机设备及存储介质。

技术介绍

[0002]联邦学习等分布式协作机器学习由于其设计的数据隐私保护和资源协作优势而大受欢迎。与数据集中,资源集中的传统方法不同,这类分布式协作方法如联邦学习可以训练具有多个数据源的模型,同时保持计算资源和数据去中心化,不共享原始数据,已成为访问大规模数据以训练强大的深度学习模型的一种有前景的替代方案。但是联邦学习也面临着许多挑战。其中特别重要的是研究发现这类方法在面对客户端数据异质时,会导致客户端模型产生分歧,这极大损害了模型的性能和收敛速度。此外联邦学习范式会带来大量的通信代价,增加训练成本和训练时间。具体来说,联邦学习中周期性的模型参数交换带来的通信开销与模型尺寸呈线性关系。因此对于能得到高精度的大型模型,难以部署在资源受限的边缘设备中。
[0003]目前已有的联邦学习方法主要包括:FedAVG、FedDyn为代表的传统联邦学习算法,其基本思想是提高加权或者正则化等方式减少数据异质性的影响;FedProto,Ditto等个性化联邦学习算法,其主要思想是将联邦学习视作多任务学习因此为每个客户端的数据分布生成其个性化的模型,而不追求传统方法中的全局模型;FedGKT、SpiltFed等模型异构联邦学习方法,这类方法应用在资源受限的边缘设备中训练联邦模型,通过使用知识蒸馏、分裂学习等技术,可以为资源受限的边缘设备训练超过其负担的高精度的大型模型,从而提高联邦学习的使用范围。
[0004]然而,申请人发现,传统的联邦学习方法通过引入聚合权重、添加惩罚项等方法来遏制模型聚合时的发散问题,虽然一定程度缓解了数据异质性问题,但是大部分方法在高度异质的数据分布下性能还是较低,同时增强了客户端本身的负担,并且在大型模型中的表现不佳,由此可见,传统的联邦学习方法存在准确率较低的问题。

技术实现思路

[0005]本申请实施例的目的在于提出一种模型性能提升方法、装置、计算机设备及存储介质,以解决传统的联邦学习方法存在准确率较低的问题。
[0006]为了解决上述技术问题,本申请实施例提供一种模型性能提升方法,采用了如下所述的技术方案:
[0007]获取初始客户端模型,其中,所述初始客户端模型由特征提取器、神经网络结构以及分类器组成,所述初始客户端模型包括本地数据库,所述本地数据库存储有原始样本数据;
[0008]将所述原始样本数据输入至所述特征提取器进行重新采样操作,得到原始样本特征数据;
[0009]根据随机采样法对所述原始样本特征数据进行筛选操作,得到目标样本特征数据;
[0010]获取初始服务端模型,根据所述目标样本特征数据对所述初始服务端模型进行模型训练操作,得到训练好的目标服务端模型;
[0011]对所述目标服务端模型以及所述初始客户端模型进行模型集成操作,得到目标客户端模型。
[0012]进一步的,在所述根据随机采样法对所述原始样本特征数据进行筛选操作,得到目标样本特征数据的步骤之后,还包括下述步骤:
[0013]对所述目标样本特征数据进行噪声添加操作,得到噪声样本特征数据;
[0014]所述获取初始服务端模型,根据所述目标样本特征数据对所述初始服务端模型进行模型训练操作,得到训练好的目标服务端模型的步骤,具体包括下述步骤:
[0015]获取所述初始服务端模型,根据所述噪声样本特征数据对所述初始服务端模型进行模型训练操作,得到所述训练好的目标服务端模型。
[0016]进一步的,所述模型训练操作为基于SGD的模型参数训练及优化操作。
[0017]进一步的,所述模型训练操作为基于Adam算法的模型参数训练及优化操作。
[0018]进一步的,所述模型集成操作的集成方式为投票集成法和/或平均集成法。
[0019]为了解决上述技术问题,本申请实施例还提供一种模型性能提升装置,采用了如下所述的技术方案:
[0020]客户端模型获取模块,用于获取初始客户端模型,其中,所述初始客户端模型由特征提取器、神经网络结构以及分类器组成,所述初始客户端模型包括本地数据库,所述本地数据库存储有原始样本数据;
[0021]重新采样模块,用于将所述原始样本数据输入至所述特征提取器进行重新采样操作,得到原始样本特征数据;
[0022]特征数据筛选模块,用于根据随机采样法对所述原始样本特征数据进行筛选操作,得到目标样本特征数据;
[0023]模型训练模块,用于获取初始服务端模型,根据所述目标样本特征数据对所述初始服务端模型进行模型训练操作,得到训练好的目标服务端模型;
[0024]模型集成模块,用于对所述目标服务端模型以及所述初始客户端模型进行集成操作,得到目标客户端模型。
[0025]进一步的,所述装置还包括:噪声添加模块,所述特征数据筛选模块包括:特征数据筛选子模块,其中:
[0026]所述噪声添加模块,用于对所述目标样本特征数据进行噪声添加操作,得到噪声样本特征数据;
[0027]所述特征数据筛选子模块,用于获取所述初始服务端模型,根据所述噪声样本特征数据对所述初始服务端模型进行模型训练操作,得到所述训练好的目标服务端模型。
[0028]进一步的,所述模型训练操作为基于SGD的模型参数训练及优化操作。
[0029]为了解决上述技术问题,本申请实施例还提供一种计算机设备,采用了如下所述的技术方案:
[0030]包括存储器和处理器,所述存储器中存储有计算机可读指令,所述处理器执行所
述计算机可读指令时实现如上所述的模型性能提升方法的步骤。
[0031]为了解决上述技术问题,本申请实施例还提供一种计算机可读存储介质,采用了如下所述的技术方案:
[0032]所述计算机可读存储介质上存储有计算机可读指令,所述计算机可读指令被处理器执行时实现如上所述的模型性能提升方法的步骤。
[0033]本申请提供了一种模型性能提升方法,包括:获取初始客户端模型,其中,所述初始客户端模型由特征提取器、神经网络结构以及分类器组成,所述初始客户端模型包括本地数据库,所述本地数据库存储有原始样本数据;将所述原始样本数据输入至所述特征提取器进行重新采样操作,得到原始样本特征数据;根据随机采样法对所述原始样本特征数据进行筛选操作,得到目标样本特征数据;获取初始服务端模型,根据所述目标样本特征数据对所述初始服务端模型进行模型训练操作,得到训练好的目标服务端模型;对所述目标服务端模型以及所述初始客户端模型进行模型集成操作,得到目标客户端模型。与现有技术相比,本申请可以在现有联邦算法的基础上大幅度提高准确率,使联邦学习的性能得到了明显提升。
附图说明
[0034]为了更清楚地说明本申请中的方案,下面将对本申请实施例描述中所需要使用的附图作一个简单介绍,显而易见地,下面描述中的附图是本申请的一些实施本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型性能提升方法,其特征在于,包括下述步骤:获取初始客户端模型,其中,所述初始客户端模型由特征提取器、神经网络结构以及分类器组成,所述初始客户端模型包括本地数据库,所述本地数据库存储有原始样本数据;将所述原始样本数据输入至所述特征提取器进行重新采样操作,得到原始样本特征数据;根据随机采样法对所述原始样本特征数据进行筛选操作,得到目标样本特征数据;获取初始服务端模型,根据所述目标样本特征数据对所述初始服务端模型进行模型训练操作,得到训练好的目标服务端模型;对所述目标服务端模型以及所述初始客户端模型进行模型集成操作,得到目标客户端模型。2.根据权利要求1所述的模型性能提升方法,其特征在于,在所述根据随机采样法对所述原始样本特征数据进行筛选操作,得到目标样本特征数据的步骤之后,还包括下述步骤:对所述目标样本特征数据进行噪声添加操作,得到噪声样本特征数据;所述获取初始服务端模型,根据所述目标样本特征数据对所述初始服务端模型进行模型训练操作,得到训练好的目标服务端模型的步骤,具体包括下述步骤:获取所述初始服务端模型,根据所述噪声样本特征数据对所述初始服务端模型进行模型训练操作,得到所述训练好的目标服务端模型。3.根据权利要求1所述的模型性能提升方法,其特征在于,所述模型训练操作为基于SGD的模型参数训练及优化操作。4.根据权利要求1所述的模型性能提升方法,其特征在于,所述模型训练操作为基于Adam算法的模型参数训练及优化操作。5.根据权利要求1所述的模型性能提升方法,其特征在于,所述模型集成操作的集成方式为投票集成法和/或平均集成法。6.一种模型性能提升装置,其特征在于,包括:客户端模型获取模块...

【专利技术属性】
技术研发人员:刘刚邓尤幸助毛睿陈倩婷
申请(专利权)人:深圳大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1