分类模型的训练方法和图像分类方法、装置、设备和介质制造方法及图纸

技术编号:31799214 阅读:20 留言:0更新日期:2022-01-08 10:59
本公开公开了一种分类模型的训练方法和图像处理方法、装置、电子设备和存储介质,应用于人工智能领域,具体涉及计算机视觉和深度学习领域。分类模型的训练方法的具体实现方案为:将样本图像输入特征提取网络,得到图像特征,该样本图像包括指示真实类别的信息;将图像特征输入类别预测网络,得到第一概率向量,该第一概率向量包括样本图像属于真实类别的预测概率;基于预测概率和类别预测网络的网络权重,对图像特征进行加权处理,得到加权后特征;将加权后特征输入类别预测网络,得到第二概率向量;以及基于第二概率向量,对分类模型进行训练。进行训练。进行训练。

【技术实现步骤摘要】
分类模型的训练方法和图像分类方法、装置、设备和介质


[0001]本公开涉及人工智能领域,具体涉及计算机视觉和深度学习领域,更具体地涉及一种分类模型的训练方法和图像分类方法、装置、设备和介质。

技术介绍

[0002]在图像分类的相关技术中,直接根据损失函数来训练分类模型。在类别较多时,可能会存在因模型无法使真实类别在众多类别中获得关注,导致分类精度不高的技术问题。

技术实现思路

[0003]提供了一种提高模型精度的分类模型的训练方法和图像处理方法、装置、电子设备和存储介质。
[0004]本公开的一个方面提供了一种分类模型的训练方法,分类模型包括特征提取网络和类别预测网络;该训练方法包括:将样本图像输入特征提取网络,得到图像特征,该样本图像包括指示真实类别的信息;将图像特征输入类别预测网络,得到第一概率向量,该第一概率向量包括样本图像属于真实类别的预测概率;基于预测概率和类别预测网络的网络权重,对图像特征进行加权处理,得到加权后特征;将加权后特征输入类别预测网络,得到第二概率向量;以及基于第二概率向量,对分类模型进行训练。
[0005]本公开的另一个方面提供了一种图像分类方法,包括:将待分类图像输入分类模型,得到第三概率向量;以及基于第三概率向量,确定待分类图像的类别,其中,分类模型是采用上文描述的分类模型的训练方法训练得到的。
[0006]本公开的另一个方面提供了一种分类模型的训练装置,其中,分类模型包括特征提取网络和类别预测网络,该装置包括:特征提取模块,用于将样本图像输入特征提取网络,得到图像特征,该样本图像包括指示真实类别的信息;第一概率预测模块,用于将图像特征输入类别预测网络,得到第一概率向量,该第一概率向量包括样本图像属于真实类别的预测概率;特征加权模块,用于基于预测概率和类别预测网络的网络权重,对图像特征进行加权处理,得到加权后特征;第二概率预测模块,用于将加权后特征输入类别预测网络,得到第二概率向量;以及模型训练模块,用于基于第二概率向量,对分类模型进行训练。
[0007]本公开的另一个方面提供了一种图像分类装置,包括:第三概率预测模块,用于将待分类图像输入分类模型,得到第三概率向量;以及类别确定模块,用于基于第三概率向量,确定待分类图像的类别,其中,分类模型是采用前文描述的分类模型的训练装置训练得到的。
[0008]根据本公开的另一个方面提供了一种电子设备,包括:至少一个处理器;以及与至少一个处理器通信连接的存储器;其中,存储器存储有可被至少一个处理器执行的指令,指令被至少一个处理器执行,以使至少一个处理器能够执行本公开提供的分类模型的训练方法和/或图像分类方法。
[0009]根据本公开的另一个方面提供了一种存储有计算机指令的非瞬时计算机可读存
储介质,其中,计算机指令用于使计算机执行本公开提供的分类模型的训练方法和/或图像分类方法。
[0010]根据本公开的另一个方面提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现本公开提供的分类模型的训练方法和/或图像分类方法。
[0011]应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
[0012]附图用于更好地理解本方案,不构成对本公开的限定。其中:
[0013]图1是根据本公开实施例的分类模型的训练方法和图像分类方法、装置的应用场景示意图;
[0014]图2是根据本公开实施例的分类模型的训练方法的流程示意图;
[0015]图3是根据本公开实施例的对图像进行加权处理的原理示意图;
[0016]图4是根据本公开实施例的分类模型的训练方法的原理示意图;
[0017]图5是根据本公开实施例的图像分类方法的流程示意图;
[0018]图6是根据本公开实施例的分类模型的训练装置的结构框图;
[0019]图7是根据本公开实施例的图像分类装置的结构框图;以及
[0020]图8是用来实施本公开实施例的分类模型的训练方法和/或图像分类方法的电子设备的结构框图。
具体实施方式
[0021]以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
[0022]本公开提供了一种分类模型的训练方法,其中,分类模型包括特征提取网络和类别预测网络。训练方法包括特征提取阶段、第一概率预测阶段、特征加权阶段、第二概率预测阶段和模型训练阶段。在特征提取阶段中,将样本图像输入特征提取网络,得到图像特征,该样本图像包括指示真实类别的信息。在第一概率预测阶段,将图像特征输入类别预测网络,得到第一概率向量,该第一概率向量包括样本图像属于真实类别的预测概率。在特征加权阶段,基于预测概率和类别预测网络的网络权重,对图像特征进行加权处理,得到加权后特征。在第二概率预测阶段,将加权后特征输入类别预测网络,得到第二概率向量。在模型训练阶段,基于第二概率向量,对分类模型进行训练。
[0023]以下将结合图1对本公开提供的方法和装置的应用场景进行描述。
[0024]图1是根据本公开实施例的分类模型的训练方法和图像分类方法、装置的应用场景示意图。
[0025]如图1所示,该实施例的应用场景100可以包括电子设备110,该电子设备110可以为具有处理功能的任意电子设备,包括但不限于智能手机、平板电脑、膝上型便携计算机、台式计算机和服务器等等。
[0026]该电子设备110例如可以对输入的图像120进行分类,得到分类结果130。例如可以识别图像120中的目标对象,根据识别得到的目标对象的类型来对图像120进行分类。该分类结果130例如可以包括图像中目标对象属于多个预定类别中每个类别的概率。目标对象例如可以包括车辆、水杯、背包等可能具有多种形状类型的对象,也可以为具有多种颜色类型的对象等,本公开对此不做限定。
[0027]根据本公开的实施例,如图1所示,该应用场景100还可以包括服务器140。电子设备110可以通过网络与服务器140通信连接,该网络可以包括无线或有线通信链路。
[0028]示例性地,服务器140可以用于训练分类模型,并响应于电子设备110发送的模型获取请求,将训练好的分类模型150发送给电子设备110,便于电子设备110对图像进行分类。在一实施例中,电子设备110还可以通过网络将图像发送给服务器140,由服务器根据训练好的分类模型对获得的图像进行分类。
[0029]根据本公开的实施例,如图1所示,该应用场景100还可以包括数据库160,该数据库160可以维护有海量的图像,该些图像可以具有指示图像的真实类别的标签。服务器140可以访问该数本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种分类模型的训练方法,其中,所述分类模型包括特征提取网络和类别预测网络;所述方法包括:将样本图像输入所述特征提取网络,得到图像特征,所述样本图像包括指示真实类别的信息;将所述图像特征输入所述类别预测网络,得到第一概率向量,所述第一概率向量包括所述样本图像属于所述真实类别的预测概率;基于所述预测概率和所述类别预测网络的网络权重,对所述图像特征进行加权处理,得到加权后特征;将所述加权后特征输入所述类别预测网络,得到第二概率向量;以及基于所述第二概率向量,对所述分类模型进行训练。2.根据权利要求1所述的方法,其中,所述类别预测网络包括全连接层;对所述图像特征进行加权处理包括:确定所述预测概率在所述第一概率向量中的位置信息;基于所述位置信息,确定表示所述全连接层的网络权重的权重数据中与所述预测概率相关联的目标权重数据;以及基于所述目标权重数据,对所述图像特征进行加权处理。3.根据权利要求2所述的方法,其中,所述基于所述目标权重向量,对所述图像特征进行加权处理包括:对所述目标权重数据进行扩充处理,得到与所述图像特征相同尺寸的权重因子;以及将所述权重因子与所述图像特征点乘,得到所述加权后特征。4.根据权利要求2所述的方法,其中,所述类别预测网络还包括卷积层和分类层;所述将所述图像特征输入所述类别预测网络,得到第一概率向量包括:将所述图像特征输入所述卷积层,得到尺寸小于所述图像特征的低维特征;将所述低维特征输入所述全连接层,得到目标特征;以及将所述目标特征输入所述分类层,得到所述第一概率向量。5.根据权利要求1所述的方法,其中,所述样本图像包括多个图像;所述预测概率包括与所述多个图像分别对应的多个概率;基于所述预测概率和所述类别预测网络的网络权重,对所述图像特征进行加权处理包括:对于所述多个图像中的每个图像,基于所述多个概率中与所述每个图像对应的概率和所述类别预测网络的网络权重,对所述每个图像的图像特征进行加权处理。6.一种图像分类方法,包括:将待分类图像输入分类模型,得到第三概率向量;以及基于所述第三概率向量,确定所述待分类图像的类别,其中,所述分类模型是采用权利要求1~5中任一项所述的方法训练得到的。7.一种分类模型的训练装置,其中,所述分类模型包括特征提取网络和类别预测网络;所述装置包括:特征提取模块,用于将样本图像输入所述特征提取网络,得到图像特征,所述样本图像包括指示真实类别的信息;第一概率预测模块,用于将所述图像特征输入所述类别预测网络,得到第一概率向量,
所述第一概率向量包括所述样本图像属于...

【专利技术属性】
技术研发人员:龚琛婷谭啸孙昊
申请(专利权)人:北京百度网讯科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1