模型训练方法、装置以及计算机可读存储介质制造方法及图纸

技术编号:27262566 阅读:17 留言:0更新日期:2021-02-06 11:23
本发明专利技术涉及数据分类技术领域,具体提供了一种模型训练方法、装置以及计算机可读存储介质,旨在解决代价敏感学习算法与数据增强方法无法进行有效的结合,导致数据分类模型的精准度和性能无法一同得到提升的技术问题。为此目的,根据本发明专利技术实施例的方法,可以采用代价敏感学习算法并且根据初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型;对初始训练样本组中的训练样本进行数据增强处理,以生成增强样本;采用知识蒸馏算法,使初始数据分类模型指导第二分类模型使用增强样本进行模型训练,得到最终的数据分类模型。通过上述步骤,可以将代价敏感学习算法和数据增强方法进行有效的结合,同时提升了模型分类的精准度和性能。精准度和性能。精准度和性能。

【技术实现步骤摘要】
模型训练方法、装置以及计算机可读存储介质


[0001]本专利技术涉及数据分类
,具体涉及一种模型训练方法、装置以及计算机可读存储介质。

技术介绍

[0002]随着信息技术的高速发展,深度学习技术在图像分类任务上的性能已经远远超越了传统的图像识别方法。深度卷积神经网络(Convolutional Neural Network,CNN)是特别设计用于识别图像的多层感知器,CNN的权重共享网络结构与生物神经网络类似,通过对图像进行多次的卷积核池化操作,逐渐提取到图像的高层表达,再使用神经网络对特征进行分类,以此来实现对图像分类的功能。此外,通过对数据进行标注,CNN在图像分类领域表现出极大的优势。
[0003]然而,在实际的图像分类过程中可能会出现数据不平衡的情况,标注为某一类别的数据量远远小于标注为其他类别的数据量,神经网络模型往往会忽略该类别从而使得模型分类的精准度下降。为解决该问题,代价敏感学习算法是其中一种有效的方法;另一方面,在实际的图像分类过程中还可能会因为数据量较少而导致模型分类的性能差,现有技术中往往采用数据增强方法来提高神经网络模型的性能,但是代价敏感学习算法与数据增强方法无法进行有效的结合,导致神经网络模型分类的精准度和性能得到无法一同得到提升。

技术实现思路

[0004]为了克服上述缺陷,提出了本专利技术,以提供解决或至少部分地解决代价敏感学习算法与数据增强方法无法进行有效的结合,导致数据分类模型的精准度和性能无法一同得到提升的技术问题的模型训练方法、装置以及计算机可读存储介质。
[0005]第一方面,提供一种模型训练方法,所述模型训练方法包括:采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型;对所述初始训练样本组中的训练样本进行数据增强处理,以生成增强样本;采用知识蒸馏算法,使所述初始数据分类模型指导第二分类模型使用所述增强样本进行模型训练,得到最终的数据分类模型;其中,所述第一分类模型与所述第二分类模型的模型结构相同;所述初始训练样本组中一部分类别训练样本的数量远小于其他类别训练样本的数量。
[0006]在上述模型训练方法的一个技术方案中,“采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型”的步骤具体包括:采用代价敏感学习算法按照下列公式所示的代价敏感学习函数对所述第一分类模型进行模型训练:
其中,所述L1表示所述代价敏感学习函数,所述N表示所述初始训练样本组中训练样本的个数;所述l
i
表示所述初始训练样本组中第i个训练样本的训练误差,i=1,2,3,...,N;所述m表示所述初始训练样本组中样本类别的总数;所述W
j
表示第j个样本类别的权重且所述n
j
表示第j个样本类别的训练样本的个数;所述p
ij
表示第i个训练样本被分类为第j个样本类别的预测概率;所述q
ij
表示第i个训练样本被标记为第j个样本类别的标签值。
[0007]在上述模型训练方法的一个技术方案中,“采用知识蒸馏算法并且利用所述初始数据分类模型与所述增强样本对第二分类模型进行模型训练,得到最终的数据分类模型”的步骤具体包括:将所述增强样本同时输入至所述初始数据分类模型以及所述第二分类模型;采用知识蒸馏算法并且按照下列公式所示的知识蒸馏函数对所述第二分类模型进行模型训练:其中,所述L2表示所述知识蒸馏函数,所述l
a
表示所述第二分类模型在对所述增强样本进行训练时确定的损失函数,所述l
b
表示利用所述初始数据分类模型对所述第二分类模型使用所述增强样本进行训练指导学习时确定的知识蒸馏损失函数。
[0008]在上述模型训练方法的一个技术方案中,每个所述增强样本分别由所述初始训练样本组中的任意两个训练样本各自对应的一部分样本数据组成;所述第二分类模型的损失函数l
a
如下式所示:其中,所述r表示浮点数且r∈[0,1];所述c
uj
表示与增强样本相关的一个训练样本被标记为第j个样本类别的标签值,所述c
vj
表示与当前增强样本相关的另一个训练样本被标记为第j个样本类别的标签值,所述s
j
表示增强样本被分类为第j个样本类别的预测概率;并且/或者,所述知识蒸馏损失函数l
b
如下式所示:
其中,所述T表示超参数,T为[2,5]之间的整数;所述f
j
表示利用所述初始数据分类模型获取到的所述增强样本被分类为第j个样本类别的预测概率,所述h
j
表示利用所述第二分类模型获取到的所述增强样本被分类为第j个样本类别的预测概率;所述z
j
表示所述初始数据分类模型的特征提取模块输出的所述增强样本对应的第j个样本类别的样本特征向量;所述k
j
表示所述第二分类模型的特征提取模块输出的所述增强样本对应的第j个样本类别的样本特征向量。
[0009]在上述模型训练方法的一个技术方案中,“对所述初始训练样本组进行数据增强处理,以生成增强样本”的步骤具体包括:采用混合样本数据增强算法对所述初始训练样本组进行数据增强处理。
[0010]第二方面,提供一种模型训练装置,所述模型训练装置包括:代价敏感学习模块,其被配置成采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型;数据增强模块,其被配置成对所述初始训练样本组进行数据增强处理,以生成增强样本;知识蒸馏模块,其被配置成采用知识蒸馏算法,使所述初始数据分类模型指导第二分类模型使用所述增强样本进行模型训练,得到最终的数据分类模型;其中,所述第一分类模型与所述第二分类模型的模型结构相同;所述初始训练样本组中一部分类别训练样本的数量远小于其他类别训练样本的数量。
[0011]在上述模型训练装置的一个技术方案中,所述代价敏感学习模块还被配置成执行以下操作:采用代价敏感学习算法按照下列公式所示的代价敏感学习函数对所述第一分类模型进行模型训练:其中,所述L1表示所述代价敏感学习函数,所述N表示所述初始训练样本组中训练样本的个数;所述l
i
表示所述初始训练样本组中第i个训练样本的训练误差,i=1,2,3,...,N;
所述m表示所述初始训练样本组中样本类别的总数;所述W
j
表示第j个样本类别的权重且所述n
j
表示第j个样本类别的训练样本的个数;所述p
ij
表示第i个训练样本被分类为第j个样本类别的预测概率;所述q
ij
表示第i个训练样本被标记为第j个样本类别的标签值。
[0012]在上述模型训练装置的一个技术方案中,所述知识蒸馏模块还被配置成执行以下操作:将所述增强样本同时输入至所述初始数据分类模型以及所述第二分类模型;采用知识蒸馏算法并且按照下列公式所示的知识蒸馏函数对所述第二分类模型进行模型训本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,其特征在于,所述模型训练方法包括:采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型;对所述初始训练样本组进行数据增强处理,以生成增强样本;采用知识蒸馏算法,使所述初始数据分类模型指导第二分类模型使用所述增强样本进行模型训练,得到最终的数据分类模型;其中,所述第一分类模型与所述第二分类模型的模型结构相同;所述初始训练样本组中一部分类别训练样本的数量远小于其他类别训练样本的数量。2.根据权利要求1所述的模型训练方法,其特征在于,“采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型”的步骤具体包括:采用代价敏感学习算法按照下列公式所示的代价敏感学习函数对所述第一分类模型进行模型训练:其中,所述L1表示所述代价敏感学习函数,所述N表示所述初始训练样本组中训练样本的个数;所述l
i
表示所述初始训练样本组中第i个训练样本的训练误差,i=1,2,3,...,N;所述m表示所述初始训练样本组中样本类别的总数;所述W
j
表示第j个样本类别的权重且所述n
j
表示第j个样本类别的训练样本的个数;所述p
ij
表示第i个训练样本被分类为第j个样本类别的预测概率;所述q
ij
表示第i个训练样本被标记为第j个样本类别的标签值。3.根据权利要求1所述的模型训练方法,其特征在于,“采用知识蒸馏算法并且利用所述初始数据分类模型与所述增强样本对第二分类模型进行模型训练,得到最终的数据分类模型”的步骤具体包括:将所述增强样本同时输入至所述初始数据分类模型以及所述第二分类模型;采用知识蒸馏算法并且按照下列公式所示的知识蒸馏函数对所述第二分类模型进行模型训练:其中,所述L2表示所述知识蒸馏函数,所述l
a
表示所述第二分类模型在对所述增强样本进行训练时确定的损失函数,所述l
b
表示利用所述初始数据分类模型对所述第二分类模型使用所述增强样本进行训练指导学习时确定的知识蒸馏损失函数。4.根据权利要求3所述的模型训练方法,其特征在于,每个所述增强样本分别由所述初始训练样本组中的任意两个训练样本各自对应的一部分样本数据组成;
所述第二分类模型的损失函数l
a
如下式所示:其中,所述r表示浮点数且r∈[0,1];所述c
uj
表示与增强样本相关的一个训练样本被标记为第j个样本类别的标签值,所述c
vj
表示与当前增强样本相关的另一个训练样本被标记为第j个样本类别的标签值,所述s
j
表示增强样本被分类为第j个样本类别的预测概率;并且/或者,所述知识蒸馏损失函数l
b
如下式所示:其中,所述T表示超参数,T为[2,5]之间的整数;所述f
j
表示利用所述初始数据分类模型获取到的所述增强样本被分类为第j个样本类别的预测概率,所述h
j
表示利用所述第二分类模型获取到的所述增强样本被分类为第j个样本类别的预测概率;所述z
j
表示所述初始数据分类模型的特征提取模块输出的所述增强样本对应的第j个样本类别的样本特征向量;所述k
j
表示所述第二分类模型的特征提取模块输出的所述增强样本对应的第j个样本类别的样本特征向量。5.根据权利要求1至4中任一项所述的模型训练方法,其特征在于,“对所述初始训练样本组进行数据增强处理”的步骤具体包括:采用混合样本数据增强算法对所述初始训练样本组进行数据增强处理。6.一种模型训练装置,其特征在于,所述训练装置包括:代价敏感学习模块,其被配置成采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训...

【专利技术属性】
技术研发人员:冯于树
申请(专利权)人:江苏云从曦和人工智能有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1