一种用于轻量级网络的知识蒸馏热启动训练方法和系统技术方案

技术编号:39009321 阅读:14 留言:0更新日期:2023-10-07 10:40
本发明专利技术提供一种用于轻量级网络的知识蒸馏热启动训练方法和系统,该方法包括获取至少一个随机初始化的轻量级网络模型;将低分辨率图像和对应高分辨率图像作为一对训练样本,基于多个训练样本获取训练样本集合;将训练样本集合中的低分辨率图像输入轻量级网络模型,训练轻量级网络模型,获得第一预训练网络模型;将第一预训练网络模型的权重加载于轻量级网络模型,并保持训练的各项参数不变,基于训练样本集合,对轻量级网络模型进行一轮或多轮训练,获得第二预训练网络模型;将第二预训练网络模型的权重加载于轻量级网络模型,基于教师网络的监督,训练轻量级网络模型,获得第三预训练网络模型。训练网络模型。训练网络模型。

【技术实现步骤摘要】
一种用于轻量级网络的知识蒸馏热启动训练方法和系统


[0001]本说明书涉及人工智能
,特别涉及一种用于轻量级网络的知识蒸馏热启动训练方法和系统。

技术介绍

[0002]图像超分辨率技术是从相应的低分辨率图像生成高分辨率图像的技术,而轻量化图像超分辨率模型具有较高的计算效率和较少的参数量,使得其在图像超分辨率技术上得到广泛应用,但同时存在图像超分辨率效果不足等问题。研究者们通常使用网络剪枝、网络量化和知识蒸馏等方法基于教师网络的监督,提高轻量化图像超分辨率模型的性能。但由于教师网络与学生网络之间的表示空间等差异,会阻碍学生网络的训练,使得训练效果不佳。
[0003]因此,为了解决以上问题,希望提出一种用于轻量级网络的知识蒸馏热启动训练方法和系统。

技术实现思路

[0004]本说明书一个或多个实施例提供一种用于轻量级网络的知识蒸馏热启动训练方法。所述用于轻量级网络的知识蒸馏热启动训练方法包括:获取至少一个随机初始化的轻量级网络模型;将低分辨率图像和对应高分辨率图像作为一对训练样本,基于多个训练样本获取训练样本集合;将训练样本集合中的低分辨率图像输入所述轻量级网络模型,训练所述轻量级网络模型,获得第一预训练网络模型,基于L1损失函数调整所述轻量级网络模型的参数,直到输出数据与原始高分辨率图像数据的差距小于预设阈值;将所述第一预训练网络模型的权重加载于所述轻量级网络模型,并保持训练的各项参数不变,基于所述训练样本集合,对所述轻量级网络模型进行一轮或多轮训练,获得第二预训练网络模型,基于L1损失函数调整所述轻量级网络模型的参数,直到输出数据与原始高分辨率图像数据的差距小于预设阈值;将所述第二预训练网络模型的权重加载于所述轻量级网络模型,基于教师网络的监督,训练所述轻量级网络模型,获得第三预训练网络模型,基于混合损失函数调整所述轻量级网络模型的参数,直到输出数据与原始高分辨率图像数据及教师网络输出的高分辨率图像数据的差距小于预设阈值。
[0005]本说明书一个或多个实施例提供一种用于轻量级网络的知识蒸馏热启动训练系统,所述系统包括:第一获取模块,用于获取至少一个随机初始化的轻量级网络模型;第二获取模块,用于将低分辨率图像和对应高分辨率图像作为一对训练样本,基于多个训练样本获取训练样本集合;第一训练模块,用于将训练样本集合中的低分辨率图像输入所述轻量级网络模型,训练所述轻量级网络模型,获得第一预训练网络模型;基于L1损失函数调整所述轻量级网络模型的参数,直到输出数据与原始高分辨率图像数据的差距小于预设阈值;第二训练模块,用于将所述第一预训练网络模型的权重加载于所述轻量级网络模型,并保持训练的各项参数不变,基于所述训练样本集合,对所述轻量级网络模型进行一轮或多
轮训练,获得第二预训练网络模型,基于L1损失函数调整所述轻量级网络模型的参数,直到输出数据与原始高分辨率图像数据的差距小于预设阈值;第三训练模块,用于将所述第二预训练网络模型的权重加载于所述轻量级网络模型,基于教师网络的监督,训练所述轻量级网络模型,获得第三预训练网络模型,基于混合损失函数调整所述轻量级网络模型的参数,直到输出数据与原始高分辨率图像数据及教师网络模型输出的高分辨率图像数据的差距小于预设阈值。
[0006]本说明书一个或多个实施例提供一种计算机可读存储介质,所述存储介质存储计算机指令,当计算机读取存储介质中的计算机指令后,计算机执行上述的任意一项的方法。
[0007]在本说明书的一些实施例中,处理器通过对初始轻量级网络模型进行热启动训练及知识蒸馏训练,得到更高性能的轻量级网络模型。通过这种方式,可以让轻量级网络模型保持较低的计算量和数据量,得到更高的性能。通过热启动训练可以降低模型的异构性,让学生模型可以更好的获得教师模型的监督,提高训练效果与稳定性,并且可以让训练方法有更广的应用范围。
附图说明
[0008]本说明书将以示例性实施例的方式进一步说明,这些示例性实施例将通过附图进行详细描述。这些实施例并非限制性的,在这些实施例中,相同的编号表示相同的结构,其中:
[0009]图1是根据本说明书一些实施例所示的一种用于轻量级网络的知识蒸馏热启动训练系统的模块示意图;
[0010]图2是根据本说明书一些实施例所示的一种用于轻量级网络的知识蒸馏热启动训练方法的示例性流程图;
[0011]图3是根据本说明书一些实施例所示的第一预训练网络模型的示例性示意图;
[0012]图4是根据本说明书一些实施例所示的第二预训练网络模型的示例性示意图;
[0013]图5是根据本说明书一些实施例所示的第三预训练网络模型的示例性示意图。
具体实施方式
[0014]为了更清楚地说明本说明书实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单的介绍。显而易见地,下面描述中的附图仅仅是本说明书的一些示例或实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图将本说明书应用于其它类似情景。除非从语言环境中显而易见或另做说明,图中相同标号代表相同结构或操作。
[0015]应当理解,本文使用的“系统”、“装置”、“单元”和/或“模块”是用于区分不同级别的不同组件、元件、部件、部分或装配的一种方法。然而,如果其他词语可实现相同的目的,则可通过其他表达来替换所述词语。
[0016]如本说明书和权利要求书中所示,除非上下文明确提示例外情形,“一”、“一个”、“一种”和/或“该”等词并非特指单数,也可包括复数。一般说来,术语“包括”与“包含”仅提示包括已明确标识的步骤和元素,而这些步骤和元素不构成一个排它性的罗列,方法或者设备也可能包含其它的步骤或元素。
[0017]本说明书中使用了流程图用来说明根据本说明书的实施例的系统所执行的操作。应当理解的是,前面或后面操作不一定按照顺序来精确地执行。相反,可以按照倒序或同时处理各个步骤。同时,也可以将其他操作添加到这些过程中,或从这些过程移除某一步或数步操作。
[0018]图1是根据本说明书一些实施例所示的一种用于轻量级网络的知识蒸馏热启动训练系统的模块示意图。
[0019]在一些实施例中,所述一种用于轻量级网络的知识蒸馏热启动训练系统100可以包括第一获取模块110、第二获取模块120、第一训练模块130、第二训练模块140和第三训练模块150。
[0020]第一获取模块110可以用于获取至少一个随机初始化的轻量级网络模型。关于获取轻量级网络模型的更多细节可以参见图2及其相关描述。
[0021]第二获取模块120可以用于将低分辨率图像和对应高分辨率图像作为一对训练样本,基于多个训练样本获取训练样本集合。关于获取训练样本集合的更多细节可以参见图2及其相关描述。
[0022]第一训练模块130可以用于将训练样本集合中的低分辨率图像输入轻量级网络模型,训练轻量级网络模型,获得第一预训练网络模型;基于L1损失函数调整轻量级网络模型的参数,直到输出数据与原始高分辨率图像数据的本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种用于轻量级网络的知识蒸馏热启动训练方法,其特征在于,所述方法包括:获取至少一个随机初始化的轻量级网络模型;将低分辨率图像和对应高分辨率图像作为一对训练样本,基于多个训练样本获取训练样本集合;将训练样本集合中的低分辨率图像输入所述轻量级网络模型,训练所述轻量级网络模型,基于L1损失函数调整所述轻量级网络模型的参数,直到输出数据与原始高分辨率图像数据的差距小于预设阈值,获得第一预训练网络模型;将所述第一预训练网络模型的权重加载于所述轻量级网络模型,并保持训练的各项参数不变,基于所述训练样本集合,对所述轻量级网络模型进行一轮或多轮训练,基于L1损失函数调整所述轻量级网络模型的参数,直到输出数据与原始高分辨率图像数据的差距小于预设阈值,获得第二预训练网络模型;将所述第二预训练网络模型的权重加载于所述轻量级网络模型,基于教师网络的监督,训练所述轻量级网络模型,基于混合损失函数调整所述轻量级网络模型的参数,直到输出数据与原始高分辨率图像数据及教师网络模型输出的高分辨率图像数据的差距小于预设阈值,获得第三预训练网络模型。2.根据权利要求1所述的方法,其特征在于,所述基于L1损失函数调整所述轻量级网络模型的参数包括:基于L1损失函数计算第二预训练网络模型输出的高分辨率图像与原始高分辨率图像的差距;基于所述差距调整第二预训练网络模型的各项参数,再进行训练;直到所述差距低于所述预设阈值,停止训练,得到训练好的第二预训练网络模型。3.根据权利要求1所述的方法,其特征在于,所述L1损失函数为:其中,p是一个像素点,P指图像块;N是块中像素p的数量,x(p)和y(p)分别表示模型输出的图像数据和原始高分辨率图像数据。4.根据权利要求1所述的方法,其特征在于,所述混合损失函数包括训练网络的输出数据与教师网络模型输出数据之间的损失函数以及训练网络模型的输出数据与原始高分辨率图像数据之间的损失函数。5.根据权利要求4所述的方法,其特征在于,所述训练网络的输出数据与教师网络模型输出数...

【专利技术属性】
技术研发人员:邵杰雷敏武鑫梁爽陈飞宇许辉赵磊
申请(专利权)人:四川省人工智能研究院宜宾
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1