数据训练方法、装置、设备及存储介质制造方法及图纸

技术编号:20222349 阅读:42 留言:0更新日期:2019-01-28 20:40
本公开是关于一种数据训练方法、装置、设备及存储介质。所述方法,包括:以类别为单位对训练数据进行拆分,得到N个子数据集;所述类别为根据训练数据的第一属性进行确定;子数据集中的类别数量不超过预设数量;基于N个子数据集对深度学习模型进行M次迭代训练,得到训练后的深度学习模型;每次迭代训练的过程包括:基于第一个子数据集对深度学习模型进行训练,直至深度学习模型的损失函数达到第一预设值停止训练;基于第二个至第N个子数据集对深度学习模型进行训练。提高了训练深度学习模型时的GPU利用率以及训练速度。

【技术实现步骤摘要】
数据训练方法、装置、设备及存储介质
本公开涉及深度学习
,尤其涉及一种数据训练方法、装置、设备及存储介质。
技术介绍
相关技术中,深度学习技术目前已经广泛应用在计算机视觉中,其分类、检测效果已经远远超过传统的方法,深度学习本质上是数据驱动的技术,一般来说,数据量越大,其泛化效果越好。目前深度学习模型的训练方式大部分采用基于GPU(GraphicsProcessingUnit,图形处理器)的模式,GPU的并行处理能力可以大大提高模型的训练速度。而且,深度学习模型目前一般采用的是基于随机梯度下降的优化算法,训练过程中的每一步加载的样本量(batchsize)一般是在256至512的范围内。但是,GPU的显存一般是12G(GByte,吉字节),对于参数量比较大的深度学习模型来说,每次训练加载的数据量是有限的;而对于大规模数据的训练任务来讲,因为深度学习模型比较大,占用的显存比较多,所以每次只能加载少量的样本进行前馈运算,累加几次前馈运算后,再进行反向传播对模型参数进行调整。例如batchsize为512,但是因为显存的限制,每次只能加载64个样本,这样需要累加8次前馈运算之后才能进行本文档来自技高网...

【技术保护点】
1.一种数据训练方法,其特征在于,包括:以类别为单位对所述训练数据进行拆分,得到N个子数据集;所述类别为根据所述训练数据的第一属性进行确定;所述子数据集中的类别数量不超过预设数量;基于所述N个子数据集对深度学习模型进行M次迭代训练,得到训练后的深度学习模型;M、N为大于等于1的自然数;其中,每次迭代训练的过程包括:基于第一个子数据集对所述深度学习模型进行训练,直至所述深度学习模型的损失函数达到第一预设值停止训练;基于第二个至第N个子数据集对所述深度学习模型进行训练。

【技术特征摘要】
1.一种数据训练方法,其特征在于,包括:以类别为单位对所述训练数据进行拆分,得到N个子数据集;所述类别为根据所述训练数据的第一属性进行确定;所述子数据集中的类别数量不超过预设数量;基于所述N个子数据集对深度学习模型进行M次迭代训练,得到训练后的深度学习模型;M、N为大于等于1的自然数;其中,每次迭代训练的过程包括:基于第一个子数据集对所述深度学习模型进行训练,直至所述深度学习模型的损失函数达到第一预设值停止训练;基于第二个至第N个子数据集对所述深度学习模型进行训练。2.根据权利要求1所述的方法,其特征在于,所述基于第二个至第N个子数据集对所述深度学习模型进行训练,包括:固定所述深度学习模型中除最后一层参数外的其他参数;根据用以训练所述深度学习模型的当前子数据集中的类别数量,确定所述深度学习模型中最后一层参数的超参,并训练所述最后一层参数;当所述深度学习模型的损失函数达到与所述当前子数据集对应的第三预设值时,暂停训练所述深度学习模型;取消固定所述深度学习模型中的参数,并利用所述当前子数据集对所述深度学习模型的所有参数进行训练调整,直至所述深度学习模型的损失函数达到预设状态,停止本次训练,所述预设状态包括所述损失函数小于等于第二预设值,且所述损失函数的变化范围在预设阈值范围内;以所述当前子数据集的下一子数据集作为当前子数据集,然后进入固定所述深度学习模型中除最后一层参数外的其他参数的步骤,直至所述当前子数据集为第N个子数据集。3.根据权利要求1所述的方法,其特征在于,所述基于第二个至第N个子数据集对所述深度学习模型进行训练,包括:利用第二个至第N个子数据集中的各个子数据集对深度学习模型进行训练;其中,基于每个子数据集对所述对深度学习模型进行至少一次训练。4.根据权利要求1-3任一项所述的方法,其特征在于,在所述基于所述N个子数据集对深度学习模型进行M次迭代训练,得到训练后的深度学习模型之后,还包括:删除所述深度学习模型的最后一层,得到更新后的深度学习模型;将待提取特征的第一数据输入所述更新后的深度学习模型;通过所述更新后的深度学习模型获取所述第一数据的特征。5.根据权利要求1-3任一项所述的方法,其特征在于,在所述基于所述N个子数据集对深度学习模型进行M次迭代训练,得到训练后的深度学习模型之后,还包括:将待分类的第二数据输入训练后的所述深度学习模型;利用训练后的所述深度学习模型确定所述第二数据的所属类别。6.根据权利要求1所述的方法,其特征在于,所述预设数量为小于等于10万的自然数。7.一种数据训练装置,其特征在于,包括:训练数据拆分模块,被配置为以类别为单位对所述训练数据进行拆分,得到N个子数据集;所述类别为根据所述训练数据的第一属性进行确定;所述子数据集中的类别数量不超过预设数量;数据...

【专利技术属性】
技术研发人员:吴丽军杨帆
申请(专利权)人:北京达佳互联信息技术有限公司
类型:发明
国别省市:北京,11

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1