【技术实现步骤摘要】
基于深度确定性策略梯度的联邦学习分类模型训练方法
[0001]本专利技术属于机器学习领域,涉及一种联邦学习分类模型训练方法,具体涉及一种基于深度确定性策略梯度选择本地训练学习率和最大本地迭代次数的联邦学习分类模型训练方法。
技术介绍
[0002]分类任务是机器学习领域的一种基础任务,其目的是根据目标样本携带的特征,把不同类别的目标样本区分开来。分类任务的样本数据类型包括图像、文本、音频等,实现分类任务的机器学习分类模型包括卷积神经网络模型、逻辑回归模型、支持向量机模型等。训练一个分类精度值高的机器学习分类模型需要中央服务器从客户端处收集大量的训练样本,而收集数据的过程必然会带来通信开销过高的问题。
[0003]基于上述原因,需要一种通信高效的方法来支持训练样本数据共享以实现分类模型训练,于是谷歌提出了联邦学习。联邦学习是一种分布式的机器学习模式,一个联邦学习系统由中央服务器和多个客户端组成。联邦学习的主要执行流程为:中央服务器为每个客户端构建模型后由客户端初始化、客户端在本地更新模型权重参数,并将结果发送给中央服务器、中央服务器进行全局聚合后将结果发送给各客户端,然后进行新的一轮迭代,直到损失函数收敛或者达到设定的最大迭代次数。在“客户端在本地更新模型权重参数”的步骤中,由中央服务器为所有客户端规定相同的联邦学习本地训练超参数,包括控制联邦学习模型的权重参数更新幅度的学习率,以及控制联邦学习模型的本地训练程度的最大本地迭代次数。使用不同的“学习率
‑
最大本地迭代次数”超参数组合最终得到的联邦
【技术保护点】
【技术特征摘要】
1.一种基于深度确定性策略梯度的联邦学习分类模型训练方法,其特征在于,包括如下步骤:(1)构建联邦学习分类系统:构建包括中央服务器以及可与其通信的N个客户端C={c
n
|1≤n≤N}的联邦学习分类系统,其中,N≥2,c
n
表示第n个客户端;(2)客户端获取联邦学习训练数据集和测试数据集:每个客户端c
n
获取包含E个目标类别的U个数据,并对每个数据的目标进行标注后,将其中半数以上的数据及其标签组成训练数据集将剩余的数据及其标签组成测试数据集其中,E≥2,U≥1000;(3)中央服务器为每个客户端c
n
构建联邦学习分类模型H
n
和深度确定性策略梯度模型I
n
:中央服务器为每个客户端c
n
构建包括依次层叠的输入层、隐藏层、输出层的联邦学习分类模型H
n
,同时构建包括并行排布的主网络与目标网络的深度确定性策略梯度模型I
n
,其中,主网络包括顺次连接的actor模块μ
n
和critic模块Q
n
;μ
n
包括依次层叠的输入层、隐藏层和输出层,Q
n
包括依次层叠的输入层、隐藏层和输出层;目标网络包括顺次连接的actor
′
模块μ
′
n
和critic
′
模块Q
′
n
,μ
′
n
的结构和μ
n
的结构相同,Q
′
n
的结构和Q
n
的结构相同;(4)客户端初始化联邦学习训练参数:每个客户端c
n
设置联邦学习全局迭代次数为t,最大全局迭代次数为T,T≥1500,设置联邦学习本地迭代次数为α,第t次全局迭代中客户端c
n
的最大本地迭代次数为每个联邦学习分类模型H
n
的权重参数为令t=1,α=1;(5)客户端c
n
得到联邦学习分类模型H
n
的状态每个客户端c
n
将从测试数据集中无放回地随机选取的D1个测试数据作为联邦学习分类模型H
n
的输入进行前向传播,H
n
的隐藏层对每个测试数据进行特征提取,输出层根据特征对每个测试数据进行分类,得到D1个测试数据的预测标签集合然后通过预测标签值及其对应的真实标签值计算联邦学习分类模型H
n
的损失值精确度值以及F
‑
1值客户端c
n
然后将这三个值顺次排列,得到第t次全局迭代中联邦学习分类模型H
n
的状态(6)每个客户端判断t>1是否成立,若是,利用状态对深度确定性策略梯度模型I
n
的权重参数进行本地更新后,执行步骤(7),否则,执行步骤(7);(7)客户端获取联邦学习本地训练超参数并记录动作每个客户端c
n
将步骤(5)中得到的状态作为μ
n
的输入,μ
n
的隐藏层对状态进行特征提取,输出层根据特征进行预测,得到联邦学习分类模型本地训练超参数:学习率和最大本地迭代次数然后将与顺次排列,记录第t次全局迭代中联邦学习分类模型H
n
的的动作
(8)客户端对联邦学习分类模型的权重参数进行本地更新:(8a)每个客户端c
n
将从训练数据集中无放回地随机选取的D2个训练数据作为联邦学习分类模型H
n
的输入进行前向传播,H
n
的隐藏层对每个训练数据进行特征提取,输出层根据特征对每个训练数据进行分类,得到D2个训练数据的预测标签集合(8b)每个客户端c
n
通过预测标签及其对应的真实标签计算联邦学习分类模型H
n
在训练数据集上的损失值然后...
【专利技术属性】
技术研发人员:王子龙,陈嘉伟,陈谦,柴政,胡嘉琪,
申请(专利权)人:西安电子科技大学,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。