知识蒸馏的对抗式蒸馏温度调整方法、装置、设备及介质制造方法及图纸

技术编号:38730128 阅读:13 留言:0更新日期:2023-09-08 23:20
本发明专利技术涉及人工智能、机器学习、智慧医疗、金融科技技术领域,公开了一种知识蒸馏的对抗式蒸馏温度调整方法、装置、设备及介质。本方法通过梯度下降算法迭代学生模型中的全体参数使得目标损失函数的值最小,确定最小损失值;通过梯度上升算法迭代温度参数使得目标损失函数的值最大,确定最大损失值;交替执行上述迭代过程,使得所述最大损失值与所述最小损失值之间差值的绝对值小于预设阈值。确定出当前蒸馏温度。由此提高了学生模型学习的效率,同时符合由易到难的学习过程,缓解了固定温度超参数带来的模型性能瓶颈问题。参数带来的模型性能瓶颈问题。参数带来的模型性能瓶颈问题。

【技术实现步骤摘要】
知识蒸馏的对抗式蒸馏温度调整方法、装置、设备及介质


[0001]本专利技术涉及人工智能、深度学习、智慧医疗、金融科技
,尤其涉及一种知识蒸馏的对抗式蒸馏温度调整方法、装置、设备及介质。

技术介绍

[0002]随着人工智能和深度学习技术的发展,深度学习通过从数据中自行学习出有效的特征表示,代替以往及其学习中繁琐的人工特征工程过程,使得机器智能化程度进一步加深。但是深度学习想要发挥出理想的效果,需要大规模的数据,当数据量偏小时,学习效果较差。
[0003]深度学习想要从数据中学习出更有效的特征表示,一般会通过加深模型层数的方法,但这导致了深度学习模型的体积过大,无法部署在资源受限的设备上,往往只是理论上能达到最优,真正落地使用较为困难。
[0004]知识蒸馏是一种经典的模型压缩方法,核心思想是通过引导轻量化的学生模型“模仿”性能更好、结构更复杂的教师模型,在不改变学生模型结构的情况下提高其性能。
[0005]现有的知识蒸馏方法中,一般采用温度作为超参数的相对熵进行损失函数的设计,从而让学生模型通过学习分布间的差异性从而蒸馏到教师模型的知识。但是现有的蒸馏框架普遍会通过在验证集上的最优性能选择一个固定的温度系数。由此,产生以下问题:
[0006]其一、不同的教师模型和学生模型在蒸馏过程中,温度这个超参数的最优值不是唯一的。在知识蒸馏的不同阶段温度的最优值是不一样的。如果要找到这个最佳的超参数,需要进行暴力搜索,而暴力搜索会导致计算量显著增大,使模型运行效率降低。
[0007]其二、一直保持静态固定的超参数对学生模型来说不是最优的,会导致知识蒸馏过程的效率降低。

技术实现思路

[0008]本专利技术提供一种知识蒸馏的对抗式蒸馏温度调整方法、装置、设备及介质,以解决知识蒸馏过程中由于蒸馏温度选取不当造成的蒸馏效率降低的技术问题。
[0009]第一方面,本专利技术提供了一种知识蒸馏的对抗式蒸馏温度调整方法,包括:
[0010]获取教师模型、学生模型和温度参数;
[0011]根据所述教师模型、所述学生模型和所述温度参数构建知识蒸馏的目标损失函数;
[0012]根据梯度下降算法迭代更新所述学生模型的全体参数,以此更新所述学生模型,每一次迭代更新所述学生模型后,计算一次所述目标损失函数的值,直到所述目标损失函数的值最小,确定最小损失值;
[0013]保持使得所述目标损失函数最小的所述全体参数不变,根据梯度上升算法迭代更新所述温度参数,每一次迭代更新所述温度参数后,计算一次所述目标损失函数的值,直到所述目标损失函数的值最大,确定最大损失值;
[0014]确定所述最大损失值后,跳转到根据梯度下降算法迭代更新所述学生模型的全体参数的步骤重复执行,直到所述最大损失值与所述最小损失值之间差值的绝对值小于预设阈值,将所述绝对值小于预设阈值时对应的所述温度参数作为当前蒸馏温度。
[0015]第二方面,本专利技术提供了一种知识蒸馏的对抗式蒸馏温度调整装置,包括:
[0016]获取模块,用于获取教师模型、学生模型和温度参数;
[0017]构建模块,用于根据所述教师模型、所述学生模型和所述温度参数构建知识蒸馏的目标损失函数;
[0018]第一迭代模块,用于根据梯度下降算法迭代更新所述学生模型的全体参数,以此更新所述学生模型,每一次迭代更新所述学生模型后,计算一次所述目标损失函数的值,直到所述目标损失函数的值最小,确定最小损失值;
[0019]第二迭代模块,用于保持使得所述目标损失函数最小的所述全体参数不变,根据梯度上升算法迭代更新所述温度参数,每一次迭代更新所述温度参数后,计算一次所述目标损失函数的值,直到所述目标损失函数的值最大,确定最大损失值;
[0020]执行控制模块,用于确定所述最大损失值后,跳转到根据梯度下降算法迭代更新所述学生模型的全体参数的步骤重复执行,直到所述最大损失值与所述最小损失值之间差值的绝对值小于预设阈值,将所述绝对值小于预设阈值时对应的所述温度参数作为当前蒸馏温度。
[0021]第三方面,本专利技术提供了一种计算机设备,包括存储器、处理器以及存储在存储器中并可在处理器上运行的计算机程序,处理器执行计算机程序时实现上述知识蒸馏的对抗式蒸馏温度调整方法的步骤。
[0022]第四方面,本专利技术提供了一种计算机可读存储介质,计算机可读存储介质存储有计算机程序,计算机程序被处理器执行时实现上述知识蒸馏的对抗式蒸馏温度调整方法的步骤。
[0023]上述知识蒸馏的对抗式蒸馏温度调整方法、装置、设备及介质所实现的方案中,通过根据所述教师模型、所述学生模型和所述温度参数构建知识蒸馏的目标损失函数,通过梯度下降算法迭代更新所述学生模型的全体参数,直到所述目标损失函数的值最小,确定最小损失值。由此通过梯度下降算法迭代学生模型的所述全体参数,使得所述目标损失函数的值最小,得到了基于当前温度参数使得目标损失函数的值最小的所述全体参数。通过保持使得所述目标损失函数最小的所述全体参数不变,根据梯度上升算法迭代更新所述温度参数,使得目标损失函数的值最大,确定最大损失值。由此得到了基于当前全体参数的使得目标损失函数最大的所述温度参数。通过使所述最小损失值和所述最大损失值的差值的绝对值小于预设阈值,说明迭代满足预设要求,可以筛选出合适的温度参数。同时确定所述最大损失值后,跳转到根据梯度下降算法迭代更新所述学生模型的全体参数的步骤重复执行,代表当前学习难度级别的所述温度参数也在不断上升,也符合了由易到难的学习过程,提高了学生模型学习的效率。基于此,本专利技术提供的知识蒸馏的对抗式蒸馏温度调整方法、装置、设备及介质所实现的方案,通过梯度下降算法迭代学生模型中的全体参数使得目标损失函数的值最小,通过梯度上升算法迭代温度参数使得目标损失函数的值最大,交替执行上述迭代过程,使得所述最大损失值与所述最小损失值之间差值的绝对值小于预设阈值。确定出当前蒸馏温度。由此提高了学生模型学习的效率,同时符合由易到难的学习过
程。通过可变换的温度参数,缓解了固定温度超参数带来的模型性能瓶颈问题。将所述温度参数从超参数转换为参数,让模型自主学习,降低了暴力搜索最佳参数所需的算力成本和时间成本,使得模型可以自动学习最佳的所述温度参数,提高了学生模型的学习效率。
附图说明
[0024]为了更清楚地说明本专利技术实施例的技术方案,下面将对本专利技术实施例的描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本专利技术的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
[0025]图1是本专利技术一实施例中知识蒸馏的对抗式蒸馏温度调整方法的一应用环境示意图;
[0026]图2是本专利技术一实施例中知识蒸馏的对抗式蒸馏温度调整方法的一流程示意图;
[0027]图3是本专利技术一实施例中知识蒸馏的对抗式蒸馏温度调整装置的一结构示意图;
[0028]图4是本专利技术一实施例中计算机设本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种知识蒸馏的对抗式蒸馏温度调整方法,其特征在于,包括:获取教师模型、学生模型和温度参数;根据所述教师模型、所述学生模型和所述温度参数构建知识蒸馏的目标损失函数;根据梯度下降算法迭代更新所述学生模型的全体参数,以此更新所述学生模型,每一次迭代更新所述学生模型后,计算一次所述目标损失函数的值,直到所述目标损失函数的值最小,确定最小损失值;保持使得所述目标损失函数最小的所述全体参数不变,根据梯度上升算法迭代更新所述温度参数,每一次迭代更新所述温度参数后,计算一次所述目标损失函数的值,直到所述目标损失函数的值最大,确定最大损失值;确定所述最大损失值后,跳转到根据梯度下降算法迭代更新所述学生模型的全体参数的步骤重复执行,直到所述最大损失值与所述最小损失值之间差值的绝对值小于预设阈值,将所述绝对值小于预设阈值时对应的所述温度参数作为当前蒸馏温度。2.根据权利要求1所述的知识蒸馏的对抗式蒸馏温度调整方法,其特征在于,所述根据所述教师模型、所述学生模型和所述温度参数构建知识蒸馏的目标损失函数,包括:获取教师模型的第一损失函数;根据所述教师模型、所述学生模型和所述温度参数确定知识蒸馏的第二损失函数;根据所述第一损失函数和所述第二损失函数构建所述目标损失函数。3.根据权利要求2所述的知识蒸馏的对抗式蒸馏温度调整方法,其特征在于,所述根据所述教师模型、所述学生模型和所述温度参数确定知识蒸馏的第二损失函数,包括:获取模型输入样本集合;将所述模型输入样本集合中的任一样本分别输入所述教师模型和所述学生模型,分别得到第一教师模型输出向量和第一学生模型输出向量;根据所述第一教师模型输出向量确定教师模型概率向量,根据所述第一学生模型输出向量确定学生模型概率向量;计算所述教师模型概率向量和所述学生模型概率向量的第一相对熵;根据所述温度参数对所述第一相对熵进行放大或缩小,得到第二相对熵;根据所有所述模型输入样本集合中所有样本的所述第二相对熵,构建所述第二损失函数。4.根据权利要求3所述的知识蒸馏的对抗式蒸馏温度调整方法,其特征在于,所述根据所述第一教师模型输出向量确定教师模型概率向量,根据所述第一学生模型输出向量确定学生模型概率向量,包括:将所述第一教师模型输出向量除以所述温度参数,得到第二教师模型输出向量;将所述第一学生模型输出向量除以所述温度参数,得到第二学生模型输出向量;通过归一化指数函数将所述第二教师模型输出向量转换为教师模型概率向量;通过归一化指数函数将所述第二学生模型输出向量转换为学生模型概率向量。5.根据权利要求2所述的知识蒸馏的对抗式蒸馏温...

【专利技术属性】
技术研发人员:谯轶轩
申请(专利权)人:平安科技深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1