【技术实现步骤摘要】
图像分类模型的训练方法、装置及电子设备
[0001]本公开涉及人工智能
,尤其涉及自然语言处理、计算机视觉、深度学习
,尤其涉及一种图像分类模型的训练方法、装置及电子设备。
技术介绍
[0002]目前,针对图像识别任务,需要对大量图像进行标记,得到图像标注数据;采用图像标注数据对深度学习模型进行训练,得到识别准确度较高的图像识别模型,用于图像识别任务。
[0003]其中,针对物种细粒度识别任务,由于很多物种的体型、外貌相似,
[0004]特征差异较小,只有相应领域的专家才能区分不同的物种,导致该任务下的图像标注数据缺乏,难以训练得到识别准确度较高的物种细粒度识别模型。
技术实现思路
[0005]本公开提供了一种图像分类模型的训练方法、装置及电子设备。
[0006]根据本公开的一方面,提供了一种图像分类模型的训练方法,所述方法包括:获取多个训练数据集,所述训练数据集包括支持集和查询集;所述支持集包括支持样本图像以及所述支持样本图像的类别;所述查询集中包括查询样本图像以及所述查询样本图像的类别;获取初始的图像分类模型;针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度;依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中 ...
【技术保护点】
【技术特征摘要】
1.一种图像分类模型的训练方法,包括:获取多个训练数据集,所述训练数据集包括支持集和查询集;所述支持集包括支持样本图像以及所述支持样本图像的类别;所述查询集中包括查询样本图像以及所述查询样本图像的类别;获取初始的图像分类模型;针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度;依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度为所述图像分类模型的输出,对所述图像分类模型进行训练,得到训练好的图像分类模型。2.根据权利要求1所述的方法,其中,所述获取多个训练数据集,包括:获取原始数据集,其中,所述原始数据集中包括大于预设数量的样本图像,以及所述样本图像的类别;从所述原始数据集的多个类别中抽取第一类别,并从所述原始数据集中具有所述第一类别的样本图像中抽取支持样本图像,得到支持集;从所述第一类别中抽取一个类别作为第二类别,并从所述原始数据集中具有所述第二类别的样本图像中抽取查询样本图像,得到所述支持集对应的查询集;根据所述支持集以及所述支持集对应的查询集,生成训练数据集。3.根据权利要求1所述的方法,其中,所述针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度,包括:针对每个训练数据集,根据所述训练数据集中的所述支持样本图像以及所述查询样本图像,生成多个所述样本图像对;针对每个样本图像对,根据所述样本图像对中支持样本图像的类别,以及所述样本图像对中查询样本图像的类别,确定所述样本图像对中支持样本图像与查询样本图像之间的样本相似度。4.根据权利要求1所述的方法,其中,所述图像分类模型中包括依次连接的特征提取网络、注意力机制网络和相似度计算网络;所述特征提取网络与所述注意力机制网络,用于提取样本图像对中支持样本图像的支持图像特征,以及提取所述样本图像对中查询样本图像的查询图像特征;所述相似度计算网络,用于对所述支持图像特征以及所述查询图像特征进行拼接处理以及相似度计算处理,获取所述样本图像对中支持样本图像与查询样本图像之间的预测相似度。5.根据权利要求4所述的方法,其中,所述特征提取网络与所述注意力机制网络,分别为视觉Vision Transformer模型中的特征提取网络以及注意力机制网络。
6.根据权利要求1所述的方法,其中,所述依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度为所述图像分类模型的输出,对所述图像分类模型进行训练,得到训练好的图像分类模型,包括:依次针对每个训练数据集,将所述训练数据集中样本图像对中的支持样本图像和查询样本图像输入所述图像分类模型,获取所述样本图像对中支持样本图像与查询样本图像之间的预测相似度;根据所述预测相似度,以及所述样本图像对中支持样本图像与查询样本图像之间的样本相似度,构建损失函数;根据所述损失函数的数值,对所述图像分类模型进行参数调整,实现训练。7.根据权利要求1-6中任一项所述的方法,其中,所述支持样本图像的类别,为所述支持样本图像中目标对象所属的物种;所述查询样本图像的类别,为所述查询样本图像中目标对象所属的物种。8.一种图像分类方法,包括:获取待处理图像以及支持集,所述支持集包括多个支持样本图像以及所述支持样本图像的类别;根据所述待处理图像以及多个所述支持样本图像,生成多个图像对;所述图像对中包括所述待处理图像以及所述支持样本图像;将所述待处理图像对输入图像分类模型的特征提取网络以及注意力机制网络,获取所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征;所述图像分类模型基于权利要求1-7中任一项所述的方法训练得到;将所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征,输入所述图像分类模型中的相似度计算网络,获取所述待处理图像与所述支持样本图像之间的相似度;根据所述待处理图像与所述支持样本图像之间的相似度,以及所述支持样本图像的类别,确定所述待处理图像的类别。9.根据权利要求8所述的方法,其中,所述根据所述待处理图像与所述支持样本图像之间的相似度,以及所述支持样本图像的类别,确定所述待处理图像的类别,包括:根据所述待处理图像与所述支持样本图像之间的相似度,从多个所述支持样本图像中选择目标样本图像;将所述目标样本图像的类别,确定为所述待处理图像的类别。10.根据权利要求8或9所述的方法,其中,所述支持样本图像的类别,为所述支持样本图像中目标对象所属的物种;所述待处理图像的类别,为所述待处理图像中目标对象所属的物种。11.一种图像分类模型的训练装置,包括:第一获取模块,用于获取多个训练数据集,所述训练数据集包括支持集和查询集;所述支持集包括支持样本图像以及所述...
【专利技术属性】
技术研发人员:徐彤彤,迟恺,
申请(专利权)人:北京百度网讯科技有限公司,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。