网络训练方法、装置、电子设备及存储介质制造方法及图纸

技术编号:37455719 阅读:23 留言:0更新日期:2023-05-06 09:27
本公开提供了一种网络训练方法、装置、电子设备及存储介质,该方法包括:获取训练样本图像及与训练样本图像对应的掩码样本图像;将掩码样本图像输入至第一网络,得到与掩码样本图像对应的第一多层级特征图;将样本图像输入至第二网络,得到与样本图像对应的第二多层级特征图;第二网络的规模大于第一网络的规模;基于解码器以及与第一多层级特征图中每个特征图分别对应的掩码图像,对第一多层级特征图进行特征恢复处理,得到恢复处理后的第一多层级特征图;每个特征图对应的掩码图像由目标掩码图像分别进行缩放处理后得到;基于恢复处理后的第一多层级特征图及第二多层级特征图对第一网络进行训练,以提升第一网络的性能。以提升第一网络的性能。以提升第一网络的性能。

【技术实现步骤摘要】
网络训练方法、装置、电子设备及存储介质


[0001]本公开涉及人工智能
,具体而言,涉及一种网络训练方法、装置、电子设备及存储介质。

技术介绍

[0002]知识蒸馏也称为师生学习,是一种有效的模型压缩和模型精度提升的技术,通过知识蒸馏可以将知识从容量更高的教师模型转移到可部署性更强、容量较小的学生模型,进而来提升学生模型的性能。
[0003]经研究发现,针对密集视觉检测任务,由于密集视觉检测任务对于图像的定位信息更加敏感,目前的知识蒸馏方法主要基于对教师特征图的模仿。然而,该基于特征图的知识蒸馏,通常将完整的图像输入学生网络,然后进行逐像素一对一的空间模仿,该模仿过程相对简单,导致学生模型的学习能力得不到较好的挖掘,学生模型的性能较差。

技术实现思路

[0004]本公开实施例至少提供一种网络训练方法、装置、电子设备及存储介质,能够提升学生网络的性能。
[0005]公开实施例提供了一种网络训练方法,包括:
[0006]获取训练样本图像以及与所述训练样本图像对应的掩码样本图像;其中,所述掩码样本图像由所述训练样本图像以及目标掩码图像生成;
[0007]将所述掩码样本图像输入至第一网络,并基于所述第一网络的第一特征提取网络对所述掩码样本图像进行特征提取,得到与所述掩码样本图像对应的第一多层级特征图;
[0008]将所述样本图像输入至第二网络,并基于所述第二网络的第二特征提取网络对所述样本图像进行特征提取,得到与所述样本图像对应的第二多层级特征图;所述第二网络的规模大于所述第一网络的规模;
[0009]基于解码器以及与所述第一多层级特征图中每个特征图分别对应的掩码图像,对所述第一多层级特征图进行特征恢复处理,得到恢复处理后的第一多层级特征图;其中,所述第一多层级特征图中每个特征图对应的掩码图像由所述目标掩码图像分别进行缩放处理后得到;
[0010]基于所述恢复处理后的第一多层级特征图以及所述第二多层级特征图对所述第一网络进行训练,并重复上述步骤,直到所述第一网络的训练结果符合预设要求,得到训练好的第一网络。
[0011]本公开实施例中,第一网络也称学生网络,第二网络也称教师网络,在基于特征进行知识蒸馏的过程中,通过对训练样本图像进行掩码处理,并通过模仿第二网络的输出的第二层级特征图恢复被掩码区域对应的特征,进而可以增加特征模仿的难度,也即,在不改变第一网络的网络结构的前提下,通过单独的解码器对蒸馏过程进行增强,进而可以提升第一网络的学习能力,如此,即使在输入的待检测图像存在部分被遮盖情况下,通过训练好
的第一网络也能够进行检测识别,从而提升了第一网络的检测性能。
[0012]在一种可能的实施方式中,所述第一特征提取网络包括呈金字塔结构的特征提取模块,所述将所述掩码样本图像输入至第一网络,并基于所述第一网络的第一特征提取网络对所述掩码样本图像进行特征提取,得到与所述掩码样本图像对应的第一多层级特征图,包括:
[0013]将所述掩码样本图像输入至第一网络,并基于所述特征提取模块对所述掩码样本图像进行特征提取,得到中间多层级特征图,并将所述中间多层级特征图作为所述第一多层级特征图;其中,所述中间多层级特征图包括多张尺寸不同的中间特征图。
[0014]本公开实施例中,由于所述特征提取模块呈金字塔结构设计,导致所述提取出来的第一多层级特征图包括多张尺寸不同的中间特征图,进而可以适用于各类密集视觉检测任务,如目标检测、实例分割和语义分割等。
[0015]在一种可能的实施方式中,所述第一特征提取网络包括呈金塔结构的特征提取模块以及多个掩码卷积模块;所述将所述掩码样本图像输入至第一网络,并基于所述第一网络的第一特征提取网络对所述掩码样本图像进行特征提取,得到与所述掩码样本图像对应的第一多层级特征图,包括:
[0016]将所述掩码样本图像以及所述目标掩码图像输入至第一网络,并基于所述特征提取模块对所述掩码样本图像进行特征提取,得到中间多层级特征图,其中,所述中间多层级特征图包括多张尺寸不同的中间特征图;
[0017]基于所述多个掩码卷积模块以及所述目标掩码图像,对所述中间多层级特征图中的每个中间特征图进行掩码处理,并将掩码处理后的中间多层级特征图,作为所述第一多层级特征图。
[0018]本实施方式中,除了能够提升第一网络的适用性之外,通过对每个中间特征图进行掩码处理后,可以避免掩盖区域和可见区域之间的混淆的特征交互,也即,由于第一网络中的用于特征提取的骨干网络中使用了掩码卷积,进而可以防止卷积过程中被掩盖的图像块受到其他可见图像块的影响,有助于进一步提升第一网络的性能表现。
[0019]在一种可能的实施方式中,所述基于所述多个掩码卷积模块以及所述目标掩码图像,对所述中间多层级特征图中的每个中间特征图进行掩码处理,包括:
[0020]根据所述中间多层级特征图中的每个中间特征图的尺寸大小,分别对所述目标掩码图像进行缩放处理,得到与所述中间多层级特征图中每个中间特征图分别对应的掩码图像;
[0021]针对每个中间特征图,基于所述掩码卷积模块对所述中间特征图以及与所述中间特征图对应的掩码图像进行点乘处理,得到掩码处理的中间特征图,并基于每个掩码处理的中间特征图得到所述掩码处理后的中间多层级特征图。
[0022]本实施方式中,在所述掩码卷积模块中对所述中间特征图以及与所述中间特征图对应的掩码图像进行点乘处理,即可得到掩码处理的中间特征图,进而可以提升对中间特征图进行掩码处理的效率。
[0023]在一种可能的实施方式中,所述解码器包括空间对齐模块、解码模块以及空间恢复模块,所述基于解码器以及与所述第一多层级特征图中每个特征图分别对应的掩码图像,对所述第一多层级特征图进行特征恢复处理,包括:
[0024]基于所述空间对齐模块将所述第一多层级特征图中的各个尺寸不同的特征图对齐到相同的空间分辨率,以使得所述第一多层级特征图中的各个特征图的尺寸大小对齐,得到空间对齐的多层级特征图;
[0025]基于所述空间对齐的多层级特征图分别对应的掩码图像,采用掩码标记替换所述空间对齐的多层级特征图中的掩码区域,得到带有掩码标记的多层级特征图,并基于所述解码模块对所述带有掩码标记的多层级特征图进行特征预测处理,得到特征预测处理的多层级特征图;
[0026]基于所述空间恢复模块将相同的空间分辨率的所述特征预测处理的多层级特征图恢复成原始尺寸的多层级特征图,得到所述恢复处理后的第一多层级特征图。
[0027]本实施方式中,通过对第一多层级特征图进行空间对齐后进行特征恢复操作,然后再将空间对齐后的特征图恢复至原始尺寸,不仅能够实现特征的恢复,还可以保证第一多层级特征图的尺寸,有助于确定后续的特征恢复损失。
[0028]在一种可能的实施方式中,所述基于所述空间对齐模块将所述第一多层级特征图中的各个尺寸不同的特征图对齐到相同的空间分辨率,包括:
[0029]针对所述第一多层级特征图中本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种网络训练方法,其特征在于,包括:获取训练样本图像以及与所述训练样本图像对应的掩码样本图像;其中,所述掩码样本图像由所述训练样本图像以及目标掩码图像生成;将所述掩码样本图像输入至第一网络,并基于所述第一网络的第一特征提取网络对所述掩码样本图像进行特征提取,得到与所述掩码样本图像对应的第一多层级特征图;将所述样本图像输入至第二网络,并基于所述第二网络的第二特征提取网络对所述样本图像进行特征提取,得到与所述样本图像对应的第二多层级特征图;所述第二网络的规模大于所述第一网络的规模;基于解码器以及与所述第一多层级特征图中每个特征图分别对应的掩码图像,对所述第一多层级特征图进行特征恢复处理,得到恢复处理后的第一多层级特征图;其中,所述第一多层级特征图中每个特征图对应的掩码图像由所述目标掩码图像分别进行缩放处理后得到;基于所述恢复处理后的第一多层级特征图以及所述第二多层级特征图对所述第一网络进行训练,并重复上述步骤,直到所述第一网络的训练结果符合预设要求,得到训练好的第一网络。2.根据权利要求1所述的方法,其特征在于,所述第一特征提取网络包括呈金字塔结构的特征提取模块,所述将所述掩码样本图像输入至第一网络,并基于所述第一网络的第一特征提取网络对所述掩码样本图像进行特征提取,得到与所述掩码样本图像对应的第一多层级特征图,包括:将所述掩码样本图像输入至第一网络,并基于所述特征提取模块对所述掩码样本图像进行特征提取,得到中间多层级特征图,并将所述中间多层级特征图作为所述第一多层级特征图;其中,所述中间多层级特征图包括多张尺寸不同的中间特征图。3.根据权利要求1所述的方法,其特征在于,所述第一特征提取网络包括呈金塔结构的特征提取模块以及多个掩码卷积模块;所述将所述掩码样本图像输入至第一网络,并基于所述第一网络的第一特征提取网络对所述掩码样本图像进行特征提取,得到与所述掩码样本图像对应的第一多层级特征图,包括:将所述掩码样本图像以及所述目标掩码图像输入至第一网络,并基于所述特征提取模块对所述掩码样本图像进行特征提取,得到中间多层级特征图,其中,所述中间多层级特征图包括多张尺寸不同的中间特征图;基于所述多个掩码卷积模块以及所述目标掩码图像,对所述中间多层级特征图中的每个中间特征图进行掩码处理,并将掩码处理后的中间多层级特征图,作为所述第一多层级特征图。4.根据权利要求3所述的方法,其特征在于,所述基于所述多个掩码卷积模块以及所述目标掩码图像,对所述中间多层级特征图中的每个中间特征图进行掩码处理,包括:根据所述中间多层级特征图中的每个中间特征图的尺寸大小,分别对所述目标掩码图像进行缩放处理,得到与所述中间多层级特征图中每个中间特征图分别对应的掩码图像;针对每个中间特征图,基于所述掩码卷积模块对所述中间特征图以及与所述中间特征图对应的掩码图像进行点乘处理,得到掩码处理的中间特征图,并基于每个掩码处理的中间特征图得到所述掩码处理后的中间多层级特征图。
5.根据权利要求1

4中任一项所述的方法,其特征在于,所述解码器包括空间对齐模块、解码模块以及空间恢复模块,所述基于解码器以及与所述第一多层级特征图中每个特征图分别对应的掩码图像,对所述第一多层级特征图进行特征恢复处理,包括:基于所述空间对齐模块将所述第一多层级特征图中的各个尺寸不同的特征图对齐到相同的空间分辨率,以使得所述第一多层级特征图中的各个特征图的尺寸大小对齐,得到空间对齐的多层级特征图;基于所述空间对齐的多层级特征图分别对应的掩码图像,采用掩码标记替换所述空间对齐的多层级特征图中的掩码区域,得到带有掩码标记的多层级特征图,并基于所述解码模块对所述带有掩码标记的多层级特征图进行特征预测处理,得到特征预测处理的多层级特征图;基于所述空间恢复模块将相同的空间分辨率的所述特征预测处理的多层级特征图恢复成原始尺寸的多层级特征图,得到所述恢复处理后的第一多层级特征图。6.根据权利要求5所述的方法,其特征在于,所述基于所述空间对齐模块将所述第一多层级特征图中的各个尺寸不同的特征图对齐到相同的空间分辨率,包括:针对所述第一多层级特征图中的每个特征图,将所述特征图与目标图像进行比较;当所述特征图的尺寸大于所述目标图像的尺寸的情况下,对所述特征图进行降维处理,使得所述特征图的尺寸与所述目标图像的尺寸一致;或者,当所述特征图的尺寸小于所述目标图像的尺寸的情况下,采用最近邻插值对所述特征图进行上采样处理,使得所述特征图的尺寸与所述目标图像的尺寸一致。7.根据权利要求5或6所述的方法,其特征在于,所述基于所述空间对齐模块将所述第一多层级特征图中的各个尺寸不同的特征图对齐到相同的空间分辨率之前,所述方法还包括:将所述第一多层级特征图的通道数与所述第二多层级特征图的通道数对...

【专利技术属性】
技术研发人员:劳珊珊宋广录刘博晓刘宇
申请(专利权)人:深圳市商汤科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1