一种模型训练方法以及相关装置制造方法及图纸

技术编号:33782130 阅读:13 留言:0更新日期:2022-06-12 14:36
本申请实施例公开了一种模型训练方法以及相关装置,可应用于云技术、人工智能、智慧交通、辅助驾驶等各种场景中。本申请实施例通过联邦学习,不仅在所有攻击的防御率上得到提高,而且也提高了对于未遭到攻击的原始样本的分类准确率。前述的方法包括:获取训练样本,训练样本包括原始样本和对抗样本,对抗样本是基于预设对抗训练模型对所述原始样本进行对抗训练处理得到的;基于目标权重对训练样本进行重加权处理,并对重加权处理后的训练样本进行训练,得到局部节点的训练模型;确定局部节点的训练模型的梯度信息;向中心节点发送梯度信息,梯度信息用于中心节点更新中心节点的训练模型。模型。模型。

【技术实现步骤摘要】
一种模型训练方法以及相关装置


[0001]本申请实施例涉及计算机
,具体涉及一种模型训练方法以及相关装置。

技术介绍

[0002]随着智能设备的日益普及,其带来的安全问题以及对目标对象的隐私的威胁也逐渐受到人们的关注。当目标对象出于隐私保护的目的而不愿意将个人数据上传至后台服务器时,如何利用分散在各个目标对象的智能设备中的数据完成模型的训练和迭代更新,就成为了亟待解决的问题。
[0003]在联邦学习的场景下,由于存在通信与模型分发的步骤,中心节点的模型很容易被不法对象获取到,从而遭受到“白盒”对抗攻击,带来巨大的安全隐患,如何提高联邦学习场景下的对抗鲁棒性成为了业界的热点问题。在对抗攻防领域中,对抗训练已经被证明是提高对抗鲁棒性的最有效方式。而将对抗训练直接与联邦学习结合是目前业界最常采用的方案,该方案与普通的联邦学习唯一的区别在于将局部节点的普通模型训练改为对抗训练。
[0004]然而,在实际使用过程中,数据往往不符合独立同分布,而且数据在不同类型之间存在严重的不均衡现象。这就使得基于对抗训练直接与联邦学习技术结合的模型训练的方式,会大大地降低模型对于未遭到攻击的原始样本的分类准确率,而且以大幅度牺牲性能的代价来保障模型的安全性,在正常的业务场景中是无法被接纳的。

技术实现思路

[0005]本申请实施例提供了一种模型训练方法以及相关装置,不仅在所有攻击的防御率上得到提高,而且也提高了对于未遭到攻击的原始样本的分类准确率,无需牺牲较大的性能代价来保障模型的安全性。
[0006]第一方面,本申请实施例提供了第一种模型训练方法。该模型训练方法应用于局部节点。在该模型训练方法中,局部节点获取训练样本,训练样本包括原始样本和对抗样本,对抗样本是基于预设对抗训练模型对原始样本进行对抗训练处理得到的。然后局部节点基于目标权重对所述训练样本进行重加权处理,并且对进行所述重加权处理后的训练样本进行对抗训练,得到所述局部节点的训练模型。局部节点确定所述局部节点的训练模型的梯度信息,并向中心节点发送所述梯度信息。所述梯度信息用于所述中心节点更新所述中心节点的训练模型。
[0007]第二方面,本申请实施例提供了第二种模型训练方法。该第二种模型训练方法可以应用在中心节点。在该模型训练方法中,中心节点获取N个局部节点的训练模型的梯度信息,其中,每个所述局部节点的训练模型是由对应的局部节点基于目标权重对训练样本进行重加权处理、并对重加权处理后的训练样本进行训练得到的,所述训练样本包括原始样本和对抗样本,所述对抗样本是基于预设对抗训练模型对所述原始样本进行对抗训练处理得到的,N≥1、且N为整数。然后,中心节点根据所述N个局部节点的训练模型的梯度信息更
新所述中心节点的训练模型。
[0008]第三方面,本申请实施例提供了一种局部节点。该局部节点包括获取单元、处理单元以及发送单元。其中,获取单元用于获取训练样本,训练样本包括原始样本和对抗样本,对抗样本是基于预设对抗训练模型对所述原始样本进行对抗训练处理得到的。处理单元用于根据目标权重对所述训练样本进行重加权处理,并对进行所述重加权处理后的训练样本进行对抗训练,得到所述局部节点的训练模型。所述处理单元用于确定所述局部节点的训练模型的梯度信息。发送单元,用于向中心节点发送所述梯度信息,所述梯度信息用于所述中心节点更新所述中心节点的训练模型。
[0009]在一些可能的实施方式中,所述处理单元用于:通过预设交叉熵损失函数对所述重加权处理后的训练样本进行对抗训练,得到所述局部节点的训练模型,所述局部节点的训练模型受到所述中心节点的训练模型的KL散度值的约束。
[0010]在另一些可能的实施方式中,所述目标权重是基于所述训练样本与所述训练样本的分类边界之间的距离得到。
[0011]在另一些可能的实施方式中,所述训练样本与所述训练样本的分类边界之间的距离,是基于投影梯度下降PGD算法进行迭代攻击成功时的迭代次数得到。
[0012]在另一些可能的实施方式中,所述获取单元还用于接收所述中心节点发送的更新后的所述中心节点的训练模型。
[0013]在另一些可能的实施方式中,所述处理单元还用于基于更新后的中心节点的训练模型对测试样本进行处理,得到每个测试样本的类型。
[0014]在另一些可能的实施方式中,所述处理单元用于根据所述训练样本对所述局部节点的训练模型进行训练,得到所述局部节点的训练模型的梯度信息。
[0015]第四方面,本申请实施例提供了一种中心节点。该中心节点包括获取单元和处理单元。其中,所述获取单元用于获取N个局部节点的训练模型的梯度信息,其中,每个所述局部节点的训练模型是由对应的局部节点基于目标权重对训练样本进行重加权处理、并对重加权处理后的训练样本进行训练得到的,所述训练样本包括原始样本和对抗样本,所述对抗样本是基于预设对抗训练模型对所述原始样本进行对抗训练处理得到的,N≥1、且N为整数。所述处理单元用于根据所述N个局部节点的训练模型的梯度信息更新所述中心节点的训练模型。
[0016]在一些可能的实施方式中,所述中心节点还包括发送单元。该发送单元用于向所述N个局部节点分别发送更新后的所述中心节点的训练模型,其中,更新后的中心节点的训练模型用于每个局部节点对测试样本进行识别处理,得到每个测试样本的类型。
[0017]在另一些可能的实施方式中,处理节点用于根据N个局部节点的梯度信息生成全局信息,并基于全局信息更新中心节点的训练模型。
[0018]在另一些可能的实施方式中,处理节点还用于对梯度信息进行解密处理。
[0019]本申请实施例第五方面提供了一种模型处理装置,包括:存储器、输入/输出(I/O)接口和存储器。存储器用于存储程序指令。处理器用于执行存储器中的程序指令,以执行上述第一方面或第二方面的实施方式对应的模型训练方法。
[0020]本申请实施例第六方面提供了一种计算机可读存储介质,计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行以执行上述第一方面或第二方面的
实施方式对应的模型训练方法。
[0021]本申请实施例第七方面提供了一种包含指令的计算机程序产品,当其在计算机或者处理器上运行时,使得计算机或者处理器执行上述以执行上述第一方面或第二方面的实施方式对应的模型训练方法。
[0022]从以上技术方案可以看出,本申请实施例具有以下优点:
[0023]本申请实施例中,由于训练样本包括原始样本和对抗样本,对抗样本是基于预设对抗训练模型对原始样本进行对抗训练处理得到的,因此通过对各个局部节点对应的训练样本进行重加权处理后,再对重加权处理后的训练样本进行训练,得到局部节点对应的训练模型。然后在这些局部节点的训练模型之间进行联邦学习,可以有效地在训练过程中防止局部节点的数据泄露,提高局部节点的对抗攻击的防御能力,而且能够提升模型对于未遭到攻击的原始样本的分类准确率,无需牺牲较大的性能代价来保障模型的安全性。
附图说明...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,应用于局部节点,所述模型训练方法包括:获取训练样本,所述训练样本包括原始样本和对抗样本,所述对抗样本是基于预设对抗训练模型对所述原始样本进行对抗训练处理得到的;基于目标权重对所述训练样本进行重加权处理,并对所述重加权处理后的训练样本进行训练,得到所述局部节点的训练模型;确定所述局部节点的训练模型的梯度信息;向中心节点发送所述梯度信息,所述梯度信息用于所述中心节点更新所述中心节点的训练模型。2.根据权利要求1所述的模型训练方法,其特征在于,所述对进行所述重加权处理后的训练样本进行对抗训练,得到所述局部节点的训练模型,包括:通过预设交叉熵损失函数对所述重加权处理后的训练样本进行训练,得到所述局部节点的训练模型,所述局部节点的训练模型受到所述中心节点的训练模型的KL散度值的约束。3.根据权利要求1或2所述的模型训练方法,其特征在于,所述目标权重是基于所述训练样本与所述训练样本的分类边界之间的距离得到。4.根据权利要求3所述的模型训练方法,其特征在于,所述训练样本与所述训练样本的分类边界之间的距离,是基于投影梯度下降PGD算法进行迭代攻击成功时的迭代次数得到。5.根据权利要求1

4中任一项所述的模型训练方法,其特征在于,所述模型训练方法还包括:接收所述中心节点发送的更新后的所述中心节点的训练模型。6.一种模型训练方法,其特征在于,应用于中心节点,所述模型训练方法包括:获取N个局部节点的训练模型的梯度信息,其中,每个所述局部节点的训练模型是由对应的局部节点基于目标权重对训练样本进行重加权处理、并对重加权处理后的训练样本进行训...

【专利技术属性】
技术研发人员:李博张杰徐江河刘世策吴双丁守鸿
申请(专利权)人:腾讯科技上海有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1