当前位置: 首页 > 专利查询>南京大学专利>正文

一种融合元学习的多终端协同训练算法及系统技术方案

技术编号:33014534 阅读:12 留言:0更新日期:2022-04-15 08:46
本申请公开了一种融合元学习的多终端协同训练算法及系统,包括客户端加载位于本地的训练模型并初始化网络的权重参数;客户端利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型;服务器对来自多个客户端传输的调整后模型进行融合操作,得到平均模型。本申请提供的方法在联邦学习的基础上,在各个客户端引入针对小样本情境(即少量训练数据)的元学习算法,在训练中可以高效获取少量样本中的元信息,所训练出的模型对于新数据也有较好的迁移性,采用该方法训练出的客户端模型在服务器端进行融合后对于其它客户端的数据集也具有较高的处理精度。端的数据集也具有较高的处理精度。端的数据集也具有较高的处理精度。

【技术实现步骤摘要】
一种融合元学习的多终端协同训练算法及系统


[0001]本专利技术涉及人工智能
,特别涉及一种融合元学习的多终端协同训练算法及系统。

技术介绍

[0002]现如今,任何一个人工智能(AI)项目都可能涉及多个领域,因此需要对来自各个公司、各个部门的数据进行整合。然而,实际应用中,由于各方对数据所有权和隐私性的关注越来越多,对用户隐私及数据的安全管理日趋严格,想要将分散在各地、各个机构的数据进行整合几乎变得不可能。在这种前提下,基于大数据的训练对于某个AI项目来说是高精度的必要保障,因此要求在满足隐私监管要求的前提下,设计一个机器学习框架,联邦学习算法应运而生。
[0003]联邦学习中,一种常见的算法由图1所示,在每个客户端分别使用本地数据训练各自的模型,之后将各自训练好的模型传输到服务器上进行融合,再将融合后的模型传回各客户端做继续训练。由于各客户端中的本地数据经常数目十分有限,导致其采用训练算法获得的模型经常会对本地数据产生过度适配,这样在服务器端将来自不同客户端的模型融合时,各模型无法很快地适配其它客户端上的数据处理,导致整体精度有限,且需要更多轮通信才能获得较为彻底的模型融合。
[0004]现有技术中,为了尽可能减少通信次数,通常采用的方式是通过限制需要和服务器进行通信的本地客户端数目来实现,或者,采用SGD(联邦平均算法)在本地客户端得到测试损失,和服务器进行通信来达到共同训练的效果。然而,SGD的计算效率虽高,但该方法需要大量的训练才能产生较为精确的模型,对于大多数客户端来说,其本地数据量远远不能达到SGD所需的标准,因此,需要在客户端本地的训练过程中引入更高效的算法,以更好地利用本地较少的数据量训练出具备迁移性的改进算法。

技术实现思路

[0005]本申请提供了一种融合元学习的多终端协同训练算法及系统,以解决现有技术中,客户端采用少量数据训练出的模型迁移性较差,融合精确度低的问题。
[0006]第一方面,本申请提供了一种融合元学习的多终端协同训练算法,包括:
[0007]客户端加载位于本地的训练模型并初始化网络的权重参数;
[0008]客户端利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型;
[0009]服务器对来自多个客户端传输的调整后模型进行融合操作,得到平均模型。
[0010]在一些实施例中,得到平均模型后,所述算法还包括:
[0011]服务器获取包含所有客户端存储的数据样本的测试数据集,根据所述测试数据集评估所述平均模型的精度,得到评估结果;
[0012]若所述评估结果为满足要求,则停止数据通信与训练;
[0013]若所述评估结果为未满足要求,则重新执行所述客户端利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型的步骤。
[0014]在一些实施例中,所述测试数据集中的数据样本根据类别不同分为多个数据包,其中,每个数据包采用N-way K-shot表示,N为每个数据包中随机抽取的类别数,way为类别,K为每个类别包含的数据样本数量,shot为数据单位。
[0015]在一些实施例中,所述采用元学习算法调整所述训练模型的步骤包括:
[0016]客户端从本地存储的数据样本中随机抽取一份数据包;
[0017]利用内循环和外循环更新所述训练模型的模型参数。
[0018]在一些实施例中,利用内循环更新所述训练模型的模型参数包括:
[0019]建立多个任务,每个任务采用梯度下降的规则,基于训练模型的原始参数θ得到更新参数θ
i

;其中i表示第i个任务;
[0020]根据更新参数θ
i

计算交叉熵损失L
Ti
,所述交叉熵损失L
Ti
由所有任务下得到的更新参数θ
i

相加得到。
[0021]在一些实施例中,所述外循环更新所述训练模型的模型参数采用下列公式得到:
[0022][0023]其中,θn为调整后模型的模型参数,β为学习率,T
i
指i个任务,ΣTi(*)是指对任务求和,指采用参数θ
i

的模型。
[0024]第二方面,本申请还提供了一种对应与第一方面提供方法的系统。
[0025]本申请提供的方法在联邦学习的基础上,在各个客户端引入针对小样本情境(即少量训练数据)的元学习算法,在训练中可以高效获取少量样本中的元信息,所训练出的模型对于新数据也有较好的迁移性,采用该方法训练出的客户端模型在服务器端进行融合后对于其它客户端的数据集也具有较高的处理精度。
[0026]由于客户端训练出来的模型迁移性较好,所需的模型融合的通信次数显著降低,对于每个客户端来说,采用更少的训练次数、更短的训练时间以及更低的能量消耗即能获取同样的模型精度。
附图说明
[0027]为了更清楚地说明本申请的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,对于本领域普通技术人员而言,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
[0028]图1为现有技术中一种常见的联邦学习算法的原理图;
[0029]图2为本申请一种融合元学习的多终端协同训练算法的流程图;
[0030]图3为图2所示算法中步骤S200的分解步骤图;
[0031]图4为本申请一种融合元学习的多终端协同训练算法在另一种实施例下的流程图;
[0032]图5为本申请提供的方法的其中一种实施例的流程图。
具体实施方式
[0033]鉴于小样本学习和联邦学习目标有重叠,即为了实现客户端侧设备数据隐私保护前提下,训练出一个高精度的集成模型,同时小样本中的元学习训练方案可以帮助模型在未见过的数据上泛化性增强,因此本申请考虑将二者结合,将元学习引入联邦学习的客户端侧训练,提升联邦学习多终端协同训练的性能。提升性能分为三方面:一、在保证学习性能前提下,降低整体通信次数;二、在保证学习性能前提下,降低端侧训练次数;三、相同训练消耗下,集成模型精度提升。该专利技术提出的解决方法是目前首次提出将端侧元学习引入联邦学习、且有效的方案。
[0034]在本申请提供的方案中,所提到的联邦学习是指一种学习技术,它允许用户在不需要集中存储数据的情况下,从这些丰富的数据中取获得共享模型的好处。这种方法还允许我们利用网络边缘可用的廉价计算来扩展学习任务。联邦学习适合任务的特点有:一、对来自移动设备的真实数据的训练比对数据中心通常可用的代理数据的训练具有明显优势;二、所处理的数据是隐私敏感性的或者规模较大的,因此它不适合将其记录到数据中心来进行模型训练;三、对于监督型任务,数据集上的标签可以从用户与他们设备的交互过程中自然推理出来。基于联邦学习不能独立解决当客户端侧数据较少的问题时,本申请提供了一种基于联邦学习的改进型算法。
[0035]参见图2,本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种融合元学习的多终端协同训练算法,其特征在于,包括:客户端加载位于本地的训练模型并初始化网络的权重参数;客户端利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型;服务器对来自多个客户端传输的调整后模型进行融合操作,得到平均模型。2.根据权利要求1所述的一种融合元学习的多终端协同训练算法,其特征在于,得到平均模型后,所述算法还包括:服务器获取包含所有客户端存储的数据样本的测试数据集,根据所述测试数据集评估所述平均模型的精度,得到评估结果;若所述评估结果为满足要求,则停止数据通信与训练;若所述评估结果为未满足要求,则重新执行所述客户端利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型的步骤。3.根据权利要求2所述的一种融合元学习的多终端协同训练算法,其特征在于,所述测试数据集中的数据样本根据类别不同分为多个数据包,其中,每个数据包采用N-way K-shot表示,N为每个数据包中随机抽取的类别数,way为类别,K为每个类别包含的数据样本数量,shot为数据单位。4.根据权利要求3所述的一种融合元学习的多终端协同训练算法,其特征在于,所述采用元学习算法调整所述训练模型的步骤包括:客户端从本地存储的数据样本中随机抽取一份数据包;利用内循环和外循环更新所述训练模型的模型参数。5.根据权利要求4所述的一种融合元学习的多终端协同训练算法,其特征在于,利用内循环更新所述训练模型的模型参数包括:建立多个任务,每个任务采用梯度下降的规则,基于训练模型的原始参数θ得到更新参数θ
i

;其中i表示第i个任务;根据更新参数θ
i

计算交叉熵损失L
Ti
,所述交叉熵损失L
Ti
由所有任务下得到的更新参数θ
i

相加得到。6.根据权利要求5所述的一种融合元学习的多终端协同训练算法,其特征在于,所述外循环更新所述训练模型的模型参数采用下列公式得到:其中,θn为调整后模型的模型参数,β为学习率,T
i
指i个任务...

【专利技术属性】
技术研发人员:王中风王美琪鲁安卓薛瑞鑫林军
申请(专利权)人:南京大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1