图像分类模型的训练方法、装置及电子设备制造方法及图纸

技术编号:37249919 阅读:9 留言:0更新日期:2023-04-20 23:28
本公开提供了图像分类模型的训练方法、装置及电子设备,涉及人工智能技术领域,尤其涉及自然语言处理、计算机视觉、深度学习技术领域。具体实现方案为:获取多个训练数据集,训练数据集包括支持集和查询集;获取初始的图像分类模型;针对每个训练数据集,根据训练数据集中的支持集和查询集,确定训练数据集中的多个样本图像对以及对应的样本相似度;依次针对每个训练数据集,采用其中的多个样本图像对以及对应的样本相似度,对图像分类模型进行训练,得到训练好的图像分类模型,从而能够根据较少的样本图像以及对应的类别,训练得到准确度较高的图像分类模型,能够适用于图像标注数据比较缺乏的图像分类任务,提高图像分类任务下的准确度。准确度。准确度。

【技术实现步骤摘要】
图像分类模型的训练方法、装置及电子设备


[0001]本公开涉及人工智能
,尤其涉及自然语言处理、计算机视觉、深度学习
,尤其涉及一种图像分类模型的训练方法、装置及电子设备。

技术介绍

[0002]目前,针对图像识别任务,需要对大量图像进行标记,得到图像标注数据;采用图像标注数据对深度学习模型进行训练,得到识别准确度较高的图像识别模型,用于图像识别任务。
[0003]其中,针对物种细粒度识别任务,由于很多物种的体型、外貌相似,
[0004]特征差异较小,只有相应领域的专家才能区分不同的物种,导致该任务下的图像标注数据缺乏,难以训练得到识别准确度较高的物种细粒度识别模型。

技术实现思路

[0005]本公开提供了一种图像分类模型的训练方法、装置及电子设备。
[0006]根据本公开的一方面,提供了一种图像分类模型的训练方法,所述方法包括:获取多个训练数据集,所述训练数据集包括支持集和查询集;所述支持集包括支持样本图像以及所述支持样本图像的类别;所述查询集中包括查询样本图像以及所述查询样本图像的类别;获取初始的图像分类模型;针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度;依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度为所述图像分类模型的输出,对所述图像分类模型进行训练,得到训练好的图像分类模型。
[0007]根据本公开的另一方面,提供了一种图像分类方法,所述方法包括:获取待处理图像以及支持集,所述支持集包括多个支持样本图像以及所述支持样本图像的类别;根据所述待处理图像以及多个所述支持样本图像,生成多个图像对;所述图像对中包括所述待处理图像以及所述支持样本图像;将所述待处理图像对输入图像分类模型的特征提取网络以及注意力机制网络,获取所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征;所述图像分类模型基于如上所述的图像分类模型的训练方法训练得到;将所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征,输入所述图像分类模型中的相似度计算网络,获取所述待处理图像与所述支持样本图像之间的相似度;根据所述待处理图像与所述支持样本图像之间的相似度,以及所述支持样本图像的类别,确定所述待处理图像的类别。
[0008]根据本公开的另一方面,提供了一种图像分类模型的训练装置,所述装置包括:第一获取模块,用于获取多个训练数据集,所述训练数据集包括支持集和查询集;所述支持集
包括支持样本图像以及所述支持样本图像的类别;所述查询集中包括查询样本图像以及所述查询样本图像的类别;第二获取模块,用于获取初始的图像分类模型;确定模块,用于针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度;训练模块,用于依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度为所述图像分类模型的输出,对所述图像分类模型进行训练,得到训练好的图像分类模型。
[0009]根据本公开的另一方面,提供了一种图像分类装置,所述装置包括:
[0010]获取模块,用于获取待处理图像以及支持集,所述支持集包括多个支持样本图像以及所述支持样本图像的类别;生成模块,用于根据所述5待处理图像以及多个所述支持样本图像,生成多个图像对;所述图像对中包括所述待处理图像以及所述支持样本图像;第一输入模块,用于将所述待处理图像对输入图像分类模型的特征提取网络以及注意力机制网络,获取所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征;所述图像分类模型基于如上所述0的图像分类模型的训练方法训练得到;第二输入模块,用于将所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征,输入所述图像分类模型中的相似度计算网络,获取所述待处理图像与所述支持样本图像之间的相似度;确定模块,用于根据所述待处理图像与所述支持样本图像之间的相似度,以及所述支持样5本图像的类别,确定所述待处理图像的类别。
[0011]根据本公开的另一方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本公开上述提0出的图像分类模型的训练方法,或者,执行本公开上述提出的图像分类方法。
[0012]根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使计算机执行本公开上述提出的图像分类模型的训练方法,或者,执行本公开上述提出的图像分5类方法。
[0013]根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现本公开上述提出的图像分类模型的训练方法,或者,实现本公开上述提出的图像分类方法。
[0014]应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
[0015]附图用于更好地理解本方案,不构成对本公开的限定。其中:
[0016]图1是根据本公开第一实施例的示意图;
[0017]图2是根据本公开第二实施例的示意图;
[0018]图3是根据本公开第三实施例的示意图;
[0019]图4是根据本公开第四实施例的示意图;
[0020]图5是根据本公开第五实施例的示意图;
[0021]图6是根据本公开第六实施例的示意图;
[0022]图7是用来实现本公开实施例的图像分类模型的训练方法或者图像分类方法的电子设备的框图。
具体实施方式
[0023]以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
[0024]目前,针对图像识别任务,需要对大量图像进行标记,得到图像标注数据;采用图像标注数据对深度学习模型进行训练,得到识别准确度较高的图像识别模型,用于图像识别任务。
[0025]其中,针对物种细粒度识别任务,由于很本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种图像分类模型的训练方法,包括:获取多个训练数据集,所述训练数据集包括支持集和查询集;所述支持集包括支持样本图像以及所述支持样本图像的类别;所述查询集中包括查询样本图像以及所述查询样本图像的类别;获取初始的图像分类模型;针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度;依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度为所述图像分类模型的输出,对所述图像分类模型进行训练,得到训练好的图像分类模型。2.根据权利要求1所述的方法,其中,所述获取多个训练数据集,包括:获取原始数据集,其中,所述原始数据集中包括大于预设数量的样本图像,以及所述样本图像的类别;从所述原始数据集的多个类别中抽取第一类别,并从所述原始数据集中具有所述第一类别的样本图像中抽取支持样本图像,得到支持集;从所述第一类别中抽取一个类别作为第二类别,并从所述原始数据集中具有所述第二类别的样本图像中抽取查询样本图像,得到所述支持集对应的查询集;根据所述支持集以及所述支持集对应的查询集,生成训练数据集。3.根据权利要求1所述的方法,其中,所述针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度,包括:针对每个训练数据集,根据所述训练数据集中的所述支持样本图像以及所述查询样本图像,生成多个所述样本图像对;针对每个样本图像对,根据所述样本图像对中支持样本图像的类别,以及所述样本图像对中查询样本图像的类别,确定所述样本图像对中支持样本图像与查询样本图像之间的样本相似度。4.根据权利要求1所述的方法,其中,所述图像分类模型中包括依次连接的特征提取网络、注意力机制网络和相似度计算网络;所述特征提取网络与所述注意力机制网络,用于提取样本图像对中支持样本图像的支持图像特征,以及提取所述样本图像对中查询样本图像的查询图像特征;所述相似度计算网络,用于对所述支持图像特征以及所述查询图像特征进行拼接处理以及相似度计算处理,获取所述样本图像对中支持样本图像与查询样本图像之间的预测相似度。5.根据权利要求4所述的方法,其中,所述特征提取网络与所述注意力机制网络,分别为视觉Vision Transformer模型中的特征提取网络以及注意力机制网络。
6.根据权利要求1所述的方法,其中,所述依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度为所述图像分类模型的输出,对所述图像分类模型进行训练,得到训练好的图像分类模型,包括:依次针对每个训练数据集,将所述训练数据集中样本图像对中的支持样本图像和查询样本图像输入所述图像分类模型,获取所述样本图像对中支持样本图像与查询样本图像之间的预测相似度;根据所述预测相似度,以及所述样本图像对中支持样本图像与查询样本图像之间的样本相似度,构建损失函数;根据所述损失函数的数值,对所述图像分类模型进行参数调整,实现训练。7.根据权利要求1-6中任一项所述的方法,其中,所述支持样本图像的类别,为所述支持样本图像中目标对象所属的物种;所述查询样本图像的类别,为所述查询样本图像中目标对象所属的物种。8.一种图像分类方法,包括:获取待处理图像以及支持集,所述支持集包括多个支持样本图像以及所述支持样本图像的类别;根据所述待处理图像以及多个所述支持样本图像,生成多个图像对;所述图像对中包括所述待处理图像以及所述支持样本图像;将所述待处理图像对输入图像分类模型的特征提取网络以及注意力机制网络,获取所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征;所述图像分类模型基于权利要求1-7中任一项所述的方法训练得到;将所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征,输入所述图像分类模型中的相似度计算网络,获取所述待处理图像与所述支持样本图像之间的相似度;根据所述待处理图像与所述支持样本图像之间的相似度,以及所述支持样本图像的类别,确定所述待处理图像的类别。9.根据权利要求8所述的方法,其中,所述根据所述待处理图像与所述支持样本图像之间的相似度,以及所述支持样本图像的类别,确定所述待处理图像的类别,包括:根据所述待处理图像与所述支持样本图像之间的相似度,从多个所述支持样本图像中选择目标样本图像;将所述目标样本图像的类别,确定为所述待处理图像的类别。10.根据权利要求8或9所述的方法,其中,所述支持样本图像的类别,为所述支持样本图像中目标对象所属的物种;所述待处理图像的类别,为所述待处理图像中目标对象所属的物种。11.一种图像分类模型的训练装置,包括:第一获取模块,用于获取多个训练数据集,所述训练数据集包括支持集和查询集;所述支持集包括支持样本图像以及所述...

【专利技术属性】
技术研发人员:徐彤彤迟恺
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1