视觉转换器和用于训练视觉转换器的方法技术

技术编号:39430359 阅读:18 留言:0更新日期:2023-11-19 16:15
公开视觉转换器和用于训练视觉转换器的方法。基于教师网络类标记和学生网络(训练期间的视觉转换器)的图单元重要性分数的输入图像的图单元蒸馏损失在视觉转换器的修剪层处被确定。在当前回合数为奇数时,输入图像的图单元的稀疏化被跳过,并且密集输入图像由修剪层之后的层处理。在当前回合数为偶数时,输入图像的图单元在修剪层处被修剪并且由修剪层之后的层处理。输入图像的标签损失和总损失由后续层确定,并且学生网络被更新。并且学生网络被更新。并且学生网络被更新。

【技术实现步骤摘要】
视觉转换器和用于训练视觉转换器的方法
[0001]本申请要求于2022年5月10日提交的第63/340,375号美国临时申请的优先权权益,所述美国临时申请的公开通过引用全部包含于此。


[0002]在此公开的主题涉及视觉转换器(vision transformer)。更具体地,在此公开的主题涉及训练视觉转换器的系统和方法。

技术介绍

[0003]卷积神经网络(CNN)已经促进了计算机视觉领域的快速发展。随着一些视觉转换器结果在广泛的任务(诸如,分类、语义分割和对象检测)中超过CNN性能,关于视觉转换器的新兴研究展现出令人鼓舞的结果。为了提高(特别是在边缘装置上的)CNN中的模型效率,模型压缩技术(诸如,修剪(pruning)、量化和知识蒸馏(knowledge distillation))已经被广泛使用。针对图像分类,显著的图块单元(salient patch token)的数量根据输入图像的难度而变化。高效且数据特定的图单元(token)修剪实现了有效的模型加速,但是高效且数据特定的图单元修剪是一个未解决的问题。然而,对视觉转换器中的稀疏性探索得较少。此外,针对精度敏感的应用,通常导致轻微精度损失的典型压缩技术是不理想的。

技术实现思路

[0004]示例实施例提供一种用于训练视觉转换器的方法,所述方法可包括:在视觉转换器的修剪层P处,基于教师网络类标记和学生网络在修剪层P处的图单元重要性分数TIS
P
来确定输入图像的图单元蒸馏损失L
distill
,其中,输入图像可以是用于将视觉转换器训练预定数量的回合的图像数据库的一部分,并且学生网络可包括训练期间的视觉转换器;基于当前回合是奇数,由视觉转换器的在修剪层P之后的层通过跳过在修剪层P处对输入图像的图单元的稀疏化来处理输入图像;基于当前回合是偶数,由视觉转换器的在修剪层P之后的层通过在修剪层P处对输入图像的图单元进行修剪来处理输入图像;在由视觉转换器的在修剪层P之后的层处理输入图像之后,确定输入图像的标签损失L
loss
和总损失L;以及基于输入图像的标签损失L
loss
和总损失L来更新视觉转换器的学生网络。在一个实施例中,对输入图像的图单元进行修剪的步骤可包括:对输入图像的具有小于预定阈值的图单元重要性分数的图单元进行修剪。在另一实施例中,对输入图像的图单元进行修剪的步骤可包括:自适应地对输入图像的具有小于预定阈值的图单元重要性分数的图单元进行修剪。在又一实施例中,对输入图像的图单元进行修剪的步骤可包括:对不在图单元重要性分数的总和等于或大于预定阈值的最小数量的具有最高权重的图单元的组中的图单元进行修剪。在又一实施例中,在修剪层P处对输入图像的图单元进行修剪的步骤使用图单元掩码M来对输入图像的图单元进行修剪。在一个实施例中,输入图像的图单元蒸馏损失L
distill
还可基于教师网络类标记与学生网络的图单元重要性分数TIS
P
的库尔贝克

莱布勒散度。在另一实施例中,修剪层P可包括视觉转换器的第三层。
[0005]示例实施例提供一种视觉转换器,所述视觉转换器可包括第一层组和第二层组。第一层组可基于教师网络类标记和学生网络的图单元重要性分数TIS
P
来输出输入图像的图单元蒸馏损失L
distill
,其中,输入图像可以是用于将视觉转换器训练第一预定数量的回合的图像数据库的一部分,并且学生网络可包括训练期间的视觉转换器。第二层组在第一层组之后,并且可通过以下处理被训练:基于当前回合是奇数,由第二层组通过跳过对第一层组内的输入图像的图单元的稀疏化来处理输入图像;基于当前回合是偶数,由第二层组通过对第一层组内的输入图像的图单元进行修剪来处理输入图像;在由第二层组处理输入图像之后,确定输入图像的标签损失L
loss
和总损失L;以及基于输入图像的标签损失L
loss
和总损失L来更新所述视觉转换器的学生网络。在一个实施例中,对输入图像的图单元进行修剪的处理可包括:对输入图像的具有小于预定阈值的图单元重要性分数的图单元进行修剪。在另一实施例中,对输入图像的图单元进行修剪的处理可包括:自适应地对输入图像的具有小于预定阈值的图单元重要性分数的图单元进行修剪。在又一实施例中,对输入图像的图单元进行修剪的处理可包括:对不在图单元重要性分数的总和等于或大于预定阈值的最小数量的具有最高权重的图单元的组中的图单元进行修剪。在又一实施例中,对输入图像的图单元进行修剪的处理可包括:使用图单元掩码M来对输入图像的图单元进行修剪。在一个实施例中,输入图像的图单元蒸馏损失L
distill
还可基于教师网络类标记与学生网络的图单元重要性分数TIS
P
的库尔贝克

莱布勒散度。在另一个实施例中,第一层组可包括所述视觉转换器的前三层。
[0006]示例实施例提供一种用于训练视觉转换器的方法,其中,所述方法可包括:在视觉转换器的第一层组的输出处,基于教师网络类标记和学生网络的图单元重要性分数TIS
P
来确定输入图像的图单元蒸馏损失L
distill
,其中,输入图像可以是用于将视觉转换器训练预定数量的回合的图像数据库的一部分,并且学生网络可以是训练期间的视觉转换器;基于当前回合是奇数,由视觉转换器的在第一层组之后的第二层组通过跳过对第一层组内的输入图像的图单元的稀疏化来处理输入图像;基于当前回合是偶数,由第二层组通过使用图单元掩码M对第一层组内的输入图像的图单元进行修剪来处理输入图像;在通过第二层组处理输入图像之后,确定输入图像的标签损失L
loss
和总损失L;以及基于输入图像的标签损失L
loss
和总损失L来更新视觉转换器的学生网络。在一个实施例中,对输入图像的图单元进行修剪的步骤可包括:对输入图像的具有小于预定阈值的图单元重要性分数的图单元进行修剪。在另一实施例中,对输入图像的图单元进行修剪的步骤可包括:自适应地对输入图像的具有小于预定阈值的图单元重要性分数的图单元进行修剪。在又一实施例中,对输入图像的图单元进行修剪的步骤可包括:对不在图单元重要性分数的总和等于或大于预定阈值的最小数量的最高加权图单元的组中的图单元进行修剪。在又一实施例中,输入图像的图单元蒸馏损失L
distill
还可基于教师网络类标记与学生网络的图单元重要性分数TIS
P
的库尔贝克

莱布勒散度。在另一实施例中,第一层组可包括视觉转换器的前三层。
附图说明
[0007]在下面的部分中,将参照附图中示出的示例性实施例来描述在此公开的主题的方面,其中:
[0008]图1描绘由在此公开的一般训练框架模型提供的自适应修剪技术的可视化;
[0009]图2描绘根据在此公开的主题的训练系统的示例实施例的框图,该训练系统可用于使用自适应图单元修剪技术来训练本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种用于训练视觉转换器的方法,所述方法包括:在视觉转换器的修剪层P处,基于教师网络类标记和学生网络在修剪层P处的图单元重要性分数TIS
P
来确定输入图像的图单元蒸馏损失L
distill
,输入图像是用于将视觉转换器训练预定数量的回合的图像数据库的一部分,并且学生网络包括训练期间的视觉转换器;基于当前回合数是奇数,由视觉转换器的在修剪层P之后的层通过跳过在修剪层P处对输入图像的图单元的稀疏化来处理输入图像;基于当前回合数是偶数,由视觉转换器的在修剪层P之后的层通过在修剪层P处对输入图像的图单元进行修剪来处理输入图像;在由视觉转换器的在修剪层P之后的层处理输入图像之后,确定输入图像的标签损失L
loss
和总损失L;以及基于输入图像的标签损失L
loss
和总损失L来更新视觉转换器。2.根据权利要求1所述的方法,其中,对输入图像的图单元进行修剪的步骤包括:对输入图像的具有小于预定阈值的图单元重要性分数的图单元进行修剪。3.根据权利要求1所述的方法,其中,对输入图像的图单元进行修剪的步骤包括:自适应地对输入图像的具有小于预定阈值的图单元重要性分数的图单元进行修剪。4.根据权利要求3所述的方法,其中,对输入图像的图单元进行修剪的步骤包括:对不在图单元重要性分数的总和等于或大于预定阈值的最小数量的具有最高权重的图单元的组中的图单元进行修剪。5.根据权利要求1所述的方法,其中,在修剪层P处对输入图像的图单元进行修剪的步骤使用图单元掩码M来对输入图像的图单元进行修剪。6.根据权利要求1至5中的任一项所述的方法,其中,输入图像的图单元蒸馏损失L
distill
还基于教师网络类标记与学生网络的图单元重要性分数TIS
P
的库尔贝克

莱布勒散度。7.根据权利要求1至5中的任一项所述的方法,其中,修剪层P包括视觉转换器的第三层。8.一种视觉转换器,包括:第一层组,基于教师网络类标记和学生网络的图单元重要性分数TIS
P
来输出输入图像的图单元蒸馏损失L
distill
,输入图像是用于将视觉转换器训练第一预定数量的回合的图像数据库的一部分,并且学生网络包括训练期间的视觉转换器;以及第二层组,在第一层组之后,第二层组通过以下处理被训练:基于当前回合数是奇数,由第二层组通过跳过对第一层组内的输入图像的图单元的稀疏化来处理输入图像,基于当前回合数是偶数,由第二层组通过对第一层组内的输入图像的图单元进行修剪来处理输入图像,在由第二层组处理输入图像之后,确定输入图像的标签损失L
loss
和总损失L,以及基于输入图像的标签损失L
loss
和总损失L来更新所述视觉转换器。9.根据权利要求8所述的视觉转换器,其中,对输入图像的图单元进行修...

【专利技术属性】
技术研发人员:李玲阿里
申请(专利权)人:三星电子株式会社
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1