基于记忆重放和差分隐私的联邦持续学习训练方法技术

技术编号:35014600 阅读:17 留言:0更新日期:2022-09-21 15:14
一种基于记忆重放和差分隐私的联邦持续学习训练方法,其步骤为:中央服务器对生成对抗网络进行训练,生成一组伪图像组成伪样本集,将伪样本集与联邦学习全局网络下发给每个客户端,客户端使用记忆重放方法将下发的伪样本集与本地样本集混合训练,将训练好的联邦学习全局网络与添加差分隐私噪声的本地样本集异步地上传至中央服务器,中央服务器对上传的本地网络参数进行加权聚合并更新联邦学习全局网络,当没有新任务到达时结束训练。本发明专利技术的方法在不加重客户端计算负担的前提下,降低了客户端对旧任务知识的遗忘,保护了客户端的隐私,提高了联邦学习全局模型的训练效率。提高了联邦学习全局模型的训练效率。提高了联邦学习全局模型的训练效率。

【技术实现步骤摘要】
基于记忆重放和差分隐私的联邦持续学习训练方法


[0001]本专利技术属于数据处理
,更进一步涉及机器学习模型梯度数据处理
中的一种基于记忆重放和差分隐私的联邦持续学习训练方法。本专利技术可用于客户端协同持续训练机器学习模型。

技术介绍

[0002]通常,联邦学习利用分布在客户端本地的隐私数据通过服务器与客户端交互式通信训练的方式获得一个具有良好预测能力的机器学习模型。具体来说,中央服务器通过聚合本地客户端经过本地训练获得的本地模型梯度,获得一个新的全局模型。然后,中央服务器将该全局模型作为下一全局训练回合的初始模型下发给各个客户端,客户端在本地数据集上使用该模型进行本地训练。该联邦学习全局模型更新过程迭代进行,直到满足确定的训练终止条件。通常作为联邦学习中客户端的智能边缘设备面临着持续不断采集的大量任务数据流。但是,联邦学习中的客户端本地模型对旧任务的性能通常会在新任务到来时急剧下降,这被称为灾难性遗忘问题。因此,在任务数据流中的持续学习能力成为制约联邦学习走向实际应用的主要因素。如何减轻联邦学习本地模型在任务数据流中的灾难性遗忘问题成为了发展联邦学习的关键问题。
[0003]中山大学在其申请的专利文献“一种基于联邦学习的在线学习方法、系统、计算机设备及介质”(申请号:202110865716.8,申请公布号:CN 113743616 A,公布日期2021.01.15)中提出了一种基于联邦学习的在线学习方法。该方法实现的步骤是:(1)服务器初始化全局模型的参数,统一将全局模型派发给每个用户终端;(2)每个用户终端接收全局模型;(3)每个用户终端持续采集用户行为产生的数据,并保存于用户终端本地;(4)每个用户终端利用持续采集的数据按到来轮次加权计算损失函数F
k
,对接收到的全局模型进行在线学习训练;(5)每个用户终端将训练好的全局模型参数上传至服务器;(6)服务器进行参数聚合,生成新的全局模型;(7)判断是否达到在线学习终止条件,若是,则在线学习训练结束;否则,返回步骤(3)。该方法存在的不足之处是,联邦在线学习对新旧数据进行按到来轮次加权计算每个客户端的损失函数F
k
,虽然可以处理实时收集的相同任务数据以更新本地模型,但当收集到新任务数据时客户端本地模型就会被新任务知识覆盖,从而逐渐忘记了旧任务知识,发生严重的灾难性遗忘问题,从而导致联邦本地模型在遇到训练过的旧任务时需要重新训练。
[0004]Yoon等人在其发表的论文“Federated Continual Learning with Weighted Inter

client Transfer”(International Conference on Machine Learning 2021.)中提出了一种基于联邦加权客户间传输的联邦持续学习方法(FedWeIT)。该方法的主要步骤是:(1)服务器初始化全局模型和共享参数B;(2)随机采样知识库kb;(3)对每个可以通信的本地客户端下发全局模型与知识库kb;(4)本地客户端使用APD持续学习算法进行本地训练,并将客户端的模型参数计算分解为自适应任务特定参数A和全局共享参数B;(5)客户端将训练后得到的共享参数B与自适应任务特定参数A上传给服务器;(6)服务器使用收到的
共享参数B进行聚合更新全局模型;(7)服务器使用自适应参数A更新知识库kb;(8)判断是否有新任务到达,若有进行步骤(3),否则训练结束。该方法存在的不足之处是:在本地客户端进行模型参数计算分解与训练任务会给客户端带来沉重的计算负担,降低了模型预测精度,且该方法没有考虑服务器与客户端间传输数据时可能会造成的隐私安全问题,因此该方法不适用于联邦学习中计算资源(例如CPU,内存,电池等)受限的客户端(例如手机,智能穿戴设备,IoT设备等)。

技术实现思路

[0005]本专利技术的目的在于针对上述已有技术的不足,提出一种基于记忆重放和差分隐私的联邦持续学习方法,用于解决联邦学习中计算资源受限的客户端本地模型在任务数据流中的灾难性遗忘,导致的联邦本地模型在遇到训练过的旧任务时需要重新训练的问题,以及提高联邦学习中通信数据的隐私保护。
[0006]实现本专利技术目的的技术思路是:本专利技术通过在中央服务器中维护一个由生成对抗网络GAN组成的记忆生成器模型,利用客户端上传的数据持续积累客户端任务知识,并将生成的任务伪数据下发给客户端,客户端将接收到的伪数据与当前任务数据按重要性比例混合训练,从而能实现在不加重客户端计算负担的情况下有效地恢复旧任务知识。此外,本专利技术通过向服务器与客户端通信的数据信息中添加差分隐私噪声并异步传输采样数据,从而降低了通信过程中用户隐私泄露的风险,以此来提高联邦学习中通信数据的隐私保护与通信效率。
[0007]为实现上述目的,专利技术采用的技术方案包括如下步骤:
[0008]步骤1,生成客户端本地样本集:
[0009]步骤1.1,选取至少55000张图像组成样本集,该样本集至少包括五个种类;
[0010]步骤1.2,从每个种类样本集中各随机抽取至少1400张图像组成一个客户端的本地样本集;
[0011]步骤1.3,采用与步骤1.2相同的方法,得到至少10个客户端的本地样本集设置每个客户端的任务样本集按先后顺序到达,当新任务样本集到达时旧任务的样本集立即被丢弃;
[0012]步骤2,生成服务器中生成对抗网络的训练集;
[0013]步骤3,构建联邦学习全局网络;
[0014]步骤4,构建生成对抗网络:
[0015]步骤4.1,构建生成对抗网络中的生成网络(Generator);
[0016]步骤4.2,构建生成对抗网络中的判别网络(Discriminator);
[0017]步骤5,对生成对抗网络进行训练:
[0018]步骤5.1,生成对抗网络生成一张伪图像并标注为负样本;
[0019]步骤5.2,使用训练集中的图像与伪图像共同对生成对抗网络进行训练;
[0020]步骤5.3,利用梯度下降方法,迭代更新生成对抗网络的参数,判断损失函数是否收敛,若收敛,则执行步骤5.4,否则,执行步骤5.1;
[0021]步骤5.4,将一组随机生成的100维高斯噪声向量输入训练好的生成对抗网络中,输出一组生成的伪图像作为伪样本集;
[0022]步骤6,使用记忆重放方法训练联邦学习全局网络:
[0023]步骤6.1,选取最多5个客户端,并将构建的联邦学习全局网络和伪样本集下发给每个参加本轮训练的客户端;
[0024]步骤6.2,按每个任务的重要性比率组合部分伪样本集与部分本地样本集;
[0025]步骤6.3,客户端使用混合后的本地样本集,迭代训练更新网络参数,直到总损失函数收敛为止,将训练好的联邦学习全局网络参数异步地上传给服务器;
[0026]步骤7,对选取的本地图像添加本地差分隐私噪声;
[0027]步骤8,对所有联邦学习全局网本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于记忆重放和差分隐私的联邦持续学习训练方法,其特征在于,使用记忆重放方法训练联邦学习全局网络,对选取的本地图像添加本地差分隐私噪声;该训练方法的具体步骤包括如下:步骤1,生成客户端本地样本集:步骤1.1,选取至少55000张图像组成样本集,该样本集中至少包括五个种类,将每个类至少包含11000张图像样本组成任务样本集;步骤1.2,从每个任务样本集中各随机抽取至少1400张图像组成一个客户端的本地样本集;步骤1.3,采用与步骤1.2相同的方法,得到至少10个客户端的本地样本集,设置每个客户端的任务样本集按先后顺序到达,当新任务样本集到达时旧任务的样本集立即被丢弃;步骤2,生成服务器中生成对抗网络的训练集:从第一个任务样本集中选取至少100张图像组成训练集,该训练集中至少包括一个种类;将所有图像标注为正样本;步骤3,构建联邦学习全局网络:在服务器中搭建一个七层的卷积神经网络作为联邦学习全局网络,其结构依次为:第一卷积层,第一池化层,第二卷积层,第二池化层,第三卷积层,第一全连接层,第二全连接层,将第一至第三卷积层的维度分别设置为28*28,13*13,5*5,卷积核的大小均设置为3*3,第一、第二池化层的池化窗口尺寸均设置为2*2,第一、第二全连接层的维度分别设置为64,10;步骤4,构建生成对抗网络:步骤4.1,生成对抗网络中的生成网络由编码器、卷积长短期记忆网络LSTM、解码器组成,卷积长短期记忆网络LSTM是一个五层的反卷积神经网络,其结构依次为:第一反卷积层,第二反卷积层,第三反卷积层,池化层,归一化层;将第一至第三反卷积层的维度分别设置为5*5,13*13,28*28,卷积核的大小均设置为3*3;将池化层的池化窗口尺寸设置为2*2;将归一化层的维度设置为10;步骤4.2,生成对抗网络中的判别网络是一个五层的卷积神经网络,其结构依次为:第一卷积层,第二卷积层,第三卷积层,池化层,归一化层,将第一至第三卷积层的维度分别设置为28*28,13*13,5*5,卷积核的大小均设置为3*3,卷积核的大小均设置为3*3,池化层的池化窗口尺寸设置为2*2,归一化层的维度设置为10;步骤5,对生成对抗网络进行训练:步骤5.1,将一个随机生成的100维高斯噪声向量输入到生成对抗网络中,将输出的图像作为伪图像,并将该伪图像标注为负样本;步骤5.2,从训练集中随机选取一张图像,将所选图像与伪图像输入到生成对抗网络中进行非线性映射,输出每个样本对应的各自的正、负预测标签与特征向量;步骤5.3,利用梯度下降方法,用损失函数迭代更新生成对抗网络的参数,判断损失函数是否收敛,若收敛,则执行步骤5.4,否则,执行步骤5.1;步骤5.4,将一组随机生成的100维高斯噪声向量输入训练好的生成对抗网络中,通过生成网络,对高斯噪声向量和类别标签向量进行非线性映射,输出一组生成的伪图像,将该组伪图像作为伪样本集;
步骤6,使用记忆重放方法训练联邦学习全局网络:步骤6.1,从所有客户端中随机选取最多5个客户端,作为参与本轮训练联邦学习全局网络的客户端;同时将构建的联邦学习全局网络和伪样本集下发给每个参加本轮训练的客户端;步骤6.2,按每个任务的重要性比率组合部分伪样本集与部分本地样本集,得到混合后的本地样本集;步骤6.3,参与本轮训练的所有客户端使用混合后的本地样本集,利用本地随机梯度下降法,迭代更新网络参数,直到联邦学习全局网络训练的总损失函数收敛为止,得到训练好的联邦...

【专利技术属性】
技术研发人员:张琛白航鱼滨解宇
申请(专利权)人:西安电子科技大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1