基于域自适应的模型训练方法、目标比对方法和相关装置制造方法及图纸

技术编号:33446131 阅读:53 留言:0更新日期:2022-05-19 00:31
本申请实施例公开了一种基于域自适应的模型训练方法、目标比对方法和相关装置,通过将同一三维目标不同视角下的三维目标特征聚合成多视角三维目标特征,基于多视角三维目标特征进行模型训练,使训练的特征编码器从单一视角三维目标图像推理多视角三维目标特征,以更加全面地描述三维目标,缓解不同视角间三维目标形态差异对比对准确率的挑战,提高了三维目标比对准确率;通过三重对抗学习,即相机分类器和特征编码器对抗训练,特征编码器和身份

【技术实现步骤摘要】
基于域自适应的模型训练方法、目标比对方法和相关装置


[0001]本申请属于人工智能领域,尤其涉及一种基于域自适应的模型训练方法、目标比对方法和相关装置。

技术介绍

[0002]目前,基于无监督域自适应的三维目标比对方法主要有三类:基于自训练的算法;基于图像风格迁移的算法;基于域不变特征提取的算法。
[0003]但是,这些方法忽略了不同视角下三维目标图像的差异,还忽略了三维目标图像的域差异,导致三维目标比对准确率较低。

技术实现思路

[0004]本申请实施例提供一种基于域自适应的模型训练方法、目标比对方法和相关装置,可以解决现有三维目标比对准确率低下的问题。
[0005]第一方面,本申请实施例提供一种基于域自适应的模型训练方法,包括:
[0006]获取源域数据集和目标域数据集,源域数据集包括源域三维目标图像、身份标签和第一相机标签,目标域数据集包括目标域三维目标图像和第二相机标签;
[0007]使用特征编码器对源域三维目标图像和目标域三维目标图像进行特征提取,得到源域三维目标特征和目标域三维目标特征;
[0008]将源域三维目标特征输入至身份分类器,获得身份分类器输出的第一分类结果,并根据第一分类结果和身份标签,计算第一交叉熵损失值;
[0009]将同一三维目标在不同相机视角下的源域三维目标特征进行特征聚合,得到多视角三维目标特征,并将多视角三维目标特征输入至身份分类器,获得身份分类器输出的第二分类结果,并根据第二分类结果和身份标签,计算第二交叉熵损失值,根据多视角三维目标特征和源域三维目标特征,计算第一损失值;
[0010]基于源域三维目标图像和目标域三维目标图像,对相机分类器和特征编码器进行对抗训练,并根据第一相机标签和第二相机标签,计算第一对抗损失值,基于源域三维目标图像和目标域三维目标图像,对特征编码器和身份

域分类器进行对抗训练,并计算第二对抗损失值,基于源域三维目标图像,对特征编码器、身份分类器和去除域类别的身份

域分类器进行对抗训练,并计算第三对抗损失值;
[0011]根据第一交叉熵损失值、第二交叉熵损失值、第一损失值、第一对抗损失值、第二对抗损失值和第三对抗损失值,计算总损失值,并根据总损失值调整特征编码器、身份分类器、相机分类器、身份

域分类器以及去除域类别的身份

域分类器的参数;
[0012]迭代训练多次,当总损失值达到预设条件时,则获得训练好的特征编码器。
[0013]由上可见,将同一三维目标在不同相机视角下的三维目标特征聚合成多视角三维目标特征,并基于多视角三维目标特征进行模型训练,使得训练出的特征编码器可以从单一视角三维目标图像中推理出多视角三维目标特征,这样可以更加全面地描述三维目标,
缓解不同视角间三维目标形态差异对三维目标比对准确率的挑战,提高了三维目标比对准确率;另外,通过三重对抗学习,即相机分类器和特征编码器进行对抗训练,特征编码器和身份

域分类器进行对抗训练,特征编码器、身份分类器和去除域类别的身份

域分类器进行对抗训练,缓解了源域三维目标和目标域三维目标之间的域偏移,使得训练好的模型在目标域上有更好的性能表现,提高了三维目标比对准确率。
[0014]在第一方面的一些可能的实现方式中,使用特征编码器对源域三维目标图像和目标域三维目标图像进行特征提取,得到源域三维目标特征和目标域三维目标特征,包括:
[0015]将源域三维目标图像输入至特征编码器,获得特征编码器输出的源域特征图;
[0016]将目标域三维目标图像输入至特征编码器,获得特征编码器输出的目标域特征图;
[0017]将源域特征图等分成第一特征图和第二特征图,并对第一特征图分别进行实例归一化操作和批归一化操作,获得第一实例归一化结果和第一批归一化结果,对第二特征图进行实例归一化操作和批归一化操作,获得第二实例归一化结果和第二批归一化结果;
[0018]将第一实例归一化结果和第二批归一化结果进行拼接,得到第一拼接结果,将第一批归一化结果和第二实例归一化结果进行拼接,得到第二拼接结果,并将第一拼接结果和第二拼接结果进行加权平均,得到源域三维目标特征;
[0019]将目标域特征图等分成第三特征图和第四特征图,并对第三特征图分别进行实例归一化操作和批归一化操作,获得第三实例归一化结果和第三批归一化结果,对第四特征图进行实例归一化操作和批归一化操作,获得第四实例归一化结果和第四批归一化结果;
[0020]将第三实例归一化结果和第四批归一化结果进行拼接,得到第三拼接结果,将第三批归一化结果和第四实例归一化结果进行拼接,得到第四拼接结果,并将第三拼接结果和第四拼接结果进行加权平均,得到目标域三维目标特征。
[0021]在该实现方式中,通过本申请实施例提供的交叉归一化方法,即同时在特征图上进行实例归一化和批归一化,提高了模型的泛化能力。
[0022]在第一方面的一些可能的实现方式中,将同一三维目标在不同相机视角下的源域三维目标特征进行特征聚合,得到多视角三维目标特征,包括:
[0023]将源域三维目标特征输入至身份分类器,获得身份分类器输出的身份类别概率;
[0024]根据身份类别概率,确定源域三维目标特征的权重;
[0025]基于每个源域三维目标特征的权重,将同一三维目标在不同相机视角下的源域三维目标特征进行加权平均,得到加权平均后的三维目标特征;
[0026]将加权平均后的三维目标特征和源域三维目标特征进行拼接后再进行全局池化,得到多视角三维目标特征。
[0027]在第一方面的一些可能的实现方式中,基于源域三维目标图像和目标域三维目标图像,对相机分类器和特征编码器进行对抗训练,并根据第一相机标签和第二相机标签,计算第一对抗损失值,包括:
[0028]固定特征编码器的参数,使用源域三维目标图像和目标域三维目标图像训练相机分类器,获得训练后的相机分类器,并根据第一相机标签和第二相机标签,计算第二损失值;
[0029]固定训练后的相机分类器的参数,使用源域三维目标图像和目标域三维目标图像
更新特征编码器的参数,并计算得到第三损失值;
[0030]第一对抗损失值包括第二损失值和第三损失值。
[0031]在第一方面的一些可能的实现方式中,基于源域三维目标图像和目标域三维目标图像,对特征编码器和身份

域分类器进行对抗训练,并计算第二对抗损失值,包括:
[0032]使用所述源域三维目标图像和所述目标域三维目标图像对所述特征编码器和所述身份

域分类器进行训练,并计算第四损失值,得到训练后特征编码器和训练后身份

域分类器;
[0033]固定所述训练后特征编码器的参数,使用所述目标三维目标图像训练所述身份

域分类器,获得优化本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于域自适应的模型训练方法,其特征在于,包括:获取源域数据集和目标域数据集,所述源域数据集包括源域三维目标图像、身份标签和第一相机标签,所述目标域数据集包括目标域三维目标图像和第二相机标签;使用特征编码器对所述源域三维目标图像和目标域三维目标图像进行特征提取,得到源域三维目标特征和目标域三维目标特征;将所述源域三维目标特征输入至身份分类器,获得身份分类器输出的第一分类结果,并根据所述第一分类结果和所述身份标签,计算第一交叉熵损失值;将同一三维目标在不同相机视角下的所述源域三维目标特征进行特征聚合,得到多视角三维目标特征,并将所述多视角三维目标特征输入至所述身份分类器,获得身份分类器输出的第二分类结果,并根据所述第二分类结果和所述身份标签,计算第二交叉熵损失值,根据所述多视角三维目标特征和所述源域三维目标特征,计算第一损失值;基于所述源域三维目标图像和所述目标域三维目标图像,对相机分类器和所述特征编码器进行对抗训练,并根据所述第一相机标签和所述第二相机标签,计算第一对抗损失值,基于所述源域三维目标图像和所述目标域三维目标图像,对所述特征编码器和身份

域分类器进行对抗训练,并计算第二对抗损失值,基于所述源域三维目标图像,对所述特征编码器、所述身份分类器和去除域类别的所述身份

域分类器进行对抗训练,并计算第三对抗损失值;根据所述第一交叉熵损失值、所述第二交叉熵损失值、所述第一损失值、所述第一对抗损失值、所述第二对抗损失值和所述第三对抗损失值,计算总损失值,并根据所述总损失值调整所述特征编码器、所述身份分类器、所述相机分类器、所述身份

域分类器以及所述去除域类别的身份

域分类器的参数;迭代训练多次,当所述总损失值达到预设条件时,则获得训练好的特征编码器。2.如权利要求1所述的方法,其特征在于,使用特征编码器对所述源域三维目标图像和目标域三维目标图像进行特征提取,得到源域三维目标特征和目标域三维目标特征,包括:将所述源域三维目标图像输入至所述特征编码器,获得所述特征编码器输出的源域特征图;将所述目标域三维目标图像输入至所述特征编码器,获得所述特征编码器输出的目标域特征图;将所述源域特征图等分成第一特征图和第二特征图,并对所述第一特征图分别进行实例归一化操作和批归一化操作,获得第一实例归一化结果和第一批归一化结果,对所述第二特征图进行实例归一化操作和批归一化操作,获得第二实例归一化结果和第二批归一化结果;将所述第一实例归一化结果和所述第二批归一化结果进行拼接,得到第一拼接结果,将所述第一批归一化结果和所述第二实例归一化结果进行拼接,得到第二拼接结果,并将所述第一拼接结果和所述第二拼接结果进行加权平均,得到所述源域三维目标特征;将所述目标域特征图等分成第三特征图和第四特征图,并对所述第三特征图分别进行实例归一化操作和批归一化操作,获得第三实例归一化结果和第三批归一化结果,对所述第四特征图进行实例归一化操作和批归一化操作,获得第四实例归一化结果和第四批归一化结果;
将所述第三实例归一化结果和所述第四批归一化结果进行拼接,得到第三拼接结果,将所述第三批归一化结果和所述第四实例归一化结果进行拼接,得到第四拼接结果,并将所述第三拼接结果和所述第四拼接结果进行加权平均,得到所述目标域三维目标特征。3.如权利要求1所述的方法,其特征在于,将同一三维目标在不同相机视角下的所述源域三维目标特征进行特征聚合,得到多视角三维目标特征,包括:将所述源域三维目标特征输入至所述身份分类器,获得所述身份分类器输出的身份类别概率;根据所述身份类别概率,确定所述源域三维目标特征的权重;基于每个所述源域三维目标特征的权重,将同一三维目标在不同相机视角下的所述源域三维目标特征进行加权平均,得到加权平均后的三维目标特征;将所述加权平均后的三维目标特征和所述源域三维目标特征进行拼接后再进行全局池化,得到所述多视角三维目标特征。4.如权利要求1所述的方法,其特征在于,基于所述源域三维目标图像和所述目标域三维目标图像,对相机分类器和所述特征编码器进行对抗训练,并根据所述第一相机标签和所述第二相机标签,计算第一对抗损失值,包括:固定所述特征编码器的参数,使用所述源域三维目标图像和所述目标域三维目标图像训练所述相机分类器,获得训练后的相机分类器,并根据所述第一相机标签和所述第二相机标签,计算第二损失值;固定所述训练后的相机分类器的参数,使用所述源域三维目标图像和所述目标域三维目标图像更新所述特征编码器的参数,并计算得到第三损失值;所述第一对抗损失值包括所述第二损失值和第三损失值。5.如权利要求1所述的方法,其特征在于,基于所述源域三维目标图像和所述目标域三维目标图像,对所述特征编码器和身份

域分类器进行对...

【专利技术属性】
技术研发人员:陶大鹏李华锋林旭
申请(专利权)人:云南联合视觉科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1