一种模型训练方法及装置、电子设备和存储介质制造方法及图纸

技术编号:35017514 阅读:21 留言:0更新日期:2022-09-24 22:44
本公开涉及一种模型训练方法及装置、电子设备和存储介质。所述方法包括:获取目标域数据和源域数据,目标域数据和源域数据中包含的目标对象相同,源域数据中包含的目标对象具有标签,目标域数据和源域数据包括影像数据;根据目标域数据和源域数据,得到拟合数据,拟合数据包含目标域数据的风格特征和源域数据的结构特征,结构特征为目标对象的空间位置及布局的信息;根据拟合数据和标签,训练初始模型,得到中间模型;根据目标域数据,对中间模型进行自监督训练,得到目标模型,目标模型用于识别目标域数据中包含的目标对象。通过上述过程,有效提高了模型对跨域的无标注数据的识别能力,从而提升了模型在不同数据集之间的泛化能力。能力。能力。

【技术实现步骤摘要】
一种模型训练方法及装置、电子设备和存储介质


[0001]本公开涉及计算机视觉
,尤其涉及一种模型训练方法及装置、电子设备和存储介质。

技术介绍

[0002]深度学习模型在跨域的不同数据集上的性能可能会出现大幅度的下降。例如:在核磁共振成像(MRI)数据上训练的分割模型只能引用在核磁共振成像数据上,拿到电子计算机断层扫描(CT)数据上就会表现的很糟糕。在仅标注了MRI数据的情况下,我们无法用常规的训练方法,使得模型在没有CT数据的标注信息的情况下,很好的应对CT数据。

技术实现思路

[0003]有鉴于此,本公开提出了一种模型训练技术方案。
[0004]根据本公开的一方面,提供了一种模型训练方法,包括:获取目标域数据和源域数据,所述目标域数据和所述源域数据中包含的目标对象相同,所述源域数据中包含的目标对象具有标签,所述目标域数据和所述源域数据包括影像数据;根据所述目标域数据和所述源域数据,得到拟合数据,所述拟合数据包含所述目标域数据的风格特征和所述源域数据的结构特征,所述结构特征为所述目标对象的空间位置及布局的信息;根据所述拟合数据和所述标签,训练初始模型,得到中间模型;根据所述目标域数据,对所述中间模型进行自监督训练,得到目标模型,所述目标模型用于识别所述目标域数据中包含的目标对象。
[0005]在一种可能的实现方式中,所述根据所述目标域数据和所述源域数据,得到拟合数据,包括:将所述目标域数据和所述源域数据输入对抗生成网络,所述对抗生成网络包括提取网络和生成网络;根据所述提取网络,提取所述目标域数据的风格特征和所述源域数据的结构特征;根据所述生成网络,生成包含所述源域数据的风格特征和所述目标域数据的结构特征的拟合数据;其中,所述对抗生成网络通过所述目标域数据和所述源域数据进行对抗训练得到。
[0006]在一种可能的实现方式中,所述对抗生成网络通过所述目标域数据和所述源域数据进行对抗训练得到,包括:根据所述提取网络,提取所述目标域数据的第一风格特征和第一结构特征、所述源域数据的第二结构特征和第二风格特征;根据所述生成网络,生成第一重建数据、第一转换数据、第二转换数据和第二重建数据,所述第一重建数据包含第一风格特征和第一结构特征,所述第一转换数据包含第一风格特征和第二结构特征,所述第二转换数据包含所述第一结构特征和第二风格特征,所述第二重建数据包含第二结构特征和第二风格特征;根据所述判别网络,判别所述第一重建数据、所述第一转换数据、所述第二转换数据和所述第二重建数据的真假,得到判别结果;根据所述判别结果,对所述对抗生成网络的参数进行调整。
[0007]在一种可能的实现方式中,根据所述判别结果,对所述对抗生成网络的参数进行调整,包括:根据所述目标域数据和所述第一重建数据之间的差异,生成第一重建损失;根
据所述源域数据和所述第二重建数据之间的差异,生成第二重建损失;根据所述判别网络输出的对所述第一重建数据的判别结果,生成第一判别损失;根据所述判别网络输出的对所述第一转换数据的判别结果,生成第二判别损失;根据所述判别网络输出的对所述第二转换数据的判别结果,生成第三判别损失;根据所述判别网络输出的对所述第二重建数据的判别结果,生成第四判别损失;根据所述第一重建损失、所述第二重建损失、所述第一判别损失、所述第二判别损失、所述第三判别损失和所述第四判别损失,对所述生成网络和所述判别网络进行对抗训练,获得训练后的神经网络。
[0008]在一种可能的实现方式中,所述根据所述目标域数据,对所述中间模型进行自监督训练,得到目标模型,包括:将所述目标域数据输入所述中间模型,得到输出结果,作为所述目标域数据的伪标签;根据所述目标域数据和所述伪标签,训练所述中间模型,得到更新后的中间模型;对所述中间模型迭代地执行上述步骤,直至得到满足预设条件的目标模型。
[0009]在一种可能的实现方式中,在所述将所述目标域数据输入所述中间模型,得到输出结果,作为所述目标域数据的伪标签之后,所述根据所述目标域数据和所述伪标签,训练所述中间模型,得到更新后的中间模型之前,所述根据所述目标域数据,对所述中间模型进行自监督训练,得到目标模型,包括:将所述中间模型的参数进行初始化。
[0010]在一种可能的实现方式中,所述对所述中间模型迭代地执行上述步骤,直至得到满足预设条件的目标模型,包括:按照概率P
N

i
抽取第N

i次迭代得到的中间模型的输出结果,作为第N次迭代时所使用的目标域数据的伪标签,其中N为大于1的整数,i为小于N的正整数,P
N

i
的值与N

i的值正相关。
[0011]根据本公开的另一方面,提供了一种模型训练装置,包括:数据获取模块,用于获取目标域数据和源域数据,所述目标域数据和所述源域数据中包含的目标对象相同,所述源域数据中包含的目标对象具有标签,所述目标域数据和所述源域数据包括影像数据;拟合数据生成模块,用于根据所述目标域数据和所述源域数据,得到拟合数据,所述拟合数据包含所述目标域数据的风格特征和所述源域数据的结构特征,所述结构特征为所述目标对象的空间位置及布局的信息;全监督训练模块,用于根据所述拟合数据和所述标签,训练初始模型,得到中间模型;自监督训练模块,用于根据所述目标域数据,对所述中间模型进行自监督训练,得到目标模型,所述目标模型用于识别所述目标域数据中包含的目标对象。
[0012]在一种可能的实现方式中,所述拟合数据生成模块,包括:数据输入子模块,用于将所述目标域数据和所述源域数据输入对抗生成网络,所述对抗生成网络包括提取网络和生成网络;特征提取子模块,用于根据所述提取网络,提取所述目标域数据的风格特征和所述源域数据的结构特征;拟合数据生成子模块,用于根据所述生成网络,生成包含所述源域数据的风格特征和所述目标域数据的结构特征的拟合数据;其中,所述对抗生成网络通过所述目标域数据和所述源域数据进行对抗训练得到。
[0013]在一种可能的实现方式中,所述对抗生成网络通过所述目标域数据和所述源域数据进行对抗训练得到,包括:根据所述提取网络,提取所述目标域数据的第一风格特征和第一结构特征、所述源域数据的第二结构特征和第二风格特征;根据所述生成网络,生成第一重建数据、第一转换数据、第二转换数据和第二重建数据,所述第一重建数据包含第一风格特征和第一结构特征,所述第一转换数据包含第一风格特征和第二结构特征,所述第二转换数据包含所述第一结构特征和第二风格特征,所述第二重建数据包含第二结构特征和第
二风格特征;根据所述判别网络,判别所述第一重建数据、所述第一转换数据、所述第二转换数据和所述第二重建数据的真假,得到判别结果;根据所述判别结果,对所述对抗生成网络的参数进行调整。
[0014]在一种可能的实现方式中,所述根据所述判别结果,对所述对抗生成网络的参数进行调整,包括:根据所述目标域数据和所述第一重建数据之间的差异,生成第一重建损失;根据所述源域数据和所本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,包括:获取目标域数据和源域数据,所述目标域数据和所述源域数据中包含的目标对象相同,所述源域数据中包含的目标对象具有标签,所述目标域数据和所述源域数据包括影像数据;根据所述目标域数据和所述源域数据,得到拟合数据,所述拟合数据包含所述目标域数据的风格特征和所述源域数据的结构特征,所述结构特征为所述目标对象的空间位置及布局的信息;根据所述拟合数据和所述标签,训练初始模型,得到中间模型;根据所述目标域数据,对所述中间模型进行自监督训练,得到目标模型,所述目标模型用于识别所述目标域数据中包含的目标对象。2.根据权利要求1所述的方法,其特征在于,所述根据所述目标域数据和所述源域数据,得到拟合数据,包括:将所述目标域数据和所述源域数据输入对抗生成网络,所述对抗生成网络包括提取网络和生成网络;根据所述提取网络,提取所述目标域数据的风格特征和所述源域数据的结构特征;根据所述生成网络,生成包含所述源域数据的风格特征和所述目标域数据的结构特征的拟合数据;其中,所述对抗生成网络通过所述目标域数据和所述源域数据进行对抗训练得到。3.根据权利要求2所述的方法,其特征在于,所述对抗生成网络通过所述目标域数据和所述源域数据进行对抗训练得到,包括:根据所述提取网络,提取所述目标域数据的第一风格特征和第一结构特征、所述源域数据的第二结构特征和第二风格特征;根据所述生成网络,生成第一重建数据、第一转换数据、第二转换数据和第二重建数据,所述第一重建数据包含第一风格特征和第一结构特征,所述第一转换数据包含第一风格特征和第二结构特征,所述第二转换数据包含所述第一结构特征和第二风格特征,所述第二重建数据包含第二结构特征和第二风格特征;根据所述判别网络,判别所述第一重建数据、所述第一转换数据、所述第二转换数据和所述第二重建数据的真假,得到判别结果;根据所述判别结果,对所述对抗生成网络的参数进行调整。4.根据权利要求3所述的方法,其特征在于,所述根据所述判别结果,对所述对抗生成网络的参数进行调整,包括:根据所述目标域数据和所述第一重建数据之间的差异,生成第一重建损失;根据所述源域数据和所述第二重建数据之间的差异,生成第二重建损失;根据所述判别网络输出的对所述第一重建数据的判别结果,生成第一判别损失;根据所述判别网络输出的对所述第一转换数据的判别结果,生成第二判别损失;根据所述判别网络输出的对所述第二转换数据的判别结果,生成第三判别损失;根据所述判别网络输出的对所述第二重建数据的判别结果,生成第四判别损失;根据所述第一重建损失、所述第二重建损失、所述第一判别损失、...

【专利技术属性】
技术研发人员:陈亦新王焱陈晓天任艺柯张培芳吴振洲
申请(专利权)人:北京安德医智科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1