一种神经网络模型交互训练方法、装置及存储介质制造方法及图纸

技术编号:32784381 阅读:24 留言:0更新日期:2022-03-23 19:43
本发明专利技术涉及一种神经网络模型交互训练方法及装置,其方法包括:确定参与交互训练的一个主神经网络,以及至少一个次神经网络;根据所述主神经网络和次神经网络之间的分布差异,确定参与交互训练的目标函数;根据所述目标函数训练所述主神经网络和次神经网络,直至目标函数值达到阈值且趋于稳定,得到训练完成的主神经网络。本发明专利技术提出一种神经网络交互训练方法,同时采用KL散度来度量主网络与次网络的预测概率分布差异,实现主次网络交互学习经验来引导主网络学习,从而获得与次网络相近或略高的性能,也缓解了主网络单独训练时收敛缓慢,容易陷入局部最优,特别在训练样本量限制条件下,网络模型泛化性较弱,检出率偏低等问题。检出率偏低等问题。检出率偏低等问题。

【技术实现步骤摘要】
一种神经网络模型交互训练方法、装置及存储介质


[0001]本专利技术属于深度学习
,具体涉及一种神经网络模型交互训练方法、装置及存储介质。

技术介绍

[0002]近几年来深度学习神经网络在计算机视觉,自然语言处理与智能语音识别等领域取得了令人瞩目的成就,各种应用场景越来越成熟,但是在比较复杂的环境下,算法模型往往会表现出不稳定的预测结果,降低应用体验。研究人员发现造成上述问题的主要原因是模型在迭代训练时未充分学习到复杂场景信息,导致预测结果不够精确。
[0003]为了提升网络模型的性能,当前的主流解决方案是通过大量收集各个场景下的有效样本以及数据增广等方法来支撑模型迭代训练,显然足够的样本量可以稳定地提升模型的性能,但同时带来了训练成本的大幅增加,训练样本的制作是一个缓慢的过程,当前主要依靠人工标注获取,效率较低。特别地,在进行网络模型验证和评估时,训练样本有限的情况下,模型输出精度往往不理想,某种程度上影响模型的性能指标,所以需要一种网络模型训练策略缓解上述问题。

技术实现思路

[0004]为解决神经网络在训练样样本有限情况下精度不高,以及训练过程中由于训练环境复杂的情况下神经网络不稳定的问题,在本专利技术的第一方面提供了一种神经网络模型交互训练方法,包括:确定参与交互训练的一个主神经网络,以及至少一个次神经网络;根据所述主神经网络和次神经网络之间的分布差异,确定参与交互训练的目标函数;根据所述目标函数训练所述主神经网络和次神经网络,直至目标函数值达到阈值且趋于稳定,得到训练完成的主神经网络。
[0005]在本专利技术的一些实施例中,所述根据所述主神经网络和次神经网络之间的分布差异,确定参与交互训练的目标函数包括:确定所述主神经网络的监督损失函数;确定主神经网络与每个次神经网络的交互训练的损失函数;根据所述监督损失函数和每个交互训练的损失函数,确定参与交互训练的目标函数。
[0006]进一步的,所述参与交互训练的目标函数通过如下方法确定:
[0007]L
01
=αl
01
+(1

α)D,
[0008]其中,L
01
表示主神经网络的监督损失函数,α为权重因子,D表示神经网络与每个次网络的交互训练的损失函数。
[0009]优选的,所述神经网络与每个次网络的交互训练的损失函数通过KL散度度量。
[0010]进一步的,所述主神经网络的监督损失函数为Focal loss函数。
[0011]在上述的实施例中,所述确定参与交互训练的一个主神经网络,以及至少一个次神经网络包括:根据需求将匹配度较高的网络模型作为主神经网络;将一个或多个性能较或泛化能力强于所述主神经网络的神经网络模型作为次神经网络。
[0012]本专利技术的第二方面,提供了一种神经网络模型交互训练装置,包括:第一确定模块,用于确定参与交互训练的一个主神经网络,以及至少一个次神经网络;第二确定模块,用于根据所述主神经网络和次神经网络之间的分布差异,确定参与交互训练的目标函数;训练模块,用于根据所述目标函数训练所述主神经网络和次神经网络,直至目标函数值达到阈值且趋于稳定,得到训练完成的主神经网络。
[0013]在本专利技术的一些实施例中,所述第二确定模块包括:第一确定单元,用于确定所述主神经网络的监督损失函数;第二确定单元,用于确定主神经网络与每个次神经网络的交互训练的损失函数;第三确定单元,用于根据所述监督损失函数和每个交互训练的损失函数确定参与交互训练的目标函数。
[0014]本专利技术的第三方面,提供了一种电子设备,包括:一个或多个处理器;存储装置,用于存储一个或多个程序,当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理器实现本专利技术在第一方面提供的神经网络模型交互训练方法。
[0015]本专利技术的第四方面,提供了一种计算机可读介质,其上存储有计算机程序,其中,所述计算机程序被处理器执行时实现本专利技术在第一方面提供的神经网络模型交互训练方法。
[0016]本专利技术的有益效果是:
[0017]1.本专利技术提出一种神经网络交互训练策略,利用监督损失函数更新主网络模型权重,同时采用KL散度来度量主网络与次网络的预测概率分布差异作为交互学习损失函数,实现充分利用次网络高性能的最优估计以及与主次网络交互学习经验来引导主网络学习,从而获得与次网络相近或略高的性能。该策略旨在缓解主网络单独训练时收敛缓慢,容易陷入局部最优,特别在训练样本量限制条件下,网络模型泛化性较弱,检出率偏低等问题;
[0018]2.本专利技术主要采用主次网络模型之间交互训练来辅助主网络对样本信息的有效学习的过程。单网络模型训练,表现为模型迭代缓慢,空间特征信息捕获能力有限,尽管采用预训练权重做权重初始化,测试表现出较好的预测精度,但在应对不同的复杂的场景时,网络的预测估计能力可能产生断崖式下降,原因是网络陷入局部最优解,而我们期望的是全局最优解,本专利技术旨在缓解上述问题,主要优势体现为为凭借次网络强大的性能输出最优预测概率分布引导主网络迭代训练,更快接近极小值点。在训练样本量有限情况下,交互训练能够让主网络获得与次网络相近甚至略高的性能,原因是次网络对数据的学习与理解在一个较高的水平,那么次网络的最优估计传递给主网络来更新迭代,使得主网络获得与次网络相持的学习能力,并提升了模型的泛化性;
[0019]3.对交互训练目标函数做加权和处理,交互网络训练初期次网络预测主导整个网络交互训练,加快主网络收敛,至训练中后期,主次网络在自身的模型中学习到不同的知识,次网络对复杂场景良好拟合能力通过KL散度方式共享学习经验,共同提升,实现交互训练过程。当然可以挖掘更多的信息,比如主次网络对不同目标的特征空间学习,在某种程度上表达了的目标之间的关联性,为应对复杂场景表现出良好泛化能力;
[0020]4.本专利技术旨在主次网络模型之间交互训练学习,同样可以拓展至多网络模型之间的交互训练,利用n-1个模型辅助主网络学习,模型在精度,泛化性和稳定性上会表现出更高效的性能。
附图说明
[0021]图1为本专利技术的一些实施例中的神经网络模型交互训练方法的基本流程示意图;
[0022]图2为本专利技术的一些实施例中的神经网络模型交互训练方法的具体流程示意图之一;
[0023]图3为本专利技术的一些实施例中的神经网络模型交互训练方法的具体流程示意图之二;
[0024]图4为本专利技术的一些实施例中的神经网络模型交互训练装置的结构示意图;
[0025]图5为本专利技术的一些实施例中的电子设备的结构示意图。
具体实施方式
[0026]以下结合附图对本专利技术的原理和特征进行描述,所举实例只用于解释本专利技术,并非用于限定本专利技术的范围。
[0027]参考图1,在本专利技术的第一方面,提供了一种神经网络模型交互训练方法,包括:S100.确定参与交互训练的一个主神经网络,以及至少一个次神经网络;S200.根据所述主神经网络和次神经本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种神经网络模型交互训练方法,其特征在于,包括:确定参与交互训练的一个主神经网络,以及至少一个次神经网络;根据所述主神经网络和次神经网络之间的分布差异,确定参与交互训练的目标函数;根据所述目标函数训练所述主神经网络和次神经网络,直至目标函数值达到阈值且趋于稳定,得到训练完成的主神经网络。2.根据权利要求1所述的神经网络模型交互训练方法,其特征在于,所述根据所述主神经网络和次神经网络之间的分布差异,确定参与交互训练的目标函数包括:确定所述主神经网络的监督损失函数;确定主神经网络与每个次神经网络的交互训练的损失函数;根据所述监督损失函数和每个交互训练的损失函数,确定参与交互训练的目标函数。3.根据权利要求2所述的的神经网络模型交互训练方法,其特征在于,所述参与交互训练的目标函数通过如下方法确定:L
01
=αl
01
+(1

α)D,其中,L
01
表示主神经网络的监督损失函数,α为权重因子,D表示神经网络与每个次网络的交互训练的损失函数。4.根据权利要求3所述的神经网络模型交互训练方法,其特征在于,所述神经网络与每个次网络的交互训练的损失函数通过KL散度度量。5.根据权利要求2所述的神经网络模型交互训练方法,其特征在于,所述主神经网络的监督损失函数为Focal loss函数。6.根据权利要求1至5任一项所述的神经...

【专利技术属性】
技术研发人员:乔少华
申请(专利权)人:武汉中海庭数据技术有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1