一种深度学习模型的训练方法及装置制造方法及图纸

技术编号:19321346 阅读:35 留言:0更新日期:2018-11-03 11:22
本发明专利技术实施例提供了一种深度学习模型的训练方法及装置,方法为:首先训练得到中间深度学习模型;计算训练样本集中各个样本的特征向量以及计算中心点距离目标函数的中间参数的初始值;将中心点距离目标函数添加到中间深度学习模型中并加载中间参数的初始值,得到目标深度学习模型;利用当前目标深度学习模型,计算当前批次数据中各个样本的特征向量,并更新中间参数的参数值;计算中心点距离目标函数的函数值和分类目标函数的函数值,并判断是否符合结束训练的条件;如果否,调整当前目标深度学习模型的参数,导入下一批次数据并进行模型训练;如果是,结束训练。应用本发明专利技术实施例提供的方案可以提升深度学习模型的分类准确率。

Training method and device for deep learning model

The embodiment of the present invention provides a training method and device for a deep learning model. The method is as follows: firstly, an intermediate deep learning model is obtained by training; the initial values of the eigenvectors of each sample in the training sample set and the intermediate parameters of the center point distance objective function are calculated; and the center point distance objective function is added to the training sample set. In the inter-depth learning model, the initial values of intermediate parameters are loaded to get the target depth learning model; using the current target depth learning model, the feature vectors of each sample in the current batch data are calculated, and the parameters of intermediate parameters are updated; the function values of the distance from the center point to the objective function and the function of the classification objective function are calculated. If not, adjust the parameters of the current target depth learning model, import the next batch of data and conduct model training; if so, finish training. The scheme provided by the embodiment of the invention can improve the classification accuracy of the deep learning model.

【技术实现步骤摘要】
一种深度学习模型的训练方法及装置
本专利技术涉及机器学习
,特别是涉及一种深度学习模型的训练方法及装置。
技术介绍
包含分类目标函数的深度学习模型,即带有分类任务的深度学习模型(深度学习分类模型),其通用架构是:首先采用若干组“卷积-非线性激励-池化”模块提取出数据分布的特征,然后采用“全连接”或者“1×1卷积”将特征缩减到指定数量个类别,最后通过softmax等分类目标函数实现类别分值的推断。而将包含分类目标函数的深度学习模型与中心点距离目标函数相结合,能够使得深度学习模型学习到的数据分布更加紧凑。对于包含分类目标函数的深度学习模型,根据每一类别所包含的各个样本的特征向量可以计算得到该类别的中心点,该类别中各个样本到该类别的中心点的距离为中心点距离。在深度学习模型的训练过程中,可以通过减小中心点距离,来提升深度学习模型的分类准确率。目前,在深度学习模型的训练过程中,均采用批次数据训练的形式对模型参数进行更新,但是现有方法是采用批次数据对中心点距离目标函数的函数值进行近似的方式来确定各个类别的中心点距离,也就是说,在训练过程中计算出的中心点距离不准确,从而导致深度学习模型的分类准确率不高。
技术实现思路
本专利技术实施例的目的在于提供一种深度学习模型的训练方法、装置、电子设备及计算机可读存储介质,以提升深度学习模型的分类准确率。具体技术方案如下:第一方面,本专利技术实施例提供了一种深度学习模型的训练方法,所述方法包括:利用训练样本集中的各个样本,对预先构建的初始深度学习模型进行训练,得到中间深度学习模型;其中,所述初始深度学习模型为:加载有分类目标函数的深度学习模型;利用所述中间深度学习模型,计算所述训练样本集中各个样本的特征向量,并根据所述训练样本集中各个样本的特征向量,计算中心点距离目标函数的中间参数的初始值;将所述中心点距离目标函数添加到所述中间深度学习模型中并加载所述中间参数的初始值,得到目标深度学习模型;导入所述训练样本集中的预设数量个样本作为批次数据;利用当前目标深度学习模型,计算当前批次数据中各个样本的特征向量,并根据当前批次数据中各个样本的特征向量,更新所述中间参数的参数值;基于当前批次数据中各个样本的特征向量,计算所述中心点距离目标函数的函数值和所述分类目标函数的函数值,并判断计算得到的中心点距离目标函数的函数值是否收敛到第一预定区间,且计算得到的分类目标函数的函数值是否收敛到第二预定区间;如果否,利用所述中心点距离目标函数的反向传播梯度和所述分类目标函数的反向传播梯度,调整当前目标深度学习模型的参数,并返回执行所述导入所述训练样本集中的预设数量个样本作为批次数据的步骤;如果是,结束对所述当前目标深度学习模型的训练。可选的,所述中心点距离目标函数的中间参数包括:各个类别的中心点、每一类别中的各个样本与该类别的中心点的偏差和、每一类别中的各个样本与该类别的中心点的距离的平方和;所述根据所述训练样本集中各个样本的特征向量,计算中心点距离目标函数的中间参数的初始值的步骤,包括:根据以下公式,计算各个类别的中心点:其中,cj表示类别j的中心点,xi表示利用所述中间深度学习模型计算的所述训练样本集中第i个样本的特征向量,yi表示第i个样本的类别标签,N表示所述训练样本集中样本的数量,nj表示在所述训练样本集中类别j所包含的样本的数量,σ(yi,j)是类别指示函数,且根据以下公式,计算每一类别中的各个样本与该类别的中心点的偏差和:其中,βj表示类别j中各个样本与该类别的中心点的偏差和,cyi表示第i个样本所属类别的中心点;根据以下公式,计算每一类别中的各个样本与该类别的中心点的距离的平方和:其中,δj表示类别j中各个样本与该类别的中心点的距离的平方和。可选的,所述根据当前批次数据中各个样本的特征向量,更新所述中间参数的参数值的步骤,包括:根据以下公式,更新各个类别的中心点:其中,cj、分别表示更新前、后的类别j的中心点,表示利用当前目标深度学习模型计算的所述训练样本集中第i个样本的特征向量,p表示所述当前批次数据中的样本;根据以下公式,更新每一类别中的各个样本与该类别的中心点的偏差和:其中,βj、分别表示更新前、后的类别j中的各个样本与该类别的中心点的偏差和,cyi、分别表示更新前、后第i个样本所属类别的中心点;根据以下公式,更新每一类别中的各个样本与该类别的中心点的距离的平方和:其中,δj、分别表示更新前、后的类别j中的各个样本与该类别的中心点的距离的平方和,Δcyi表示第i个样本所属类别的中心点的偏移量,且ΔTcj表示Δcj的转置。可选的,所述中心点距离目标函数为:其中,K表示所述训练样本集中样本的类别数量,δj表示类别j中各个样本与该类别的中心点的距离的平方和。可选的,所述中心点距离目标函数的反向传播梯度为:其中,表示所述中心点距离目标函数的反向传播梯度,nyi表示在所述训练样本集中第i个样本所属类别所包含的样本的数量,表示利用所述当前目标深度学习模型计算的所述训练样本集中第i个样本的特征向量,表示更新后第i个样本所属类别的中心点,表示更新后第i个样本所属类别中的各个样本与该类别的中心点的偏差和,N表示所述训练样本集中样本的数量。第二方面,本专利技术实施例提供了一种深度学习模型的训练装置,所述装置包括:训练模块,用于利用训练样本集中的各个样本,对预先构建的初始深度学习模型进行训练,得到中间深度学习模型;其中,所述初始深度学习模型为:加载有分类目标函数的深度学习模型;计算模块,用于利用所述中间深度学习模型,计算所述训练样本集中各个样本的特征向量,并根据所述训练样本集中各个样本的特征向量,计算中心点距离目标函数的中间参数的初始值;加载模块,用于将所述中心点距离目标函数添加到所述中间深度学习模型中并加载所述中间参数的初始值,得到目标深度学习模型;导入模块,用于导入所述训练样本集中的预设数量个样本作为批次数据;更新模块,用于利用当前目标深度学习模型,计算当前批次数据中各个样本的特征向量,并根据当前批次数据中各个样本的特征向量,更新所述中间参数的参数值;处理模块,用于基于当前批次数据中各个样本的特征向量,计算所述中心点距离目标函数的函数值和所述分类目标函数的函数值,并判断计算得到的中心点距离目标函数的函数值是否收敛到第一预定区间,且计算得到的分类目标函数的函数值是否收敛到第二预定区间;如果是,结束对所述当前目标深度学习模型的训练;如果否,利用所述中心点距离目标函数的反向传播梯度和所述分类目标函数的反向传播梯度,调整当前目标深度学习模型的参数,并触发所述导入模块。可选的,所述中心点距离目标函数的中间参数包括:各个类别的中心点、每一类别中的各个样本与该类别的中心点的偏差和、每一类别中的各个样本与该类别的中心点的距离的平方和;所述计算模块,具体用于:根据以下公式,计算各个类别的中心点:其中,cj表示类别j的中心点,xi表示利用所述中间深度学习模型计算的所述训练样本集中第i个样本的特征向量,yi表示第i个样本的类别标签,N表示所述训练样本集中样本的数量,nj表示在所述训练样本集中类别j所包含的样本的数量,σ(yi,j)是类别指示函数,且根据以下公式,计算每一类别中的各个样本与该类别的中心点的偏差和:其中,β本文档来自技高网...

【技术保护点】
1.一种深度学习模型的训练方法,其特征在于,所述方法包括:利用训练样本集中的各个样本,对预先构建的初始深度学习模型进行训练,得到中间深度学习模型;其中,所述初始深度学习模型为:加载有分类目标函数的深度学习模型;利用所述中间深度学习模型,计算所述训练样本集中各个样本的特征向量,并根据所述训练样本集中各个样本的特征向量,计算中心点距离目标函数的中间参数的初始值;将所述中心点距离目标函数添加到所述中间深度学习模型中并加载所述中间参数的初始值,得到目标深度学习模型;导入所述训练样本集中的预设数量个样本作为批次数据;利用当前目标深度学习模型,计算当前批次数据中各个样本的特征向量,并根据当前批次数据中各个样本的特征向量,更新所述中间参数的参数值;基于当前批次数据中各个样本的特征向量,计算所述中心点距离目标函数的函数值和所述分类目标函数的函数值,并判断计算得到的中心点距离目标函数的函数值是否收敛到第一预定区间,且计算得到的分类目标函数的函数值是否收敛到第二预定区间;如果否,利用所述中心点距离目标函数的反向传播梯度和所述分类目标函数的反向传播梯度,调整当前目标深度学习模型的参数,并返回执行所述导入所述训练样本集中的预设数量个样本作为批次数据的步骤;如果是,结束对所述当前目标深度学习模型的训练。...

【技术特征摘要】
1.一种深度学习模型的训练方法,其特征在于,所述方法包括:利用训练样本集中的各个样本,对预先构建的初始深度学习模型进行训练,得到中间深度学习模型;其中,所述初始深度学习模型为:加载有分类目标函数的深度学习模型;利用所述中间深度学习模型,计算所述训练样本集中各个样本的特征向量,并根据所述训练样本集中各个样本的特征向量,计算中心点距离目标函数的中间参数的初始值;将所述中心点距离目标函数添加到所述中间深度学习模型中并加载所述中间参数的初始值,得到目标深度学习模型;导入所述训练样本集中的预设数量个样本作为批次数据;利用当前目标深度学习模型,计算当前批次数据中各个样本的特征向量,并根据当前批次数据中各个样本的特征向量,更新所述中间参数的参数值;基于当前批次数据中各个样本的特征向量,计算所述中心点距离目标函数的函数值和所述分类目标函数的函数值,并判断计算得到的中心点距离目标函数的函数值是否收敛到第一预定区间,且计算得到的分类目标函数的函数值是否收敛到第二预定区间;如果否,利用所述中心点距离目标函数的反向传播梯度和所述分类目标函数的反向传播梯度,调整当前目标深度学习模型的参数,并返回执行所述导入所述训练样本集中的预设数量个样本作为批次数据的步骤;如果是,结束对所述当前目标深度学习模型的训练。2.根据权利要求1所述的方法,其特征在于,所述中心点距离目标函数的中间参数包括:各个类别的中心点、每一类别中的各个样本与该类别的中心点的偏差和、每一类别中的各个样本与该类别的中心点的距离的平方和;所述根据所述训练样本集中各个样本的特征向量,计算中心点距离目标函数的中间参数的初始值的步骤,包括:根据以下公式,计算各个类别的中心点:其中,cj表示类别j的中心点,xi表示利用所述中间深度学习模型计算的所述训练样本集中第i个样本的特征向量,yi表示第i个样本的类别标签,N表示所述训练样本集中样本的数量,nj表示在所述训练样本集中类别j所包含的样本的数量,σ(yi,j)是类别指示函数,且根据以下公式,计算每一类别中的各个样本与该类别的中心点的偏差和:其中,βj表示类别j中各个样本与该类别的中心点的偏差和,表示第i个样本所属类别的中心点;根据以下公式,计算每一类别中的各个样本与该类别的中心点的距离的平方和:其中,δj表示类别j中各个样本与该类别的中心点的距离的平方和。3.根据权利要求2所述的方法,其特征在于,所述根据当前批次数据中各个样本的特征向量,更新所述中间参数的参数值的步骤,包括:根据以下公式,更新各个类别的中心点:其中,cj、分别表示更新前、后的类别j的中心点,表示利用当前目标深度学习模型计算的所述训练样本集中第i个样本的特征向量,p表示所述当前批次数据中的样本;根据以下公式,更新每一类别中的各个样本与该类别的中心点的偏差和:其中,βj、分别表示更新前、后的类别j中的各个样本与该类别的中心点的偏差和,分别表示更新前、后第i个样本所属类别的中心点;根据以下公式,更新每一类别中的各个样本与该类别的中心点的距离的平方和:其中,δj、分别表示更新前、后的类别j中的各个样本与该类别的中心点的距离的平方和,表示第i个样本所属类别的中心点的偏移量,且ΔTcj表示Δcj的转置。4.根据权利要求1所述的方法,其特征在于,所述中心点距离目标函数为:其中,K表示所述训练样本集中样本的类别数量,δj表示类别j中各个样本与该类别的中心点的距离的平方和。5.根据权利要求4所述的方法,其特征在于,所述中心点距离目标函数的反向传播梯度为:其中,表示所述中心点距离目标函数的反向传播梯度,表示在所述训练样本集中第i个样本所属类别所包含的样本的数量,表示利用所述当前目标深度学习模型计算的所述训练样本集中第i个样本的特征向量,表示更...

【专利技术属性】
技术研发人员:李诚周晓朱才志
申请(专利权)人:合肥麟图信息科技有限公司
类型:发明
国别省市:安徽,34

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1