一种多样性增强的无数据知识蒸馏方法技术

技术编号:37334175 阅读:25 留言:0更新日期:2023-04-21 23:12
本发明专利技术提供了一种多样性增强的无数据知识蒸馏方法,包括以下步骤:步骤1:初始化大型教师模型,使用原始数据集训练教师模型;步骤2:构建深度卷积条件生成网络,与噪声向量一起输入到条件生成网络;步骤3:基于度量学习构建一个多样性正则化项,同时根据教师模型反馈知识、内部知识和多样性正则化项的损失函数联合训练条件生成网络训练,直至条件生成网络收敛,使用训练好的条件生成网络合成数据集;步骤4:利用在步骤3中获得的合成数据集通过知识蒸馏训练学生模型,直至学生模型被训练至收敛。应用本技术方案不仅能够在缺乏原始数据的情况下,生成逼近于原始数据的合成数据,而且还能保证合成数据的多样性,从而提高学生模型训练的有效性。训练的有效性。训练的有效性。

【技术实现步骤摘要】
一种多样性增强的无数据知识蒸馏方法


[0001]本专利技术涉及知识蒸馏
,特别是一种多样性增强的无数据知识蒸馏方法。

技术介绍

[0002]近年来,深度学习凭借强大的自动特征提取能力,广泛应用在智能家居、可穿戴设备和自动驾驶等领域。为了增强模型的学习能力,深度学习的模型架构被设计的更加复杂,算法需要较高计算成本和内存消耗,模型难以部署和应用在资源受限的终端设备中。知识蒸馏是获得高效小规模模型的一种新兴技术,其主要思想是将高性能大规模网络作为教师模型来指导小规模学生模型进行训练。
[0003]传统的知识蒸馏技术都假设教师模型的原始数据集在学生模型训练阶段仍然可用,这在实际应用中是非常困难的。一方面,在数据传输过程中,用户隐私、数据安全及数据所有权面临严峻挑战,一个典型例子是人脸、语音和指纹等生物特征数据的重利用会侵害用户的隐私信息。另一方面,呈指数增长的数据量难以存储、传输和管理,如包含数百万张图像的原始数据集很难传输和重用。
[0004]为了解决这一问题,人们提出一种无数据知识蒸馏方法,利用合成数据集来代替原始数据集进行知识蒸馏过程。但现有研究主要集中在如何使用教师网络的先验知识来合成数据,忽略合成数据多样性匮乏的问题。多样性匮乏的数据集会影响学生模型学习数据集的本质特征,学生模型无法得到有效的训练,模型鲁棒性不佳。

技术实现思路

[0005]有鉴于此,本专利技术的目的在于提供一种多样性增强的无数据知识蒸馏方法,不仅能够在缺乏原始数据的情况下,生成逼近于原始数据的合成数据,而且还能保证合成数据的多样性,从而提高学生模型训练的有效性。
[0006]为实现上述目的,本专利技术采用如下技术方案:一种多样性增强的无数据知识蒸馏方法,包括以下步骤:
[0007]步骤1:初始化大型教师模型,使用原始数据集训练教师模型,并存储教师模型对应的内部知识;
[0008]步骤2:构建深度卷积条件生成网络,基于条件生成的思想将预设的标签信息作为条件数据,与噪声向量一起输入到条件生成网络;
[0009]步骤3:基于度量学习构建一个多样性正则化项,同时根据教师模型反馈知识、内部知识和多样性正则化项的损失函数联合训练条件生成网络训练,直至条件生成网络收敛,使用训练好的条件生成网络合成数据集;
[0010]步骤4:利用在步骤3中获得的合成数据集通过知识蒸馏训练学生模型,直至学生模型被训练至收敛。
[0011]在一较佳的实施例中,所述步骤1具体包括:
[0012]步骤11:初始化参数为θ
T
的大型教师模型T;
[0013]步骤12:设定最大迭代次数为N
T
、停止迭代阈值为k
T
,用原始数据集(X
ori
,Y
ori
)训练步骤11中搭建好的教师模型,直到当前迭代的函数值小于阈值k
T
或达到最大迭代次数N
T
,保存训练好的教师模型T;
[0014]步骤13:存储教师模型的内部知识,即原始数据在教师模型第l层的运行均值μ
l
和方差
[0015]在一较佳的实施例中,所述步骤2具体包括:
[0016]步骤21:构建参数为θ
G
的条件生成网络G,由卷积层,BatchNorm层,Tanh层和LeakyReLU激活函数组成;
[0017]步骤22:假设噪声向量z符合高斯分布,给定噪声向量z和其类别信息y作为条件生成网络G的输入,得到合成数据x
G
=G(z|y,θ
G
)。
[0018]在一较佳的实施例中,所述步骤3具体包括:
[0019]步骤31:所述步骤3中同时根据教师模型反馈知识、内部知识和多样性正则化项的损失函数联合训练条件生成网络训练,总体损失函数为:
[0020]L
G
(x
G
,y)=L
TG
(x
G
,y)+L
S
(x
G
)+L
DIV
[0021]其中,L
TG
是教师模型反馈知识的损失函数,L
S
是教师模型内部知识的损失函数,L
DIV
是多样性正则化项的损失函数;
[0022]步骤32:教师模型反馈知识的损失函数具体包括:
[0023]L
TG
(x
G
,y)=L
OH
(x
G
,y)+R
L2
(x
G
)+R
TV
(x
G
)
[0024]其中,L
OH
是教师模型T对合成数据x
G
的预测输出和预设类别信息y之间的交叉熵损失,其公式表示为:
[0025][0026]R
L2
是合成图像的L2范数,其鼓励合成图像的范围保持在目标间隔内,公式表示为:R
L2
(x
G
)=||x
G
||2,R
TV
是全变分正则化项,用于噪声消除和保留图像边界信息,其计算公式表示为:
[0027][0028]其中,表示图像x
G
在坐标(i,j)上的像素值;
[0029]步骤33:教师模型内部知识的损失函数具体包括:
[0030][0031]L
S
用于最小化所有层合成数据x
G
和原始数据的统计数据之间的距离;其中,μ
l
和是在步骤13中存储的教师模型内部知识;μ
l
(x
G
)和表示合成数据在教师模型第l层的运行均值和方差。N(
·
,
·
)表示正态分布,D
KL
(
·
||
·
)表示KL散度;
[0032]步骤34:基于度量学习尽可能放大同类样本之间的距离,主要使通过计算同类别合成数据相对于噪声向量之间的比率来实现,多样性正则化项的损失函数具体包括:
[0033][0034]其中,z和z1是标签信息都为y的不同噪声向量,x
G
和x
G1
是条件生成网络G合成的数据。
[0035]步骤35:设定最大迭代次数为N
G
、停止迭代阈值为k
G
,根据损失函数L
G
训练更新条件生成网络G的参数,直到当前迭代的函数值小于阈值k或达到最大迭代次数N
G
,输出合成数据集(X
G
,Y)。
[0036]在一较佳的实施例中,步骤4具体包括:
[0037]步骤41:初始化参数为θ
S
的轻量级学生模型S;
[0038]步骤42:用合成数据集(X本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种多样性增强的无数据知识蒸馏方法,其特征在于,包括以下步骤:步骤1:初始化大型教师模型,使用原始数据集训练教师模型,并存储教师模型对应的内部知识;步骤2:构建深度卷积条件生成网络,基于条件生成的思想将预设的标签信息作为条件数据,与噪声向量一起输入到条件生成网络;步骤3:基于度量学习构建一个多样性正则化项,同时根据教师模型反馈知识、内部知识和多样性正则化项的损失函数联合训练条件生成网络训练,直至条件生成网络收敛,使用训练好的条件生成网络合成数据集;步骤4:利用在步骤3中获得的合成数据集通过知识蒸馏训练学生模型,直至学生模型被训练至收敛。2.根据权利要求1所述的一种多样性增强的无数据知识蒸馏方法,其特征在于,所述步骤1具体包括:步骤11:初始化参数为θ
T
的大型教师模型T;步骤12:设定最大迭代次数为N
T
、停止迭代阈值为k
T
,用原始数据集(X
ori
,Y
ori
)训练步骤11中搭建好的教师模型,直到当前迭代的函数值小于阈值k
T
或达到最大迭代次数N
T
,保存训练好的教师模型T;步骤13:存储教师模型的内部知识,即原始数据在教师模型第l层的运行均值μ
l
和方差3.根据权利要求1所述的一种多样性增强的无数据知识蒸馏方法,其特征在于,所述步骤2具体包括:步骤21:构建参数为θ
G
的条件生成网络G,由卷积层,BatchNorm层,Tanh层和LeakyReLU激活函数组成;步骤22:假设噪声向量z符合高斯分布,给定噪声向量z和其类别信息y作为条件生成网络G的输入,得到合成数据x
G
=G(z|y,θ
G
)。4.根据权利要求2所述的一种多样性增强的无数据知识蒸馏方法,其特征在于,所述步骤3具体包括:步骤31:所述步骤3中同时根据教师模型反馈知识、内部知识和多样性正则化项的损失函数联合训练条件生成网络训练,总体损失函数为:L
G
(x
G
,y)=L
TG
(x
G
,y)+L
S
(x
G
)+L
DIV
其中,L
TG
是教师模型反馈知识的损失函数,L
S
是教师模型内部知识的损失函数,L
DIV
是多样性正则化项的损失函数;步骤32:教师模型反馈知识的损失函数具体包括:L
TG
(x
G
,y)=L
OH
(x
G
,y)+R
L2
(x
G
)+R
TV
(x
G
)其中,L
OH
是教师模型T对合成数据x
G
的预测输出和预设类别信息y之间的交叉熵损失,其公式表示为:R
L2
是合成图...

【专利技术属性】
技术研发人员:叶阿勇刘燕妮
申请(专利权)人:福建师范大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1