模型检查点参数域平均方法、装置、电子设备及存储介质制造方法及图纸

技术编号:35000240 阅读:17 留言:0更新日期:2022-09-21 14:49
本发明专利技术提供一种模型检查点参数域平均方法、装置、电子设备及存储介质,该方法涉及人工智能技术领域,包括:在对待训练深度神经网络模型训练过程中,保存多个检查点和所述多个检查点分别对应的多个模型参数;所述检查点用于表示在训练过程中不同阶段的模型;确定在各所述检查点分别对应的所述模型的损失函数的总损失值;基于所述总损失值,从所述多个检查点中确定待平均检查点;对所述待平均检查点对应的多个模型参数进行参数平均,得到所述模型的参数平均值。本发明专利技术提供的方法,结合了在训练集和验证集上模型的损失函数的总损失值,同时考虑了模型的偏差和方差,提升了模型的性能。提升了模型的性能。提升了模型的性能。

【技术实现步骤摘要】
模型检查点参数域平均方法、装置、电子设备及存储介质


[0001]本专利技术涉及人工智能
,尤其涉及一种模型检查点参数域平均方法、装置、电子设备及存储介质。

技术介绍

[0002]目前,深度学习已经成为推动人工智能技术进步的动源之一。然而,基于深度学习的深度神经网络模型的非凸属性,使得模型最终仅能收敛到某个局部最优点;而且,由于模型初始化时的随机性,以及模型训练过程中梯度下降引入的随机性,使得模型在训练过程中的损失值不断波动;尤其是当模型已经趋近于收敛时,模型的损失值会出现大量的局部最小值点。因此,如何利用这些局部最优点提升模型的性能,是深度学习领域的一项重要任务。
[0003]深度学习模型的训练过程需要迭代多轮训练数据(Epoch),而在训练过程中,可以根据需要设置检查点(Checkpoint)来保存当前阶段的模型训练参数,以对该检查点时刻模型的性能进行评估。
[0004]相关技术中,基于检查点参数域的模型平均已经成为提升深度学习模型性能的基础配置方法;其中,主流的方法包括:最后K个检查点进行参数平均(Last K

Checkpoint Averaging,LKCA)和K个最优检点进行参数平均(K

Best Checkpoint Averaging,KBCA);其中,LKCA策略一般配合早停(Early Stop)策略同时使用,选择的检查点为训练集上接近收敛的K个连续检查点模型进行平均,而BKCA策略则选择验证集上损失函数分值最小的K个检查点模型进行平均。由于损失函数在训练集上的分值代表的是模型对于训练数据的拟合程度,分值越小表示拟合的越好,通常意味着模型的偏差(Bias)越小,而损失函数在验证集上的分值代表的是模型的泛化(Generalization)能力,分值越小表示返还能力越好,通常意味着模型的方差(Variance)越小。根据LKCA方法的定义,其本质上是倾向选择具有更小偏差的K个检查点进行平均,而KBCA方法本质上倾向选择具有更小方差的K个检查点进行平均。根据偏差

方差分解理论,模型最终的性能由方差和偏差共同决定。
[0005]然而,LKCA主要考虑了偏差,而KBCA主要考虑了方差,均没有同时考虑模型的偏差和方差,使得模型的性能较差。

技术实现思路

[0006]本专利技术提供一种模型检查点参数域平均方法、装置、电子设备及存储介质,用以解决现有技术中模型的性能较差的缺陷,实现提升模型的性能。
[0007]本专利技术提供一种模型检查点参数域平均方法,包括:
[0008]在对待训练深度神经网络模型训练过程中,保存多个检查点和所述多个检查点分别对应的多个模型参数;所述检查点用于表示在训练过程中不同阶段的模型;
[0009]确定在各所述检查点分别对应的所述模型的损失函数的总损失值;
[0010]基于所述总损失值,从所述多个检查点中确定待平均检查点;
[0011]对所述待平均检查点对应的多个模型参数进行参数平均,得到所述模型的参数平均值。
[0012]根据本专利技术提供的一种模型检查点参数域平均方法,所述确定在各所述检查点分别对应的所述模型的损失函数的总损失值,包括:
[0013]分别计算各所述检查点在训练集和验证集上损失函数的损失值;
[0014]基于各所述检查点在训练集和验证集上损失函数的损失值,确定各所述检查点分别对应的所述模型的损失函数的总损失值。
[0015]根据本专利技术提供的一种模型检查点参数域平均方法,所述基于各所述检查点在训练集和验证集上损失函数的损失值,确定各所述检查点分别对应的所述模型的损失函数的总损失值,包括:
[0016]基于各所述检查点在验证集上损失函数的损失值,确定连续K个检查点对应的损失函数的损失值之和;所述K为正整数;
[0017]基于所述损失函数的损失值之和,确定所述损失函数的损失值之和最小的所述连续K个检查点;
[0018]基于所述连续K个检查点,确定各所述检查点分别对应的所述模型的损失函数的总损失值。
[0019]根据本专利技术提供的一种模型检查点参数域平均方法,所述基于所述连续K个检查点,确定各所述检查点分别对应的所述模型的损失函数的总损失值,包括:
[0020]基于所述连续K个检查点中第一个检查点在验证集上损失函数的损失值和所述第一个检查点在训练集上损失函数的损失值,计算所述第一个检查点的贡献率;
[0021]根据各所述检查点的当前位置,计算各所述检查点的惩罚因子;
[0022]基于所述贡献率和所述惩罚因子,计算各所述检查点分别对应的所述模型的损失函数的总损失值。
[0023]根据本专利技术提供的一种模型检查点参数域平均方法,所述基于所述总损失值,从所述多个检查点中确定待平均检查点,包括:
[0024]基于所述总损失值,确定连续W个检查点分别对应的所述总损失值之和;所述W为正整数;
[0025]基于所述总损失值之和,确定所述总损失值之和最小的所述连续W个检查点;
[0026]基于所述连续W个检查点,从所述多个检查点中确定待平均检查点。
[0027]根据本专利技术提供的一种模型检查点参数域平均方法,所述对所述待平均检查点对应的多个模型参数进行参数平均,得到所述模型的参数平均值,包括:
[0028]确定所述待平均检查点分别对应的多个模型参数;
[0029]基于所述多个模型参数,计算所述多个模型参数的平均值,得到所述模型的参数平均值。
[0030]本专利技术还提供一种模型检查点参数域平均装置,包括:
[0031]存储模块,用于在对待训练深度神经网络模型训练过程中,保存多个检查点和所述多个检查点分别对应的多个模型参数;所述检查点用于表示在训练过程中不同阶段的模型;
[0032]第一确定模块,用于确定在各所述检查点分别对应的所述模型的损失函数的总损
失值;
[0033]第二确定模块,用于基于所述总损失值,从所述多个检查点中确定待平均检查点;
[0034]参数平均模块,用于对所述待平均检查点对应的多个模型参数进行参数平均,得到所述模型的参数平均值。
[0035]本专利技术还提供一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如上述任一种所述模型检查点参数域平均方法。
[0036]本专利技术还提供一种非暂态计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现如上述任一种所述模型检查点参数域平均方法。
[0037]本专利技术还提供一种计算机程序产品,包括计算机程序,所述计算机程序被处理器执行时实现如上述任一种所述模型检查点参数域平均方法。
[0038]本专利技术提供的模型检查点参数域平均方法、装置、电子设备及存储介质,通过在待训练深度神经网络模型进行训练的过程中,保存多个检查点和多个检查点分别对应的多个模型参数,接着根据各个检查点,本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型检查点参数域平均方法,其特征在于,包括:在对待训练深度神经网络模型训练过程中,保存多个检查点和所述多个检查点分别对应的多个模型参数;所述检查点用于表示在训练过程中不同阶段的模型;确定在各所述检查点分别对应的所述模型的损失函数的总损失值;基于所述总损失值,从所述多个检查点中确定待平均检查点;对所述待平均检查点对应的多个模型参数进行参数平均,得到所述模型的参数平均值。2.根据权利要求1所述的模型检查点参数域平均方法,其特征在于,所述确定在各所述检查点分别对应的所述模型的损失函数的总损失值,包括:分别计算各所述检查点在训练集和验证集上损失函数的损失值;基于各所述检查点在训练集和验证集上损失函数的损失值,确定各所述检查点分别对应的所述模型的损失函数的总损失值。3.根据权利要求2所述的模型检查点参数域平均方法,其特征在于,所述基于各所述检查点在训练集和验证集上损失函数的损失值,确定各所述检查点分别对应的所述模型的损失函数的总损失值,包括:基于各所述检查点在验证集上损失函数的损失值,确定连续K个检查点对应的损失函数的损失值之和;所述K为正整数;基于所述损失函数的损失值之和,确定所述损失函数的损失值之和最小的所述连续K个检查点;基于所述连续K个检查点,确定各所述检查点分别对应的所述模型的损失函数的总损失值。4.根据权利要求3所述的模型检查点参数域平均方法,其特征在于,所述基于所述连续K个检查点,确定各所述检查点分别对应的所述模型的损失函数的总损失值,包括:基于所述连续K个检查点中第一个检查点在验证集上损失函数的损失值和所述第一个检查点在训练集上损失函数的损失值,计算所述第一个检查点的贡献率;根据各所述检查点的当前位置,计算各所述检查点的惩罚因子;基于所述贡献率和所述惩罚因子,计算各所述检查点分别对应的所述模型的损失函数的总损失值...

【专利技术属性】
技术研发人员:王方圆徐波
申请(专利权)人:中国科学院自动化研究所
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1