分类模型的训练方法和训练装置制造方法及图纸

技术编号:26378327 阅读:27 留言:0更新日期:2020-11-19 23:47
本公开提出一种分类模型的训练方法和训练装置,涉及机器学习领域。本公开单独训练分类模型,相对于同时训练相关的两个模型,模型的稳定性更好;并且基于生成样本数据设置用于抑制所有输出类别激活的损失函数,在真实类别的基础上不需要额外增加虚假类别,有利于降低训练的复杂度。

【技术实现步骤摘要】
分类模型的训练方法和训练装置
本公开涉及机器学习领域,特别涉及一种分类模型的训练方法和训练装置。
技术介绍
基于生成式对抗网络(GenerativeAdversarialNetworks,GAN)的半监督分类方法:在训练阶段,同时训练生成式对抗网络的生成模型和分类模型。一般来说,训练分类模型所需要的迭代次数比训练生成模型所需要的迭代次数少,这会使得生成式对抗网络不太稳定。分类模型在训练时需要增加一个额外的虚假类别,专门用于识别生成模型生成的“虚假数据”,但该虚假类别在测试阶段不会被使用,这在一定程度上增加了训练的复杂性。此外,生成模型有时会生成足够真实的“虚假数据”,这样的训练数据对于训练没有帮助。
技术实现思路
本公开可以单独训练分类模型,相对于同时训练相关的两个模型,模型的稳定性更好;并且基于生成样本数据设置用于抑制所有输出类别激活的损失函数,在真实类别的基础上不需要额外增加虚假类别,有利于降低训练的复杂度。此外,通过在特征层添加噪声的方法,在一定程度上避免生成模型生成过于真实的“虚假数据”,有利于提升训练数据的有效性和本文档来自技高网...

【技术保护点】
1.一种分类模型的训练方法,包括:/n将真实样本数据和所述真实样本数据的标签数据输入待训练的分类模型,得到所述分类模型输出的第一组输出值,基于预设的第一损失函数和第一组输出值计算第一损失,并计算所述第一损失函数在所述分类模型当前参数下的第一梯度信息;/n将生成样本数据输入所述分类模型,得到所述分类模型输出的第二组输出值,基于预设的用于抑制所有输出类别激活的第二损失函数和第二组输出值计算第二损失,并计算所述第二损失函数在所述分类模型当前参数下的第二梯度信息;/n根据所述第一损失和所述第二损失判断所述分类模型是否收敛,在所述分类模型未收敛的情况下,基于第一梯度信息和第二梯度信息的梯度叠加信息,按照...

【技术特征摘要】
1.一种分类模型的训练方法,包括:
将真实样本数据和所述真实样本数据的标签数据输入待训练的分类模型,得到所述分类模型输出的第一组输出值,基于预设的第一损失函数和第一组输出值计算第一损失,并计算所述第一损失函数在所述分类模型当前参数下的第一梯度信息;
将生成样本数据输入所述分类模型,得到所述分类模型输出的第二组输出值,基于预设的用于抑制所有输出类别激活的第二损失函数和第二组输出值计算第二损失,并计算所述第二损失函数在所述分类模型当前参数下的第二梯度信息;
根据所述第一损失和所述第二损失判断所述分类模型是否收敛,在所述分类模型未收敛的情况下,基于第一梯度信息和第二梯度信息的梯度叠加信息,按照梯度下降的方法更新所述分类模型的参数,并对所述分类模型继续进行训练。


2.如权利要求1所述的方法,其中,第二损失函数根据所述第二组输出值中每个输出类别上的输出值与数值小的预设值之间的差值信息确定。


3.如权利要求2所述的方法,其中,第二损失函数的公式表示为:



其中,c表示输入输出类别的数量,i表示其中某个输出类别,表示所述分类模型在输出类别i上的输出值,T表示数值小的预设值,max表示取最大值的运算,Lss,m表示第二损失。


4.如权利要求3所述的方法,其中,T小于或等于log0.0001。


5.如权利要求1所述的方法,其中,所述生成样本数据通过生成模型生成,其中,所述生成模型的特征层被配置为添加噪声。


6.如权利要求1所述的方法,还包括:
利用收敛的分类模型对输入的图像数据进行分类。


7.如权利要求1所述的方法,其中,
所述分类模型为图像分类模型;
所述真实样本数据为真实事物的图像数据,所述真实样本数据的标签数据为标注的真实事物的种类,所述第一组输出值为真实事物的图像数据在各个种类上的概率;
所述生成样本数据为对真实事物的图像数据添加噪声...

【专利技术属性】
技术研发人员:叶韵
申请(专利权)人:北京京东尚科信息技术有限公司北京京东世纪贸易有限公司
类型:发明
国别省市:北京;11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1