一种鲁棒数据集蒸馏方法及系统技术方案

技术编号:36792381 阅读:20 留言:0更新日期:2023-03-08 22:45
本发明专利技术提出一种鲁棒数据集蒸馏方法,包括:在原训练数据集中进行随机采样,作为初始化的蒸馏数据集;对该原训练数据集进行随机采样,得到原样本采样集;为该原样本采样集添加对抗扰动,得到扰动样本采样集;分别计算该扰动样本采样集与该蒸馏数据集的特征表示或网络梯度,以该特征表示或网络梯度的差异作为损失函数;根据该损失函数反向传播对该蒸馏数据集进行优化更新,不断迭代优化过程,直至损失函数收敛,保存当前蒸馏数据集作为最终的鲁棒蒸馏数据集。以该鲁棒蒸馏数据集代替该原训练数据集,训练目标深度神经网络模型完成图像识别分类。本发明专利技术将对抗鲁棒特性蒸馏到鲁棒蒸馏数据集,从而提升蒸馏数据集训练得到模型的鲁棒性。棒性。棒性。

【技术实现步骤摘要】
一种鲁棒数据集蒸馏方法及系统


[0001]本专利技术涉及数据集蒸馏
,尤其设计一种训练对抗鲁棒模型的数据集蒸馏方法和装置。

技术介绍

[0002]数据集蒸馏,又称数据集压缩,是指对于一个规模较大的训练数据集,利用数据集蒸馏算法将其压缩得到一个规模较小的数据集,模型在小数据集上训练也能达到在原始数据集上训练接近的效果。现有的数据集蒸馏技术普遍通过为生成数据集(小规模数据集)设计损失函数,迭代优化生成样本。其中,生成数据集的损失函数通常会定义为生成数据集与原始数据集的表示分布差异或是训练模型参数的差异等方式,从而衡量生成数据集的质量,通过迭代优化使得通过生成数据集训练得到的模型能够达到与原始数据集训练得到模型类似的效果。
[0003]随着训练数据集规模被大幅压缩,模型在少量数据集上进行训练更容易过拟合,训练得到的模型具有更加明显的对抗脆弱性。然而,目前的数据集蒸馏算法只考虑了蒸馏数据集训练模型在测试样本上的准确率与原始数据集类似,未考虑蒸馏数据集训练模型的对抗鲁棒性。因此虽然蒸馏数据集能够利用其规模小、训练快的优势,但是如果不考虑蒸馏数据集训练得到模型的鲁棒性,那么在例如自动驾驶等对安全敏感的任务上,蒸馏数据集很难拥有鲁棒性保障从而得到进一步应用。
[0004]当训练数据集被压缩时,即便能够训练模型得到类似的准确率,但是模型鲁棒性会发生明显的下降。并且即便是利用目前较为有效的对抗训练在蒸馏数据集上训练模型,也未能得到鲁棒性较好的模型,甚至实验结果表明蒸馏数据集经过对抗训练得到的模型反而表现出更严重的对抗脆弱性,换而言之,蒸馏数据集无法很好地适配到对抗训练算法上。
[0005]模型的对抗脆弱性是指训练完成的模型在测试阶段,会对攻击者精心设计的对抗样本产生高置信度误判。例如在自动驾驶任务中,攻击者通过在STOP路牌标志上粘贴精心设计的图案,模型则会将该STOP路标错误识别为通行标志。以往的数据集蒸馏任务没有只考虑了模型针对干净样本的精度,而没有考虑训练得到的模型对抗脆弱性(或者说是对抗鲁棒性),针对这一问题急需提出一种鲁棒数据集蒸馏方法。

技术实现思路

[0006]本专利技术的目的是解决上述现有数据集蒸馏技术未考虑训练得到模型鲁棒性的问题,提出了一种鲁棒的数据集蒸馏框架,在蒸馏过程中添加了对抗扰动,即蒸馏的数据具有对抗特性。经过本申请压缩的训练集和压缩前相比,训练得到的模型会更加鲁棒。
[0007]具体来说,本专利技术提出了一种鲁棒数据集蒸馏方法,其中包括:
[0008]步骤1、在原训练数据集中进行随机采样,作为初始化蒸馏数据集;
[0009]步骤2、对该原训练数据集进行随机采样,得到原样本采样集;
[0010]步骤3、为该原样本采样集添加对抗扰动,得到扰动样本采样集;
[0011]步骤4、分别计算该扰动样本采样集与该蒸馏数据集的特征表示或网络梯度,以该特征表示或网络梯度的差异作为损失函数;
[0012]步骤5、根据该损失函数反向传播对该蒸馏数据集进行优化更新,并且不断迭代优化过程,直至损失函数收敛,保存当前蒸馏数据集作为最终的鲁棒蒸馏数据集。
[0013]步骤6、以该鲁棒蒸馏数据集代替该原训练数据集,训练目标深度神经网络模型完成图像识别分类。
[0014]所述的鲁棒数据集蒸馏方法,其中该原训练数据集由多张图像构成,该优化更新包括:根据损失函数函数计算出损失值并进行梯度反向传播,得到针对该蒸馏数据集的每幅图像中像素值更新的梯度,依据该梯度乘以一定的权重,也就是学习率和步长,修改其对应像素的色彩,以实现对该蒸馏数据集的优化更新。
[0015]所述的鲁棒数据集蒸馏方法,其中添加针对当前目标深度神经网络模型参数的对抗扰动。
[0016]本专利技术还提出了一种鲁棒数据集蒸馏系统,其中包括:
[0017]初始模块,用于将待蒸馏的原训练数据集蒸馏压缩为蒸馏数据集;
[0018]采样模块,用于对该原训练数据集进行随机采样,得到原样本采样集;
[0019]扰动模块,用于为该原样本采样集添加对抗扰动,得到扰动样本采样集;
[0020]训练模块,用于分别计算该扰动样本采样集与该蒸馏数据集的特征表示或网络梯度,以该特征表示或网络梯度的差异作为损失函数根据该损失函数反向传播对该蒸馏数据集进行优化更新,并且不断迭代优化过程,直至损失函数收敛,保存当前蒸馏数据集作为最终的鲁棒蒸馏数据集。
[0021]所述的鲁棒数据集蒸馏系统,其中还包括:
[0022]图像分类模块,用于以该鲁棒蒸馏数据集代替该原训练数据集,训练目标深度神经网络模型完成图像识别分类。
[0023]所述的鲁棒数据集蒸馏系统,其中该原训练数据集由多张图像构成,该优化更新包括:根据损失函数函数计算出损失值并进行梯度反向传播,得到针对该蒸馏数据集的每幅图像中像素值更新的梯度,以该梯度修改其对应像素的色彩,以实现对该蒸馏数据集的优化更新。
[0024]所述的鲁棒数据集蒸馏系统,其中添加针对当前目标深度神经网络模型参数的对抗扰动。
[0025]本专利技术还提出了一种存储介质,用于存储执行如所述任意一种鲁棒数据集蒸馏方法的程序。
[0026]本专利技术还提出了一种客户端,用于任意一种鲁棒数据集蒸馏系统。
[0027]由以上方案可知,本专利技术的优点在于:
[0028]本专利技术提出的鲁棒数据集蒸馏方法,可以在不改变深度神经网络模型训练方法的情况下,将鲁棒性属性蒸馏到小规模生成数据集中,相比原始未添加对抗扰动模块的普通数据集蒸馏算法,以本专利技术提出的鲁棒数据集蒸馏算法蒸馏得到的数据集作为训练数据集,训练得到的模型能够取得更好的对抗鲁棒性。
[0029]下述表1是本专利技术中,分别在公开数据集CIFAR10与MNIST上针对不同数据集规模(图像个数/类别,Img/Cls)的对抗鲁棒性(Robustness)的仿真实验结果,表2是对应的准确
率(Accuracy)的仿真实验结果。当本申请用于图像分类时,输入为原始数据集,输出为小规模蒸馏数据集,输出的小规模数据集是综合考虑了训练模型的准确率与鲁棒性的小规模数据集,也就是蒸馏得到的小而精的训练数据集。
[0030]实验任务为图像分类任务,数据集CIFAR10是10分类任务,数据集MNIST是2分类任务。对比的基线模型是基于特征匹配的蒸馏方法(Feature Matching,FM)与基于梯度匹配的蒸馏方法(Gradient Matching,GM),这两类未考虑鲁棒性的数据集蒸馏算法统称为Dataset Condensation,简称为DC。本专利技术(Adversarial Robust Dataset Condensation,AR

DC)将鲁棒数据集蒸馏算法分别适配到FM与GM算法上。根据表1可以看出,添加对抗扰动模块后的鲁棒数据集蒸馏方法相比普通数据集蒸馏算法,蒸馏出的数据集训练得到的模型在CIFAR10数据集上具有更好的鲁棒性,在MNIST数据集上,在某些设置下本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种鲁棒数据集蒸馏方法,其特征在于,包括:步骤1、在原训练数据集中进行随机采样,作为初始化的蒸馏数据集;步骤2、对该原训练数据集进行随机采样,得到原样本采样集;步骤3、为该原样本采样集添加对抗扰动,得到扰动样本采样集;步骤4、分别计算该扰动样本采样集与该蒸馏数据集的特征表示或网络梯度,以该特征表示或网络梯度的差异作为损失函数;步骤5、根据该损失函数反向传播对该蒸馏数据集进行优化更新,并重复该步骤2到5不断迭代优化过程,直至损失函数收敛,保存当前蒸馏数据集作为最终的鲁棒蒸馏数据集。2.如权利要求1所述的鲁棒数据集蒸馏方法,其特征在于,还包括:步骤6、以该鲁棒蒸馏数据集代替该原训练数据集,训练目标深度神经网络模型完成图像识别分类。3.如权利要求1所述的鲁棒数据集蒸馏方法,其特征在于,该原训练数据集由多张图像构成,该优化更新包括:根据损失函数函数计算出损失值并进行梯度反向传播,得到针对该蒸馏数据集的每幅图像中像素值更新的梯度,依据该梯度修改其对应像素的色彩,以实现对该蒸馏数据集的优化更新。4.如权利要求2所述的鲁棒数据集蒸馏方法,其特征在于,添加的该对抗扰动具体包括:为原样本采样集中样本添加针对当前该目标深度神经网络模型参数的对抗扰动。5.一种鲁棒数据集蒸馏系统,其特征在于,包括:初始模块,用于将待蒸馏的原训练数据...

【专利技术属性】
技术研发人员:程学旗郭嘉丰陈薇李家宁张明坤
申请(专利权)人:中国科学院计算技术研究所
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1