当前位置: 首页 > 专利查询>南京大学专利>正文

一种融合元学习的多终端协同训练算法及系统技术方案

技术编号:33014534 阅读:26 留言:0更新日期:2022-04-15 08:46
本申请公开了一种融合元学习的多终端协同训练算法及系统,包括客户端加载位于本地的训练模型并初始化网络的权重参数;客户端利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型;服务器对来自多个客户端传输的调整后模型进行融合操作,得到平均模型。本申请提供的方法在联邦学习的基础上,在各个客户端引入针对小样本情境(即少量训练数据)的元学习算法,在训练中可以高效获取少量样本中的元信息,所训练出的模型对于新数据也有较好的迁移性,采用该方法训练出的客户端模型在服务器端进行融合后对于其它客户端的数据集也具有较高的处理精度。端的数据集也具有较高的处理精度。端的数据集也具有较高的处理精度。

【技术实现步骤摘要】
一种融合元学习的多终端协同训练算法及系统


[0001]本专利技术涉及人工智能
,特别涉及一种融合元学习的多终端协同训练算法及系统。

技术介绍

[0002]现如今,任何一个人工智能(AI)项目都可能涉及多个领域,因此需要对来自各个公司、各个部门的数据进行整合。然而,实际应用中,由于各方对数据所有权和隐私性的关注越来越多,对用户隐私及数据的安全管理日趋严格,想要将分散在各地、各个机构的数据进行整合几乎变得不可能。在这种前提下,基于大数据的训练对于某个AI项目来说是高精度的必要保障,因此要求在满足隐私监管要求的前提下,设计一个机器学习框架,联邦学习算法应运而生。
[0003]联邦学习中,一种常见的算法由图1所示,在每个客户端分别使用本地数据训练各自的模型,之后将各自训练好的模型传输到服务器上进行融合,再将融合后的模型传回各客户端做继续训练。由于各客户端中的本地数据经常数目十分有限,导致其采用训练算法获得的模型经常会对本地数据产生过度适配,这样在服务器端将来自不同客户端的模型融合时,各模型无法很快地适配其它客户端上的数据处理,导致整体精度本文档来自技高网...

【技术保护点】

【技术特征摘要】
1.一种融合元学习的多终端协同训练算法,其特征在于,包括:客户端加载位于本地的训练模型并初始化网络的权重参数;客户端利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型;服务器对来自多个客户端传输的调整后模型进行融合操作,得到平均模型。2.根据权利要求1所述的一种融合元学习的多终端协同训练算法,其特征在于,得到平均模型后,所述算法还包括:服务器获取包含所有客户端存储的数据样本的测试数据集,根据所述测试数据集评估所述平均模型的精度,得到评估结果;若所述评估结果为满足要求,则停止数据通信与训练;若所述评估结果为未满足要求,则重新执行所述客户端利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型的步骤。3.根据权利要求2所述的一种融合元学习的多终端协同训练算法,其特征在于,所述测试数据集中的数据样本根据类别不同分为多个数据包,其中,每个数据包采用N-way K-shot表示,N为每个数据包中随机抽取的类别数,way为类别,K为每个类别包含的数据样本数量,shot为数据单位。4.根据权利要求3所述的一种融合元学习的多终端协同训练算法,其特征在于,所述采用元学习算法调整所述训练模型的步骤包括:客户端从本地存储的数据样本中随机抽取一份数据包;利用内循环和外循环更新所述训练模型的模型参数。5.根据权利要求4所述的一种融合元学习的多终端协同训练算法,其特征在于,利用内循环更新所述训练模型的模型参数包括:建立多个任务,每个任务采用梯度下降的规则,基于训练模型的原始参数θ得到更新参数θ
i

;其中i表示第i个任务;根据更新参数θ
i

计算交叉熵损失L
Ti
,所述交叉熵损失L
Ti
由所有任务下得到的更新参数θ
i

相加得到。6.根据权利要求5所述的一种融合元学习的多终端协同训练算法,其特征在于,所述外循环更新所述训练模型的模型参数采用下列公式得到:其中,θn为调整后模型的模型参数,β为学习率,T
i
指i个任务...

【专利技术属性】
技术研发人员:王中风王美琪鲁安卓薛瑞鑫林军
申请(专利权)人:南京大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1