跨样本联邦分类建模方法及装置、存储介质、电子设备制造方法及图纸

技术编号:30094144 阅读:12 留言:0更新日期:2021-09-18 08:57
本公开属于联邦学习技术领域,涉及一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法及装置、存储介质、电子设备。该方法包括:获取待联邦分类建模任务制定的标签标准信息,并对联邦分类建模任务的本地神经网络模型进行结构自定义和参数初始化处理;对本地神经网络模型进行训练得到预测标签向量,并对预测标签向量进行知识蒸馏处理得到软标签向量;获取与联邦分类建模任务对应的联邦建模参数,并将本地每个类别的软标签向量发送至协调方;接收协调方返回的联邦标签向量,并对本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型。本公开确保数据隐私的情况下,利用所有联邦参与方的数据提供的“知识”得到更优的联邦分类建模模型。联邦分类建模模型。联邦分类建模模型。

【技术实现步骤摘要】
跨样本联邦分类建模方法及装置、存储介质、电子设备


[0001]本公开涉及联邦学习
,尤其涉及一种基于神经网络和知识蒸馏的跨样本联邦建模方法及装置、存储介质、电子设备。

技术介绍

[0002]随着深度学习研究的深入以及计算机设备的发展,人工神经网络被广泛应用于计算机人工智能领域。而为了保证训练出的人工神经网络具备良好的性能,通常需要大量数据投入训练。
[0003]但是,在某些场景下训练数据散落在不同的组织或机构中,出于数据隐私考虑,无法通过数据共享的方式满足人工神经网络的训练需求。即使能够通过共享的数据训练人工神经网络,那么对训练数据的通信和传输复杂度等要求也难以满足,更无法保障人工神经网络的训练效果。
[0004]鉴于此,本专利技术提出了一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法及装置。

技术实现思路

[0005]本公开的目的在于提供一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法、装置、计算机可读存储介质及电子设备,有利于在确保各个联邦参与方数据隐私的情况下,即通过共享的数据训练人工神经网络,在满足训练数据的通信和传输复杂度要求的基础上,同时保障人工神经网络的训练效果。
[0006]本公开的其他特性和优点将通过下面的详细描述变得显然,或部分地通过本公开的实践而习得。
[0007]根据本专利技术实施例的第一个方面,提供一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法,所述方法包括:
[0008]获取待联邦分类建模任务制定的标签标准信息,并根据所述标签标准信息以及本地训练数据对所述联邦分类建模任务对本地神经网络模型进行结构自定义和参数初始化处理;
[0009]对所述本地神经网络模型进行训练得到所述本地训练数据的预测标签向量,并对所述预测标签向量进行知识蒸馏处理得到本地每个类别的软标签向量;
[0010]获取与所述联邦分类建模任务对应的联邦建模参数,并根据所述联邦建模参数将所述本地每个类别的软标签向量发送至协调方;
[0011]接收所述协调方根据所述软标签向量返回的联邦标签向量,并根据所述联邦标签向量和所述本地训练数据对所述本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型。
[0012]在本专利技术的一种示例性实施例中,
[0013]所述对所述本地神经网络模型进行训练得到所述本地训练数据的预测标签向量,
包括:
[0014]获取用于训练所述联邦分类建模任务的原始训练数据,并对所述原始训练数据进行标签对齐处理和数据过滤处理,得到本地目标训练数据;
[0015]利用所述本地目标训练数据对所述本地神经网络模型进行训练,得到所述本地目标训练数据的预测标签向量。
[0016]在本专利技术的一种示例性实施例中,所述利用所述本地目标训练数据对所述本地神经网络模型进行训练得到所述本地目标训练数据的预测标签向量,包括:
[0017]对所述本地目标训练数据进行数据划分得到本地训练数据集,并对所述本地训练数据集进行数据划分得到多组待训练数据;
[0018]利用多组所述待训练数据对所述本地神经网络模型进行迭代训练,得到所述本地目标训练数据的预测标签向量。
[0019]在本专利技术的一种示例性实施例中,
[0020]所述对所述预测标签向量进行知识蒸馏处理得到本地每个类别的软标签向量,包括:
[0021]获取与知识蒸馏相关的温度参数,并对所述本地训练数据的预测标签向量和所述温度参数进行知识蒸馏计算,得到所述本地训练数据的预测标签的蒸馏向量;
[0022]对同类别的所述本地训练数据的预测标签的蒸馏向量进行平均值计算得到本地每个类别的软标签向量。
[0023]在本专利技术的一种示例性实施例中,
[0024]所述联邦建模参数包括联邦训练轮数,所述根据所述联邦标签向量和所述本地训练数据对所述本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型,包括:
[0025]获取与所述本地训练数据对应的标签数据,并对所述预测标签向量和所述标签数据进行损失计算,得到第一损失值;
[0026]对所述预测标签向量和所述联邦标签向量进行损失计算,得到第二损失值,并根据所述第一损失值和所述第二损失值对所述本地神经网络模型进行更新;
[0027]对更新后的所述本地神经网络模型继续进行训练,直至所述神经网络模型的训练次数达到所述联邦训练轮数时,得到训练好的联邦分类建模模型。
[0028]在本专利技术的一种示例性实施例中,
[0029]所述联邦建模参数包括通信频率条件;
[0030]所述根据所述联邦建模参数将所述本地每个类别的软标签向量发送至协调方,包括:
[0031]当对所述本地神经网络模型的训练过程满足所述通信频率条件时,将所述本地每个类别的软标签向量发送至协调方。
[0032]在本专利技术的一种示例性实施例中,
[0033]所述根据所述标签标准信息以及本地训练数据对所述本地神经网络模型进行结构自定义处理和参数初始化处理,包括:
[0034]确定所述本地神经网络模型的标准结构信息,并按照所述标准结构信息对所述本地神经网络模型进行结构自定义,得到所述本地神经网络模型的网络结构信息;
[0035]对所述本地神经网络模型进行参数初始化处理。
[0036]根据本专利技术实施例的第二个方面,提供一种基于神经网络和知识蒸馏的联邦分类建模装置,所述装置包括:
[0037]模型定义模块,被配置为获取待联邦分类建模任务制定的标签标准信息,并根据所述标签标准信息以及本地训练数据对所述联邦分类建模任务对本地神经网络模型进行结构自定义和参数初始化处理;
[0038]模型训练模块,被配置为对所述本地神经网络模型进行训练得到所述本地训练数据的预测标签向量,并对所述预测标签向量进行知识蒸馏处理得到本地每个类别的软标签向量;
[0039]向量发送模块,被配置为获取与所述联邦分类建模任务对应的联邦建模参数,并根据所述联邦建模参数将所述本地每个类别的软标签向量发送至协调方;
[0040]训练完成模块,被配置为接收所述协调方根据所述软标签向量返回的联邦标签向量,并根据所述联邦标签向量和所述本地训练数据对所述本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型。
[0041]根据本专利技术实施例的第三个方面,提供一种电子设备,包括:处理器和存储器;其中,存储器上存储有计算机可读指令,所述计算机可读指令被所述处理器执行时实现上述任意示例性实施例中的基于神经网络和知识蒸馏的跨样本联邦分类建模方法。
[0042]根据本专利技术实施例的第四个方面,提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现上述任意示例性实施例中的基于神经网络和知识蒸馏的跨样本联邦分类建模方法。
[0043]由上述技术方案可知,本公开示例性实施例中的一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法、装置、计算机存本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于神经网络和知识蒸馏的跨样本联邦分类建模方法,其特征在于,所述方法包括:获取待联邦分类建模任务制定的标签标准信息,并根据所述标签标准信息以及本地训练数据对所述联邦分类建模任务的本地神经网络模型进行结构自定义和参数初始化处理;对所述本地神经网络模型进行训练得到所述本地训练数据的预测标签向量,并对所述预测标签向量进行知识蒸馏处理得到本地每个类别的软标签向量;获取与所述联邦分类建模任务对应的联邦建模参数,并根据所述联邦建模参数将所述本地每个类别的软标签向量发送至协调方;接收所述协调方根据所述软标签向量返回的联邦标签向量,并根据所述联邦标签向量和所述本地训练数据对所述本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型。2.根据权利要求1所述的基于神经网络和知识蒸馏的跨样本联邦分类建模方法,其特征在于,所述对所述本地神经网络模型进行训练得到所述本地训练数据的预测标签向量,包括:获取用于训练所述联邦分类建模任务的原始训练数据,并对所述原始训练数据进行标签对齐处理和数据过滤处理,得到本地目标训练数据;利用所述本地目标训练数据对所述本地神经网络模型进行训练,得到所述本地目标训练数据的预测标签向量。3.根据权利要求2所述的基于神经网络和知识蒸馏的跨样本联邦分类建模方法,其特征在于,所述利用所述本地目标训练数据对所述本地神经网络模型进行训练得到所述本地目标训练数据的预测标签向量,包括:对所述本地目标训练数据进行数据划分得到本地训练数据集,并对所述本地训练数据集进行数据划分得到多组待训练数据;利用多组所述待训练数据对所述本地神经网络模型进行迭代训练,得到所述本地目标训练数据的预测标签向量。4.根据权利要求1所述的基于神经网络和知识蒸馏的跨样本联邦分类建模方法,其特征在于,所述对所述预测标签向量进行知识蒸馏处理得到本地每个类别的软标签向量,包括:获取与知识蒸馏相关的温度参数,并对所述本地训练数据的预测标签向量和所述温度参数进行知识蒸馏计算,得到所述本地训练数据的预测标签的蒸馏向量;对同类别的所述本地训练数据的预测标签的蒸馏向量进行平均值计算,得到本地每个类别的软标签向量。5.根据权利要求1所述的基于神经网络和知识蒸馏的跨样本联邦分类建模方法,其特征在于,所述联邦建模参数包括联邦训练轮数,所述根据所述联邦标签向量和所述本地训练数据对所述本地神经网络模型继续进行训练,以完成训练得到联邦分类建模模型,包括:获取与所述本地训练数据对应的标签数据,并对所述预测标签向量和所述标签数据进行损失计算,得到第一损失值;对所述预测...

【专利技术属性】
技术研发人员:朱帆孟丹李宏宇李晓林
申请(专利权)人:淮安集略科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1