【技术实现步骤摘要】
神经网络的训练方法、装置、存储介质及电子设备
本专利技术涉及人工智能
,具体而言,涉及一种神经网络的训练方法、装置、存储介质及电子设备。
技术介绍
目前在深度学习领域广泛使用小批量随机梯度下降(Mini-batchStochasticGradientDescent)算法进行神经网络的训练。该算法将训练数据分批次输入至神经网络中进行前向传播以提取数据的特征,然后将提取到的特征带入到损失函数中进行损失计算,最后基于计算得到的损失利用反向传播算法更新神经网络的参数。研究表明,增加每个批次中的数据量,有利于改善训练好的模型的性能。然而,增加一个批次中的数据量,必然导致训练过程所需的计算资源大幅上升,在硬件条件的制约下,这一目标难以实现。
技术实现思路
本申请实施例的目的在于提供一种神经网络的训练方法、装置、存储介质及电子设备,以改善上述技术问题。为实现上述目的,本申请提供如下技术方案:第一方面,本申请实施例提供一种神经网络的训练方法,包括:利用待训练的神经网络对当前批次的训练数据进行 ...
【技术保护点】
1.一种神经网络的训练方法,其特征在于,包括:/n利用待训练的神经网络对当前批次的训练数据进行特征提取,获得所述当前批次的训练数据的特征;/n基于所述当前批次的训练数据的特征计算损失函数的函数值,获得当前损失;/n获取在之前的训练过程中缓存的历史批次的训练数据的特征,所述历史批次的训练数据的特征是在之前的训练过程中利用所述神经网络对所述历史批次的训练数据进行特征提取后得到的;/n基于所述当前批次的训练数据的特征以及所述历史批次的训练数据的特征计算所述损失函数的函数值,获得历史损失;/n基于所述当前损失以及所述历史损失计算总损失,并基于所述总损失利用反向传播算法更新所述神经网络的参数。/n
【技术特征摘要】
1.一种神经网络的训练方法,其特征在于,包括:
利用待训练的神经网络对当前批次的训练数据进行特征提取,获得所述当前批次的训练数据的特征;
基于所述当前批次的训练数据的特征计算损失函数的函数值,获得当前损失;
获取在之前的训练过程中缓存的历史批次的训练数据的特征,所述历史批次的训练数据的特征是在之前的训练过程中利用所述神经网络对所述历史批次的训练数据进行特征提取后得到的;
基于所述当前批次的训练数据的特征以及所述历史批次的训练数据的特征计算所述损失函数的函数值,获得历史损失;
基于所述当前损失以及所述历史损失计算总损失,并基于所述总损失利用反向传播算法更新所述神经网络的参数。
2.根据权利要求1所述的训练方法,其特征在于,所述基于所述当前批次的训练数据的特征计算损失函数的函数值,获得当前损失,包括:
基于所述当前批次的训练数据的特征和标签计算所述损失函数的函数值,获得所述当前损失;
所述获取在之前的训练过程中缓存的历史批次的训练数据的特征,包括:
获取在之前的训练过程中缓存的所述历史批次的训练数据的特征和标签;
所述基于所述当前批次的训练数据的特征以及所述历史批次的训练数据的特征计算所述损失函数的函数值,获得历史损失,包括:
基于所述当前批次的训练数据的特征和标签以及所述历史批次的训练数据的特征和标签计算所述损失函数的函数值,获得所述历史损失。
3.根据权利要求2所述的训练方法,其特征在于,所述训练数据为图像,所述神经网络为用于执行图像识别任务的网络,所述损失函数满足如下条件:
利用所述损失函数进行训练后,所述训练数据中具有相同标签的样本的特征之间的距离尽可能减小,且具有不同标签的样本的特征之间的距离尽可能增大。
4.根据权利要求3所述的训练方法,其特征在于,所述损失函数的函数值基于多个特征组计算出的损失累加得到,每个特征组包括一个正样本的特征、一个负样本的特征以及一个锚样本的特征,每个特征组对应的损失根据该组中的正样本的特征与锚样本的特征之间的第一距离以及负样本的特征与锚样本的特征之间的第二距离计算得到;其中,所述正样本和所述锚样本为所述训练数据中具有相同标签的样本,所述负样本和所述锚样本为所述训练数据中具有不同标签的样本,所述特征组对应的损失与所述第一距离正相关且与所述第二距离负相关;
所述基于所述当前批次的训练数据的特征和标签计算所述损失函数的函数值,获得所述当前损失,包括:
根据所述当前批次的训练数据的标签确定多个当前特征组,所述当前特征组中的特征对应的样本均选择自所述当前批次的训练数据;
基于所述多个当前特征组计算所述损失函数的函数值,获得所述当前损失;
所述基于所述当前批次的训练数据的特征和标签以及所述历史批次的训练数据的特征和标签计算所述损失函数的函数值,获得所述历史损失,包括:
根据所述当前批次的训练数据的标签以及所述历史批次的训练数据的标签确定多个历史特征组,所述历史特征组中的特征对应的样本分别选择自所述当前批次的训练数据以及所述历史批次的训练数据;
基于所述多个历史特征组计算所述损失函数的函数值,获得所述历史损失。
5.根据权利要求4所述的训练方法,其特征在于,每批次的训练数据及其对应的标签从全部训练数据及其对应的标签中通过采样确定,每批次的训练数据中均包括正样本、负样本以及锚样本;
根据所述当前批次的训练数据的标签以及所述历史批次的训练数据的标签确定一个历史特征组,包括:
根据所述当前批次的训练数据的标签,从所述当前批次的训练数据的特征中确定一个历史特征组中正样本的特征以及锚样本的特征,以及,根据所述历史批次的训练数据的标签,从一个历史批...
【专利技术属性】
技术研发人员:肖少然,
申请(专利权)人:北京迈格威科技有限公司,内蒙古旷视金智科技有限公司,
类型:发明
国别省市:北京;11
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。