一种模型训练方法及相关装置制造方法及图纸

技术编号:42118419 阅读:24 留言:0更新日期:2024-07-25 00:37
一种模型训练方法,应用于人工智能技术领域。在该方法中,在执行对抗对比学习方法的过程中,采用了动态的数据增强策略,即数据增强强度随着训练迭代次数的增加而发生变化,进而有效地平衡了对比学习和对抗训练之间针对于数据增强的矛盾,提高训练得到的模型的鲁棒性。

【技术实现步骤摘要】

本申请涉及人工智能(artificial intelligence,ai),尤其涉及一种模型训练方法及相关装置


技术介绍

1、在ai
中,深度神经网络的鲁棒性问题是业界广泛研究的问题。具体而言,攻击者可以通过向输入数据添加人眼不可见的噪声,使得深度神经网络输出错误的结果。为了解决深度神经网络的鲁棒性问题,有监督的对抗训练方法被提出。这类方法的基本思路是利用对抗样本进行训练,而对抗样本通常是通过最大化交叉熵损失函数或通过替代损失最小化的对抗防御(tradeoff-inspired adversarial defense via surrogate-lossminimization,trades)损失函数得到。但是,有监督的对抗训练方法往往需要大量的有标注数据。

2、对比学习是一种无监督训练方法。该方法的主要思想是对输入批次中的每个数据样本进行两种不同的数据增强方法。对于一个数据样本,该数据样本增强后的两个数据称为正样本对,其余样本称为负样本对,通过基于正样本对和负样本对来构建损失函数无监督地训练模型。

3、通过对比学习无监督地训练鲁棒的模型,即对抗对比学习,是近年来研究的热点问题。这类方法的主要思想是基于对比学习框架,生成对抗样本,并利用对抗样本训练模型。然而,目前的对抗对比学习方法所训练得到的模型的鲁棒性较低。


技术实现思路

1、本申请提供了一种模型训练方法,能够提高训练得到的模型的鲁棒性。

2、本申请第一方面提供一种模型训练方法,应用于人工智能
该方法具体包括:在对第一模型执行第n轮迭代训练的过程中,对第一训练数据执行数据增强,得到数据增强样本。其中,第一训练数据为对第一模型执行第n轮迭代训练所采用的数据,且对第一训练数据执行数据增强的强度与n的大小相关,n为大于或等于1的整数。即,在对第一模型执行不同轮次的迭代训练时,对当前迭代轮次训练所使用的训练数据所执行的数据增强的强度是不同的。

3、然后,基于所得到的数据增强样本和第一模型,生成对抗样本。该对抗样本是在数据增强样本的基础上增加噪声得到的。

4、其次,将数据增强样本和对抗样本输入第一模型,得到第一模型提取的第一特征和第二特征,第一特征为数据增强样本的特征,第二特征为对抗样本的特征。也就是说,通过第一模型分别提取数据增强样本和对抗样本的特征。

5、最后,基于损失函数更新第一模型,损失函数是基于第一特征和第二特征得到的。示例性地,该损失函数可以为基于第一特征和第二特征所构建的info-nce损失函数。

6、本方案中,在执行对抗对比学习方法的过程中,采用了动态的数据增强策略,即数据增强强度随着训练迭代次数的增加而发生变化,进而有效地平衡了对比学习和对抗训练之间针对于数据增强的矛盾,提高训练得到的模型的鲁棒性。

7、在一种可能的实现方式中,对第一训练数据执行数据增强的强度与n的大小具有负相关的关系。即,n的数值越大,代表训练迭代轮次越大,对第一训练数据执行数据增强的强度越小;n的数值越小,代表训练迭代轮次越小,对第一训练数据执行数据增强的强度越大。也就是说,在第一模型的训练过程中,随着训练迭代轮次逐渐增加(即随着训练的不断进行),对训练数据所执行的数据增强的强度逐渐变小。

8、本方案中,通过设定模型的训练迭代轮次与训练数据的数据增强强度呈负相关关系,能够使得在模型在训练过程中,对训练数据所执行的数据增强强度不断下降,进而很好地平衡对比学习和对抗训练之间针对于数据增强的矛盾,提高训练得到的模型的鲁棒性。

9、具体而言,在第一模型的训练初期,随机初始化的第一模型还没有学习到很好的表征,因此需要对训练数据执行强度较高的数据增强,以帮助第一模型迅速地学习到好的表征;在第一模型的训练后期,第一模型已逐渐学习到较好的表征,因此逐渐降低对训练数据所执行的数据增强的强度,能够有利于提升对抗训练的效果,进而同时兼顾对比学习和对抗训练针对于数据增强的需求,提高训练得到的模型的鲁棒性。

10、在一种可能的实现方式中,第一训练数据为图像,对第一训练数据执行数据增强的方式包括以下的一种或多种方式:旋转、翻转、平移变换、转换为灰度图、颜色抖动、调整亮度、调整对比度、调整饱和度以及调整色调。

11、其中,对训练数据执行不同的数据增强强度的方式也可以有多种。例如,通过对不同的训练数据采用不同数量的数据增强方式,来实现对训练数据执行不同的数据增强强度。又例如,针对于相同的数据增强方式,通过调整相同的数据增强方式中的数据增强范围,来实现对训练数据执行不同的数据增强强度。

12、在一种可能的实现方式中,第一模型包括具有相同结构的第一网络支路和第二网络支路,且第一网络支路和第二网络支路中的批归一化层的参数不同。

13、将数据增强样本和对抗样本输入第一模型,得到第一模型提取的第一特征和第二特征,具体包括:将数据增强样本输入第一网络支路以及将对抗样本输入第二网络支路,得到第一网络支路提取的第一特征和第二网络支路提取的第二特征。

14、具体地,第一网络支路和第二网络支路可以是共享相同的结构以及除批归一化层以外的参数,即第一网络支路和第二网络支路除了最后的批归一化层的参数不同之外,其他的结构和参数均相同。例如,第一网络支路和第二网络支路均包括卷积层、全连接层、池化层和批归一化层,其中第一网络支路和第二网络支路中的卷积层、全连接层和池化层的结构和参数均相同,但第一网络支路中的批归一化层的参数和第二网络支路批归一化层的参数不同。

15、本方案中,针对于不同的网络支路,采用不同的批归一化层参数,是考虑到了正常的数据增强样本与对抗样本所属的分布不同,而批归一化层中的参数往往能刻画数据的分布,因此针对于提取不同数据的特征的网络支路采用了不同的批归一化层。

16、在一种可能的实现方式中,在基于损失函数更新第一模型后,该方法还包括:将第二训练数据输入第一网络支路,得到第一网络支路提取的特征;通过聚类算法对第一网络支路提取的特征进行聚类,得到第二训练数据的伪标签;基于第二训练数据和伪标签,对第二网络支路进行对抗训练,得到训练后的第二网络支路。其中,聚类算法例如可以为k-means算法。伪标签是指通过模型对训练数据进行预测所得到的标签,用于代替人工标注的标签。

17、本方案中,针对于模型中第一网络支路(即处理数据增强样本的正常支路)预测伪标签准确率很高的特点,采用第一网络支路生成可靠的伪标签,然后将伪标签用于实现第二网络支路(即处理对抗样本的对抗支路)的训练,能够有效地提高模型中第二网络支路的鲁棒性。

18、在一种可能的实现方式中,数据增强样本包括第一增强样本和第二增强样本,第一增强样本和第二增强样本是分别对第一训练数据执行不同的数据增强得到的,对抗样本包括第一对抗样本和第二对抗样本。

19、在一种可能的实现方式中,将数据增强样本和对抗样本输入第一模型,得到第一模型提取的第本文档来自技高网...

【技术保护点】

1.一种模型训练方法,其特征在于,包括:

2.根据权利要求1所述的方法,其特征在于,对所述第一训练数据执行数据增强的强度与所述N的大小具有负相关的关系。

3.根据权利要求1或2所述的方法,其特征在于,所述第一训练数据为图像,所述对所述第一训练数据执行数据增强的方式包括以下的一种或多种方式:旋转、翻转、平移变换、转换为灰度图、颜色抖动、调整亮度、调整对比度、调整饱和度以及调整色调。

4.根据权利要求1-3任意一项所述的方法,其特征在于,所述第一模型包括具有相同结构的第一网络支路和第二网络支路,且所述第一网络支路和所述第二网络支路中的批归一化层的参数不同;

5.根据权利要求4所述的方法,其特征在于,在基于所述损失函数更新所述第一模型后,所述方法还包括:

6.根据权利要求1-5任意一项所述的方法,其特征在于,所述数据增强样本包括第一增强样本和第二增强样本,所述第一增强样本和所述第二增强样本是分别对所述第一训练数据执行不同的数据增强得到的,所述对抗样本包括第一对抗样本和第二对抗样本。

7.根据权利要求6所述的方法,其特征在于,将所述数据增强样本和所述对抗样本输入第一模型,得到所述第一模型提取的第一特征和第二特征,包括:

8.根据权利要求6所述的方法,其特征在于,所述第一特征包括对应于所述第一增强样本的第一子特征和对应于所述第二增强样本的第二子特征,所述第二特征包括对应于所述第一对抗样本的第三子特征和对应于所述第二对抗样本的第四子特征;

9.一种模型训练装置,其特征在于,包括:

10.根据权利要求9所述的装置,其特征在于,对所述第一训练数据执行数据增强的强度与所述N的大小具有负相关的关系。

11.根据权利要求9或10所述的装置,其特征在于,所述第一训练数据为图像,所述对所述第一训练数据执行数据增强的方式包括以下的一种或多种方式:旋转、翻转、平移变换、转换为灰度图、颜色抖动、调整亮度、调整对比度、调整饱和度以及调整色调。

12.根据权利要求9-11任意一项所述的装置,其特征在于,所述第一模型包括具有相同结构的第一网络支路和第二网络支路,且所述第一网络支路和所述第二网络支路中的批归一化层的参数不同;

13.根据权利要求12所述的装置,其特征在于,所述处理模块,还用于:

14.根据权利要求9-13任意一项所述的装置,其特征在于,所述数据增强样本包括第一增强样本和第二增强样本,所述第一增强样本和所述第二增强样本是分别对所述第一训练数据执行不同的数据增强得到的,所述对抗样本包括第一对抗样本和第二对抗样本。

15.根据权利要求14所述的装置,其特征在于,所述处理模块,还用于:

16.根据权利要求14所述的装置,其特征在于,所述第一特征包括对应于所述第一增强样本的第一子特征和对应于所述第二增强样本的第二子特征,所述第二特征包括对应于所述第一对抗样本的第三子特征和对应于所述第二对抗样本的第四子特征;

17.一种模型训练装置,其特征在于,包括存储器和处理器;所述存储器存储有代码,所述处理器被配置为执行所述代码,当所述代码被执行时,所述装置执行如权利要求1至8任意一项所述的方法。

18.一种计算机存储介质,其特征在于,所述计算机存储介质存储有指令,所述指令在由计算机执行时使得所述计算机实施权利要求1至8任意一项所述的方法。

19.一种计算机程序产品,其特征在于,所述计算机程序产品存储有指令,所述指令在由计算机执行时使得所述计算机实施权利要求1至8任意一项所述的方法。

...

【技术特征摘要】

1.一种模型训练方法,其特征在于,包括:

2.根据权利要求1所述的方法,其特征在于,对所述第一训练数据执行数据增强的强度与所述n的大小具有负相关的关系。

3.根据权利要求1或2所述的方法,其特征在于,所述第一训练数据为图像,所述对所述第一训练数据执行数据增强的方式包括以下的一种或多种方式:旋转、翻转、平移变换、转换为灰度图、颜色抖动、调整亮度、调整对比度、调整饱和度以及调整色调。

4.根据权利要求1-3任意一项所述的方法,其特征在于,所述第一模型包括具有相同结构的第一网络支路和第二网络支路,且所述第一网络支路和所述第二网络支路中的批归一化层的参数不同;

5.根据权利要求4所述的方法,其特征在于,在基于所述损失函数更新所述第一模型后,所述方法还包括:

6.根据权利要求1-5任意一项所述的方法,其特征在于,所述数据增强样本包括第一增强样本和第二增强样本,所述第一增强样本和所述第二增强样本是分别对所述第一训练数据执行不同的数据增强得到的,所述对抗样本包括第一对抗样本和第二对抗样本。

7.根据权利要求6所述的方法,其特征在于,将所述数据增强样本和所述对抗样本输入第一模型,得到所述第一模型提取的第一特征和第二特征,包括:

8.根据权利要求6所述的方法,其特征在于,所述第一特征包括对应于所述第一增强样本的第一子特征和对应于所述第二增强样本的第二子特征,所述第二特征包括对应于所述第一对抗样本的第三子特征和对应于所述第二对抗样本的第四子特征;

9.一种模型训练装置,其特征在于,包括:

10.根据权利要求9所述的装置,其特征在于,对所述第一训练数据执行数据增强的强度与所述n的大小具有负相关的关系。

11.根据权利要求9或10所...

【专利技术属性】
技术研发人员:王奕森罗润冬王一飞黄维然姚骏
申请(专利权)人:华为技术有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1