System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind() 一种联邦学习模型训练方法、装置、电子设备和存储介质制造方法及图纸_技高网

一种联邦学习模型训练方法、装置、电子设备和存储介质制造方法及图纸

技术编号:41225658 阅读:3 留言:0更新日期:2024-05-09 23:43
本发明专利技术涉及一种联邦学习模型训练方法、装置、电子设备和存储介质,联邦学习模型包括学生模型和预训练的教师模型;包括:对联邦学习模型的嵌入层进行采样,生成伪样本数据;基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;依据所述目标模型参数更新所述联邦学习模型;本发明专利技术实施例可以在保护隐私的同时,通过嵌入结构的样本提取和学习一组伪数据来提高联邦学习模型的性能。

【技术实现步骤摘要】

本专利技术涉及神经网络模型训练,具体涉及一种联邦学习模型训练方法、一种联邦学习模型训练装置、一种电子设备和一种存储介质。


技术介绍

1、随着技术的发展,大型模型在车端的应用正在快速扩展。如说利用大型模型来分析驾驶者和乘客的行为和偏好,从而提供个性化的娱乐、导航和车内服务。包括语音助手、智能推荐系统、驾驶习惯分析等。而且随着技术的进步,特别是在人工智能、机器学习、传感器技术和数据处理方面,大型模型在车端的应用将继续扩展,为驾驶安全、效率、舒适性和车辆维护带来革命性的变化。

2、然而,目前训练大模型需要用到大量真实数据才能提升大模型泛化能力,且目前大部分预训练的大模型已经涵盖了开源的数据,因此要得到满足使用要求的大型模型,需为该大型模型提供更多的领域数据。但目前训练大模型的数据因涉及数据安全及成本等问题,往往很难得到大量的优质数据供大模型进行训练导致数据隐私的差。而如常用的基于差分的方式进行基于知识蒸馏的过程,使用梯度裁剪和加噪声等步骤,模型训练的过程复杂,效率低下。


技术实现思路

1、本专利技术的目的之一在于提供一种联邦学习模型训练方法,以解决现有技术中的模型训练的过程复杂,效率低下和隐私性差的问题;目的之二在于提供一种联邦学习模型训练装置;目的之三在于提供一种电子设备;目的之四在于提供一种存储介质。

2、为了实现上述目的,本专利技术采用的技术方案如下:

3、一种联邦学习模型训练方法,联邦学习模型包括学生模型和预训练的教师模型;所述方法包括:>

4、对联邦学习模型的嵌入层进行采样,生成伪样本数据;

5、基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;

6、采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;

7、依据所述目标模型参数更新所述联邦学习模型。

8、可选的,所述方法还包括:

9、获取所述教师模型与所述学生模型之间的通信指标参数和性能指标参数;

10、依据所述性能指标参数和所述通信指标参数,确定通信效率函数值;

11、依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信。

12、可选的,所述方法还包括:

13、响应于所述联邦学习模型的更新,生成迭代计数值;

14、基于更新后的联邦学习模型,循环执行对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤和所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤,直至所述迭代计数值满足预设迭代条件。

15、可选的,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤包括:

16、对联邦学习模型的嵌入层进行随机采样,生成初始样本集;

17、对所述初始样本集进行目标采样,生成伪样本数据。

18、可选的,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤还包括:

19、依据所述教师模型与所述学生模型,确定对抗效应强度;

20、依据所述对抗效应强度和所述伪样本数据建立对抗损失函数;

21、对所述对抗损失函数进行梯度下降优化,得到目标伪样本;

22、采用所述目标伪样本更新所述伪样本数据。

23、可选的,所述对所述初始样本集进行目标采样,生成伪样本数据的步骤包括:

24、基于所述联邦学习模型的嵌入层的交叉熵和所述初始样本集建立目标损失函数;

25、对所述目标损失函数进行梯度下降优化,生成伪样本数据。

26、可选的,所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤包括:

27、依据所述教师模型的模型参数和前向传播函数计算所述伪样本数据,生成初始分类值;

28、对所述初始分类值进行归一化,生成所述伪样本数据对应的软标签。

29、可选的,所述采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数的步骤包括:

30、采用所述伪样本数据训练所述学生模型,生成输出数据;

31、依据所述输出数据和所述软标签建立交叉熵损失函数;

32、循环计算所述交叉熵损失函数直至满足预设次数,生成目标模型参数。

33、可选的,所述方法还包括:

34、依据所述通信指标参数,确定压缩策略;

35、在所述教师模型与所述学生模型通信时,执行所述压缩策略。

36、可选的,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:

37、依据所述通信效率函数值确定通讯轮次。

38、可选的,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:

39、依据所述通信效率函数值确定异步通讯策略。

40、一种联邦学习模型训练装置,联邦学习模型包括学生模型和预训练的教师模型;所述装置包括:

41、采样模块,用于对联邦学习模型的嵌入层进行采样,生成伪样本数据;

42、分类模块,用于基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;

43、训练模块,用于采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;

44、更新模块,用于依据所述目标模型参数更新所述联邦学习模型。

45、一种电子设备,包括处理器、存储器及存储在所述存储器上并能够在所述处理器上运行的计算机程序,所述计算机程序被所述处理器执行时实现如上所述的联邦学习模型训练方法的步骤。

46、一种计算机可读存储介质,所述计算机可读存储介质上存储计算机程序,所述计算机程序被处理器执行时实现如上所述的联邦学习模型训练方法的步骤。

47、本专利技术的有益效果:

48、(1)本专利技术通过在对联邦学习模型的嵌入层进行采样,生成用于知识蒸馏的伪样本数据,避免了对真实数据的依赖和使用;知识蒸馏过程不需要使用教师模型的实际数据,强化了整个联邦学习模型在联邦学习过程中的数据隐私保护;进一步地,通过多样性随机采样方法得到对伪样本数据,通过这些伪样本数进行训练可以增强模型的泛化能力。

49、(2)本专利技术在蒸馏的过程中,无需依赖生成对抗网络或辅助数据,减轻了计算和通信的负担;提高了训练的效率;不依赖于多个中间模型的协同训练和蒸馏,从而可能减少了模型训练和管理的复杂性;并且不依赖于可用的公共数据集进行蒸馏,在缺乏大量公共数据集的应用场景中使用该方法训练得到的模型可以更具备实用性和准确性。

50、(3)本专利技术对通讯协议进行优化,基于通信效率函数值确定不同的通讯策略,以减少在模型训练过程中的通讯开销,并确保在保护隐私的同时实现高效的模型更新。

本文档来自技高网...

【技术保护点】

1.一种联邦学习模型训练方法,其特征在于,联邦学习模型包括学生模型和预训练的教师模型;所述方法包括:

2.根据权利要求1所述的方法,其特征在于,所述方法还包括:

3.根据权利要求1所述的方法,其特征在于,所述方法还包括:

4.根据权利要求1至3任一项所述的方法,其特征在于,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤包括:

5.根据权利要求4所述的方法,其特征在于,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤还包括:

6.根据权利要求4所述的方法,其特征在于,所述对所述初始样本集进行目标采样,生成伪样本数据的步骤包括:

7.根据权利要求1至3任一项所述的方法,其特征在于,所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤包括:

8.根据权利要求1至3任一项所述的方法,其特征在于,所述采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数的步骤包括:

9.根据权利要求2所述的方法,其特征在于,所述方法还包括:

10.根据权利要求2或9所述的方法,其特征在于,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:

11.根据权利要求10所述的方法,其特征在于,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:

12.一种联邦学习模型训练装置,其特征在于,联邦学习模型包括学生模型和预训练的教师模型;所述装置包括:

13.一种电子设备,其特征在于,包括处理器、存储器及存储在所述存储器上并能够在所述处理器上运行的计算机程序,所述计算机程序被所述处理器执行时实现如权利要求1至11任一项所述的联邦学习模型训练方法的步骤。

14.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储计算机程序,所述计算机程序被处理器执行时实现如权利要求1至11任一项所述的联邦学习模型训练方法的步骤。

...

【技术特征摘要】

1.一种联邦学习模型训练方法,其特征在于,联邦学习模型包括学生模型和预训练的教师模型;所述方法包括:

2.根据权利要求1所述的方法,其特征在于,所述方法还包括:

3.根据权利要求1所述的方法,其特征在于,所述方法还包括:

4.根据权利要求1至3任一项所述的方法,其特征在于,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤包括:

5.根据权利要求4所述的方法,其特征在于,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤还包括:

6.根据权利要求4所述的方法,其特征在于,所述对所述初始样本集进行目标采样,生成伪样本数据的步骤包括:

7.根据权利要求1至3任一项所述的方法,其特征在于,所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤包括:

8.根据权利要求1至3任一项所述的方法,其特征在于,所述采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目...

【专利技术属性】
技术研发人员:谢乐成孟艺凝谭瑞
申请(专利权)人:重庆长安科技有限责任公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1