一种基于多方3D打印数据库联合训练的方法技术

技术编号:33210089 阅读:25 留言:0更新日期:2022-04-24 01:03
本发明专利技术公开了一种基于多方3D打印数据库联合训练的方法,在第j次训练迭代过程中,得到训练成员i的梯度矩阵G

【技术实现步骤摘要】
一种基于多方3D打印数据库联合训练的方法


[0001]本专利技术属于打印数据联合处理的
,具体涉及一种基于多方3D打印数据库联合训练的方法。

技术介绍

[0002]上世纪八十年代,3D打印技术诞生了,3D打印并不仅限于传统的“去除”加工方法,而且3D打印是一种自下而上的制造方式,也称为增材制造技术,其实现了数学模型的建立。3D打印技术自诞生之日起就受到人们的广泛关注,因此获得了快速发展。近几十年来,3D打印技术已成为人们关注的焦点。工业设计,建筑,汽车,航空航天,牙科,教育领域等都被应用,但是其应用和开发仍然受到因素的限制。
[0003]在3D打印实施过程中,由于3D打印相关参数太多,在实验过程中无法穷尽所有3D打印参数,并判断这些参数是否能够成型合适的零件,因此需要一种3D打印参数学习和预测的方式实现3D打印参数的预测。
[0004]由于3D打印实验成本高昂,由一家企业或单位完成所有实验无太大可能,可以基于多个数据库共同训练得到更加精准的模型参数,这里,就涉及到多个数据库之间的保密问题。例如,A公司拥有n个数据,B公司拥有m个数据,双方均不想让对方知道自己的工艺参数,但又希望联合进行模型训练。因此,需要一种基于多方3D打印数据库联合训练的方法。

技术实现思路

[0005]本专利技术的目的在于提供一种基于多方3D打印数据库联合训练的方法,旨在解决上述问题。
[0006]本专利技术主要通过以下技术方案实现:一种基于多方3D打印数据库联合训练的方法,包括多个训练成员以及服务器,所述训练成员的模型为Wi,每个训练成员的数据为Xi,,标签为y
i
;所述服务器的模型为W,且服务器的模型W与训练成员的模型Wi的网络结构一致;包括以下步骤:步骤S100:在第j次训练迭代过程中,训练成员i读取Xi中一个batch的数据bi,并进行模型Wi的前向传播,得到预测标签,进而根据实际标签y
i
,计算得到模型Wi的损失函数,进而利用反向传播算法得到梯度矩阵G
i
;步骤S200:训练成员i对梯度矩阵G
i
中的元素按照绝对值大小进行从大至小排序,并选择前m个元素得到对应是稀疏矩阵,填充元素为0;步骤S300:计算各训练成员自身数据质量Pi与服务器所有样本对应的数据质量Q之间的相关度,并基于相关度进行排序,获取得到参与当前迭代的训练成员;步骤S400:使用参与当前迭代的训练成员稀疏矩阵更新服务器的模型,并对应更新训练成员的模型W
i

[0007]为了更好地实现本专利技术,进一步地,在迭代训练之前,进行模型初始化:服务器对模型W进行初始化,并将初始化结果下发至所有的训练成员,对模型W
i
进行初始化,确定梯度上传比例系数α、衰减系数ρ、学习率γ。
[0008]为了更好地实现本专利技术,进一步地,所述步骤S200中,统计得到模型W
i
的中元素总个数为M,计算得到本次需要上传的梯度元素个数。
[0009]为了更好地实现本专利技术,进一步地,所述步骤S300中相关度计算如下:其中:其中:D
KL
为KL散度,P表示各训练成员的自身数据质量,Q表示服务器所有样本的数据质量。
[0010]为了更好地实现本专利技术,进一步地,所述步骤S300中,训练成员i计算得到加权模型参数,并利用秘密共享算法对进行加密得到,并上传至服务器;其中:为加权模型参数;为加密的加权模型参数。
[0011]为了更好地实现本专利技术,进一步地,所述步骤S400中,服务器更新模型,服务器将更新的模型下发至本地,并更新训练成员的模型W
i
;其中:t为模型更新次数,γ为学习率,K为上传数据的数训练成员数。
[0012]本专利技术的有益效果:1、本专利技术可以应用于在保证各方数据安全的情况下,各方协同训练机器学习模型供多方使用的场景。在这个场景中,多个数据方拥有自己的数据,他们想共同使用彼此的数据来统一建模(例如,分类模型、线性回归模型、逻辑回归模型等),并通过梯度稀疏矩阵的方式保证各自的数据不被泄露,具有较好的实用性;2、本专利技术还可以基于相关度确定当前轮迭代的训练成员,从而实现在训练过程中仅有部分训练成员需要进行数据上传,降低了联合训练过程中数据的传输量,降低了数据传输带宽的需求和投入成本,具有较好的实用性;3、本专利技术通过加权模型参数矩阵的设计使得不同数据质量的训练样本具有不同的权重,这样的设置使得更高质量的训练样本可以对模型的训练方向起到更大的作用,从
而使得整个多轮训练过程更容易收敛,提升了联合训练的效率,减小了总体训练的轮数。
具体实施方式
[0013]实施例1:一种基于多方3D打印数据库联合训练的方法,包括多个训练成员以及服务器,所述训练成员的模型为Wi,每个训练成员的数据为Xi,,标签为y
i
;所述服务器的模型为W,且服务器的模型W与训练成员的模型Wi的网络结构一致;包括以下步骤:步骤S100:在第j次训练迭代过程中,训练成员i读取Xi中一个batch的数据bi,并进行模型Wi的前向传播,得到预测标签,进而根据实际标签y
i
,计算得到模型Wi的损失函数,进而利用反向传播算法得到梯度矩阵G
i
;步骤S200:训练成员i对梯度矩阵G
i
中的元素按照绝对值大小进行从大至小排序,并选择前m个元素得到对应是稀疏矩阵,填充元素为0;步骤S300:计算各训练成员自身数据质量Pi与服务器所有样本对应的数据质量Q之间的相关度,并基于相关度进行排序,获取得到参与当前迭代的训练成员;步骤S400:使用参与当前迭代的训练成员稀疏矩阵更新服务器的模型,并对应更新训练成员的模型W
i

[0014]实施例2:本实施例是在实施例1的基础上进行优化,在迭代训练之前,进行模型初始化:服务器对模型W进行初始化,并将初始化结果下发至所有的训练成员,对模型W
i
进行初始化,确定梯度上传比例系数α、衰减系数ρ、学习率γ。
[0015]进一步地,所述步骤S200中,统计得到模型W
i
的中元素总个数为M,计算得到本次需要上传的梯度元素个数。
[0016]本实施例的其他部分与实施例1相同,故不再赘述。
[0017]实施例3:本实施例是在实施例1或2的基础上进行优化,所述步骤S300中相关度计算如下:其中:其中:D
KL
为KL散度,P表示各训练成员的自身数据质量,Q表示服务器所有样本的数据质量。
[0018]本实施例的其他部分与上述实施例1或2相同,故不再赘述。
[0019]实施例4:
一种基于多方3D打印数据库联合训练的方法,以水平切分的分类任务为例,假设共有k个训练成员,每个训练成员的数据集为Xi,,标签为y
i
,训练成员的模型为Wi,训练过程中对应的模型梯度为G
i
,服务器的模型W与训练成员模本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于多方3D打印数据库联合训练的方法,其特征在于,包括多个训练成员以及服务器,所述训练成员的模型为Wi,每个训练成员的数据为Xi,,标签为y
i
;所述服务器的模型为W,且服务器的模型W与训练成员的模型Wi的网络结构一致;包括以下步骤:步骤S100:在第j次训练迭代过程中,训练成员i读取Xi中一个batch的数据bi,并进行模型Wi的前向传播,得到预测标签,进而根据实际标签y
i
,计算得到模型Wi的损失函数,进而利用反向传播算法得到梯度矩阵G
i
;步骤S200:训练成员i对梯度矩阵G
i
中的元素按照绝对值大小进行从大至小排序,并选择前m个元素得到对应是稀疏矩阵,填充元素为0;步骤S300:计算各训练成员自身数据质量Pi与服务器所有样本对应的数据质量Q之间的相关度,并基于相关度进行排序,获取得到参与当前迭代的训练成员;步骤S400:使用参与当前迭代的训练成员稀疏矩阵更新服务器的模型,并对应更新训练成员的模型W
i
。2.根据权利要求1所述的一种基于多方3D打印数据库联合训练的方法,其特征在于,在迭代训练之前,进行模型初始化:服务器对模型W进...

【专利技术属性】
技术研发人员:荣鹏高鹏高川云杜娟
申请(专利权)人:成都飞机工业集团有限责任公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1