【技术实现步骤摘要】
基于域自适应的模型训练方法、目标比对方法和相关装置
[0001]本申请属于人工智能领域,尤其涉及一种基于域自适应的模型训练方法、目标比对方法和相关装置。
技术介绍
[0002]目前,基于无监督域自适应的三维目标比对方法主要有三类:基于自训练的算法;基于图像风格迁移的算法;基于域不变特征提取的算法。
[0003]但是,这些方法忽略了不同视角下三维目标图像的差异,还忽略了三维目标图像的域差异,导致三维目标比对准确率较低。
技术实现思路
[0004]本申请实施例提供一种基于域自适应的模型训练方法、目标比对方法和相关装置,可以解决现有三维目标比对准确率低下的问题。
[0005]第一方面,本申请实施例提供一种基于域自适应的模型训练方法,包括:
[0006]获取源域数据集和目标域数据集,源域数据集包括源域三维目标图像、身份标签和第一相机标签,目标域数据集包括目标域三维目标图像和第二相机标签;
[0007]使用特征编码器对源域三维目标图像和目标域三维目标图像进行特征提取,得到源域三维目标特征和目标域三维目标特征;
[0008]将源域三维目标特征输入至身份分类器,获得身份分类器输出的第一分类结果,并根据第一分类结果和身份标签,计算第一交叉熵损失值;
[0009]将同一三维目标在不同相机视角下的源域三维目标特征进行特征聚合,得到多视角三维目标特征,并将多视角三维目标特征输入至身份分类器,获得身份分类器输出的第二分类结果,并根据第二分类结果和身份标签,计算第二交叉熵损失值, ...
【技术保护点】
【技术特征摘要】
1.一种基于域自适应的模型训练方法,其特征在于,包括:获取源域数据集和目标域数据集,所述源域数据集包括源域三维目标图像、身份标签和第一相机标签,所述目标域数据集包括目标域三维目标图像和第二相机标签;使用特征编码器对所述源域三维目标图像和目标域三维目标图像进行特征提取,得到源域三维目标特征和目标域三维目标特征;将所述源域三维目标特征输入至身份分类器,获得身份分类器输出的第一分类结果,并根据所述第一分类结果和所述身份标签,计算第一交叉熵损失值;将同一三维目标在不同相机视角下的所述源域三维目标特征进行特征聚合,得到多视角三维目标特征,并将所述多视角三维目标特征输入至所述身份分类器,获得身份分类器输出的第二分类结果,并根据所述第二分类结果和所述身份标签,计算第二交叉熵损失值,根据所述多视角三维目标特征和所述源域三维目标特征,计算第一损失值;基于所述源域三维目标图像和所述目标域三维目标图像,对相机分类器和所述特征编码器进行对抗训练,并根据所述第一相机标签和所述第二相机标签,计算第一对抗损失值,基于所述源域三维目标图像和所述目标域三维目标图像,对所述特征编码器和身份
‑
域分类器进行对抗训练,并计算第二对抗损失值,基于所述源域三维目标图像,对所述特征编码器、所述身份分类器和去除域类别的所述身份
‑
域分类器进行对抗训练,并计算第三对抗损失值;根据所述第一交叉熵损失值、所述第二交叉熵损失值、所述第一损失值、所述第一对抗损失值、所述第二对抗损失值和所述第三对抗损失值,计算总损失值,并根据所述总损失值调整所述特征编码器、所述身份分类器、所述相机分类器、所述身份
‑
域分类器以及所述去除域类别的身份
‑
域分类器的参数;迭代训练多次,当所述总损失值达到预设条件时,则获得训练好的特征编码器。2.如权利要求1所述的方法,其特征在于,使用特征编码器对所述源域三维目标图像和目标域三维目标图像进行特征提取,得到源域三维目标特征和目标域三维目标特征,包括:将所述源域三维目标图像输入至所述特征编码器,获得所述特征编码器输出的源域特征图;将所述目标域三维目标图像输入至所述特征编码器,获得所述特征编码器输出的目标域特征图;将所述源域特征图等分成第一特征图和第二特征图,并对所述第一特征图分别进行实例归一化操作和批归一化操作,获得第一实例归一化结果和第一批归一化结果,对所述第二特征图进行实例归一化操作和批归一化操作,获得第二实例归一化结果和第二批归一化结果;将所述第一实例归一化结果和所述第二批归一化结果进行拼接,得到第一拼接结果,将所述第一批归一化结果和所述第二实例归一化结果进行拼接,得到第二拼接结果,并将所述第一拼接结果和所述第二拼接结果进行加权平均,得到所述源域三维目标特征;将所述目标域特征图等分成第三特征图和第四特征图,并对所述第三特征图分别进行实例归一化操作和批归一化操作,获得第三实例归一化结果和第三批归一化结果,对所述第四特征图进行实例归一化操作和批归一化操作,获得第四实例归一化结果和第四批归一化结果;
将所述第三实例归一化结果和所述第四批归一化结果进行拼接,得到第三拼接结果,将所述第三批归一化结果和所述第四实例归一化结果进行拼接,得到第四拼接结果,并将所述第三拼接结果和所述第四拼接结果进行加权平均,得到所述目标域三维目标特征。3.如权利要求1所述的方法,其特征在于,将同一三维目标在不同相机视角下的所述源域三维目标特征进行特征聚合,得到多视角三维目标特征,包括:将所述源域三维目标特征输入至所述身份分类器,获得所述身份分类器输出的身份类别概率;根据所述身份类别概率,确定所述源域三维目标特征的权重;基于每个所述源域三维目标特征的权重,将同一三维目标在不同相机视角下的所述源域三维目标特征进行加权平均,得到加权平均后的三维目标特征;将所述加权平均后的三维目标特征和所述源域三维目标特征进行拼接后再进行全局池化,得到所述多视角三维目标特征。4.如权利要求1所述的方法,其特征在于,基于所述源域三维目标图像和所述目标域三维目标图像,对相机分类器和所述特征编码器进行对抗训练,并根据所述第一相机标签和所述第二相机标签,计算第一对抗损失值,包括:固定所述特征编码器的参数,使用所述源域三维目标图像和所述目标域三维目标图像训练所述相机分类器,获得训练后的相机分类器,并根据所述第一相机标签和所述第二相机标签,计算第二损失值;固定所述训练后的相机分类器的参数,使用所述源域三维目标图像和所述目标域三维目标图像更新所述特征编码器的参数,并计算得到第三损失值;所述第一对抗损失值包括所述第二损失值和第三损失值。5.如权利要求1所述的方法,其特征在于,基于所述源域三维目标图像和所述目标域三维目标图像,对所述特征编码器和身份
‑
域分类器进行对...
【专利技术属性】
技术研发人员:陶大鹏,李华锋,林旭,
申请(专利权)人:云南联合视觉科技有限公司,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。