模型训练方法及装置、电子设备和存储介质制造方法及图纸

技术编号:34620514 阅读:14 留言:0更新日期:2022-08-20 09:27
本公开涉及一种模型训练方法及装置、电子设备和存储介质,所述方法包括:获取混合图像,混合图像是通过将两个样本图像中的图像块拼接而成的图像;通过预设的图像重建模型中的编码器对混合图像进行编码,得到混合图像的目标特征图;通过图像重建模型中的解码器对目标特征图进行解码,得到解码出的两个重建图像;根据两个重建图像与两个样本图像之间的损失,训练图像重建模型,以得到训练后的目标编码器。本公开实施例可实现提高了整体的模型训练效率以及训练后的模型性能。率以及训练后的模型性能。率以及训练后的模型性能。

【技术实现步骤摘要】
模型训练方法及装置、电子设备和存储介质


[0001]本公开涉及计算机
,尤其涉及一种模型训练方法及装置、电子设备和存储介质。

技术介绍

[0002]相关技术中提出了基于自监督学习“视觉表征(visual representations,或称视觉表达、图像表征)”的掩码图像建模(Masked Image Modeling,MIM)任务,好的“视觉表征”也即好的编码特征,能够提供对任务重要的信息,忽略与任务不相关的信息。在MIM任务中,首先会根据将原始图像分割成不重叠的图像补丁(patch),然后采用一个随机掩码来掩盖部分图像补丁,并用特殊符号填充掩盖的部分图像补丁得到待处理图像,再通过编码器对待处理图像进行处理得到隐含的“视觉表征”,并利用轻量级解码器基于“视觉表征”,生成重建图像,再利用重建图像与原始图像之间的相对均方差作为重建的损失来训练编码器和解码器,经过多次训练后的编码器可以用于下游视觉任务的网络模型中。
[0003]目前的MIM任务,虽然在自监督学习“视觉表征”的模型训练方面取得了进展,但为使得任务有足够的难度,使用大量特殊符号来填充掩码掩盖的部分图像补丁是不可避免的。然而真实的图像中并不存在这些无意义的特殊符号,这种输入上的区别会对模型训练产生的潜在负面影响,并且会在大量无意义的特殊符号(也即人造输入)上消耗大量计算资源,这种训练方式的训练时间较长、训练效率降低、通用性较差。

技术实现思路

[0004]本公开提出了一种模型训练技术方案。
[0005]根据本公开的一方面,提供了一种模型训练方法,包括:获取混合图像,所述混合图像是通过将两个样本图像中的图像块拼接而成的图像;通过预设的图像重建模型中的编码器对所述混合图像进行编码,得到所述混合图像的目标特征图;通过所述图像重建模型中的解码器对所述目标特征图进行解码,得到解码出的两个重建图像;根据所述两个重建图像与所述两个样本图像之间的损失,训练所述图像重建模型,以得到训练后的目标编码器。
[0006]在一种可能的实现方式中,所述编码器包括N个子编码器,每个子编码器包括多头注意力机制层,N为正整数;其中,所述通过预设的图像重建模型中的编码器对所述混合图像进行编码,得到所述混合图像的目标特征图,包括:确定每个子编码器中的多头注意力机制层所采用的注意力掩码,以及确定每个子编码器中的多头注意力机制层所采用的注意力窗口;通过所述N个子编码器,根据所述N个子编码器中所采用的注意力掩码与注意力窗口,对所述混合图像进行编码,得到所述混合图像的目标特征图;其中,所述注意力掩码用于指示所述多头注意力机制层计算同一样本图像的特征之间的多头注意力,所述注意力窗口用于指示所述多头注意力机制层计算同一注意力窗口内的特征之间的多头注意力。
[0007]在一种可能的实现方式中,所述确定每个子编码器中的多头注意力机制层所采用
的注意力掩码,包括:根据拼接所述混合图像时所采用的掩码图,确定所述N个子编码器中的第一个子编码器所采用的注意力掩码;根据所述N个子编码器中的第n个子编码器所编码的特征图尺度,对第n

1个子编码器所采用的注意力掩码进行下采样,得到第n个子编码器所采用的注意力掩码,2≤n≤N。
[0008]在一种可能的实现方式中,所述确定每个子编码器中的多头注意力机制层所采用的注意力窗口,包括:根据针对每个子编码器中的多头注意力机制层所预先设置的窗口尺寸,确定每个子编码器中的多头注意力机制层所采用的注意力窗口;其中,所述注意力窗口包括用于计算全局多头注意力的注意力窗口,以及用于分块计算局部多头注意力的注意力窗口中的至少一种。
[0009]在一种可能的实现方式中,所述通过所述N个子编码器,根据所述N个子编码器中所采用的注意力掩码与注意力窗口,对所述混合图像进行编码,得到所述混合图像的目标特征图,包括:将所述混合图像转换成指定维度的输入向量;通过第一个子编码器,根据所述第一个子编码器中所采用的注意力掩码与注意力窗口,对所述输入向量进行编码,得到第一个输出特征图;对第n

1输出特征图进行下采样,得到分辨率缩小且通道数增加的第n

1个输入特征图;通过第n个子编码器,根据第n个子编码器中所采用的注意力掩码以及注意力窗口,对第n

1个输入特征图进行编码,得到第n个输出特征图,2≤n≤N;将通过第N个子编码器编码得到的第N个输出特征图作为所述目标特征图。
[0010]在一种可能的实现方式中,所述将所述混合图像转换成指定维度的输入向量,包括:将拼接成所述混合图像的多个图像块进行通道展开与线性变换,得到序列向量;将所述混合图像对应的位置编码向量嵌入所述序列向量中,得到输入向量,其中,所述位置编码向量用于指示所述多个图像块各自在所述两个样本图像中的位置信息,不同样本图像中的图像块所采用的位置编码向量不同。
[0011]在一种可能的实现方式中,所述通过所述图像重建模型中的解码器对所述目标特征图进行解码,得到解码出的两个重建图像,包括:根据编码所述目标特征图时所采用的注意力掩码,将所述目标特征图拆解为两个子特征图;利用所述解码器对所述两个子特征图进行解码,得到解码出的两个重建图像。
[0012]在一种可能的实现方式中,所述两个样本图像包括第一样本图像与第二样本图像,所述获取混合图像,包括:根据所述第一样本图像与预设的掩码图,确定从所述第一样本图像中抽取的第一图像块;根据所述第二样本图像与所述掩码图对应的反向掩码图,确定从所述第二样本图像中抽取的第二图像块;将所述第一图像块与所述第二图像块进行拼接,得到所述混合图像。
[0013]在一种可能的实现方式中,所述两个样本图像包括第一样本图像与第二样本图像,所述两个重建图像包括所述第一样本图像对应的第一重建图像以及所述第二样本图像对应的第二重建图像,其中,所述根据所述两个重建图像与所述两个样本图像之间的损失,训练所述图像重建模型,以得到训练后的目标编码器,包括:确定所述第一重建图像与所述第一样本图像之间的第一图像差异,以及所述第二重建图像与所述第二样本图像之间的第二图像差异;根据所述第一图像差异与所述第二图像差异,确定所述损失,并根据所述损失,训练所述图像重建模型,以得到训练后的目标编码器。
[0014]在一种可能的实现方式中,所述确定所述第一重建图像与所述第一样本图像之间
的第一图像差异,以及所述第二重建图像与所述第二样本图像之间的第二图像差异,包括:根据从所述第二样本图像中抽取图像块时所采用的反向掩码图,确定所述第一重建图像与所述第一样本图像之间的第一图像差异;根据从所述第一样本图像中抽取图像块时所采用的掩码图,确定所述第二重建图像与所述第二样本图像之间的第二图像差异;其中,所述掩码图与所述反向掩码图对应相反。
[0015]在一种可能的实现方式中,所述目标编码器应用于下游任务的网络模型中,所述下游任务包括目标检测、图像补全、图像分割、图像分类中的至少一种。
[0016]根据本公开的一方面本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,所述方法包括:获取混合图像,所述混合图像是通过将两个样本图像中的图像块拼接而成的图像;通过预设的图像重建模型中的编码器对所述混合图像进行编码,得到所述混合图像的目标特征图;通过所述图像重建模型中的解码器对所述目标特征图进行解码,得到解码出的两个重建图像;根据所述两个重建图像与所述两个样本图像之间的损失,训练所述图像重建模型,以得到训练后的目标编码器。2.根据权利要求1所述的方法,其特征在于,所述编码器包括N个子编码器,每个子编码器包括多头注意力机制层,N为正整数;其中,所述通过预设的图像重建模型中的编码器对所述混合图像进行编码,得到所述混合图像的目标特征图,包括:确定每个子编码器中的多头注意力机制层所采用的注意力掩码,以及确定每个子编码器中的多头注意力机制层所采用的注意力窗口;通过所述N个子编码器,根据所述N个子编码器中所采用的注意力掩码与注意力窗口,对所述混合图像进行编码,得到所述混合图像的目标特征图;其中,所述注意力掩码用于指示所述多头注意力机制层计算同一样本图像的特征之间的多头注意力,所述注意力窗口用于指示所述多头注意力机制层计算同一注意力窗口内的特征之间的多头注意力。3.根据权利要求2所述的方法,其特征在于,所述确定每个子编码器中的多头注意力机制层所采用的注意力掩码,包括:根据拼接所述混合图像时所采用的掩码图,确定所述N个子编码器中的第一个子编码器所采用的注意力掩码;根据所述N个子编码器中的第n个子编码器所编码的特征图尺度,对第n

1个子编码器所采用的注意力掩码进行下采样,得到第n个子编码器所采用的注意力掩码,2≤n≤N。4.根据权利要求2或3所述的方法,其特征在于,所述确定每个子编码器中的多头注意力机制层所采用的注意力窗口,包括:根据针对每个子编码器中的多头注意力机制层所预先设置的窗口尺寸,确定每个子编码器中的多头注意力机制层所采用的注意力窗口;其中,所述注意力窗口包括用于计算全局多头注意力的注意力窗口,以及用于分块计算局部多头注意力的注意力窗口中的至少一种。5.根据权利要求2至4任一项所述的方法,其特征在于,所述通过所述N个子编码器,根据所述N个子编码器中所采用的注意力掩码与注意力窗口,对所述混合图像进行编码,得到所述混合图像的目标特征图,包括:将所述混合图像转换成指定维度的输入向量;通过第一个子编码器,根据所述第一个子编码器中所采用的注意力掩码与注意力窗口,对所述输入向量进行编码,得到第一个输出特征图;对第n

1输出特征图进行下采样,得到分辨率缩小且通道数增加的第n

1个输入特征图;
通过第n个子编码器,根据第n个子编码器中所采用的注意力掩码以及注意力窗口,对第n

1个输入特征图进行编码,得到第n个输出特征图,2≤n≤N;将通过第N个子编码器编码得到的第N个输出特征图作为所述目标特征图。6.根据权利要求5所述的方法,其特征在于,所述将所述混合图像转换成指定维度的输入向量,包括:将拼接成所述混合图像的多个图像块进行通道展开与线性变换,得到序列向量;将所述混合图像对应的位置编码向量嵌入所述序列向量中...

【专利技术属性】
技术研发人员:刘吉豪刘宇黄鑫
申请(专利权)人:上海商汤智能科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1