一种神经网络训练方法、计算设备及存储介质技术

技术编号:21715263 阅读:26 留言:0更新日期:2019-07-27 19:26
本发明专利技术公开了一种神经网络训练方法,适于在计算设备中执行,神经网络适于对图像进行处理以输出表征图像特征的关键点,方法包括:从视频数据中提取预定数量的连续帧图像;将连续帧图像输入神经网络,得到连续帧图像的图像预测点;针对各图像预测点,对连续帧图像进行光流计算,以得到对应的光流预测点;基于连续帧图像的图像预测点和光流预测点,计算第一损失值;以及基于第一损失值,调整神经网络的参数以便得到训练后的神经网络。该方法能够提高神经网络输出预测点的稳定度。

A Neural Network Training Method, Computing Equipment and Storage Media

【技术实现步骤摘要】
一种神经网络训练方法、计算设备及存储介质
本专利技术涉及深度学习
,尤其涉及一种神经网络训练方法、关键点检测方法、计算设备及存储介质。
技术介绍
关键点/特征点检测的应用非常广泛,尤其人脸或者猫狗脸关键点的检测,是各种应用如人脸美化、猫狗脸可爱表情的重要基础。互联网短视频的实时要求,对脸关键点的检测是一个挑战,传统的一帧帧检测的方法会产生严重的点抖动问题,点的晃动对用户的体验也产生巨大的不良影响。如何调校神经网络以达到最优的应用效果仍是一个经验性的工作。因此,需要一种神经网络训练方法,来提高神经网络输出预测点的稳定度。
技术实现思路
为此,本专利技术提供了一种神经网络训练方法以及关键点检测方法,以力图解决或者至少缓解上面存在的至少一个问题。根据本专利技术的一个方面,提供了一种神经网络训练方法,适于在计算设备中执行,神经网络可以对图像进行处理以输出表征图像特征的关键点。在该方法中,可以首先从视频数据中提取预定数量的连续帧图像;将连续帧图像输入神经网络,得到连续帧图像的第一图像预测点。然后针对各第一图像预测点,对连续帧图像进行光流计算,以得到对应的光流预测点。从而可以基于连续帧图像的图像预测点和光流预测点,计算第一损失值。最后基于第一损失值,调整神经网络的参数,以得到训练后的神经网络。可选地,在上述方法中,可以首先基于光流法,计算连续帧图像中相邻帧之间各第一图像预测点的位移信息。然后基于位移信息,确定前一帧图像的各第一图像预测点在当前帧图像中的位置,以得到连续帧图像的光流预测点。可选地,在上述方法中,可以通过下述公式计算第一损失值:其中,registrationloss为第一损失值,Lt,i指连续帧图像中第t帧第i个网络预测点的坐标,为对应的第t帧第i个光流预测点的坐标,K为点坐标数,T为视频帧数。可选地,为了进一步优化神经网络,在上述方法中,还可以获取已标注关键点的图像,该组图像与连续帧图像具有一致的数量和格式。同样地,将该组图像输入神经网络,获得图像的第二图像预测点。然后就可以基于标注的关键点和第二图像预测点,计算第二损失值。最后基于第二损失值,调整神经网络的参数以优化神经网络。可选地,在上述方法中,可以通过下述公式计算第二损失值:其中,detectionloss为第二损失值,Li指图像的第i个网络预测点的坐标,为对应的第i个关键点的坐标,K为点坐标数。根据本专利技术的又一个方面,提供一种神经网络训练方法,可以首先从视频数据中提取预定数量的连续帧图像,并获取具有标注了关键点的图像,该组图像与连续帧图像具有一致的数量和格式。然后,将连续帧图像和该组图像输入神经网络,得到连续帧图像的第一图像预测点和该组图像的第二图像预测点。针对连续帧图像的各第一图像预测点,对连续帧图像进行光流计算,以得到连续帧图像的光流预测点。就可以基于第一图像预测点和光流预测点,计算第一损失值,并基于标注的关键点和第二图像预测点,计算第二损失值。最后,基于第一损失值和第二损失值,调整神经网络的参数以得到训练后的神经网络。以上方案通过两个损失值对神经网络进行优化,缩小图像预测点与实际目标点的差值,从而提高神经网络预测的稳定度。根据本专利技术的另一个方面,提供一种关键点检测方法,可以使用上述神经网络训练方法训练后的神经网络对视频进行关键点检测以输出表征图像特征的关键点。根据本专利技术另一个方面,提供了一种计算设备,包括一个或多个处理器;存储器;一个或多个程序,这一个或多个程序存储在存储器中并被配置为由一个或多个处理器执行,一个或多个程序用于执行神经网络训练方法和/或关键点检测方法的指令。根据本专利技术另一个方面,提供一种存储一个或多个程序的计算机可读存储介质,这一个或多个程序包括指令,当指令被计算设备执行时,使得计算设备执行神经网络训练方法和/或关键点检测方法。本方案提出一种基于光流对已初步训练好的神经网络进行优化训练的方法。其中复杂的光流计算只存在于网络训练阶段,在神经网络预测阶段,相比原网络不会有额外的计算成本,从而保证网络的运行速度不变,同时,神经网络因为利用光流技术大大提高了输出点的稳定度。附图说明为了实现上述以及相关目的,本文结合下面的描述和附图来描述某些说明性方面,这些方面指示了可以实践本文所公开的原理的各种方式,并且所有方面及其等效方面旨在落入所要求保护的主题的范围内。通过结合附图阅读下面的详细描述,本公开的上述以及其它目的、特征和优势将变得更加明显。遍及本公开,相同的附图标记通常指代相同的部件或元素。图1示出了根据本专利技术的一个实施例的计算设备100的构造示意图;图2示出了根据本专利技术的一个实施例的神经网络训练方法200的示意性流程图;图3示出了根据本专利技术的一个实施例的神经网络训练方法300的示意性流程图;图4示出了根据本专利技术的一个实施例的神经网络训练过程的示意图;图5示出了根据本专利技术的一个实施例的神经网络训练结果的示意性图。具体实施方式下面将参照附图更详细地描述本公开的示例性实施例。虽然附图中显示了本公开的示例性实施例,然而应当理解,可以以各种形式实现本公开而不应被这里阐述的实施例所限制。相反,提供这些实施例是为了能够更透彻地理解本公开,并且能够将本公开的范围完整的传达给本领域的技术人员。神经网络优化算法是通过改善训练方法来最小化损失函数。模型的参数可以用来计算预测值和目标值之间的偏差程度,基于这些参数就可以构成损失函数。在训练神经网络时,可以通过寻找最小值控制方差,更新模型参数,最终使模型收敛。本方案主要通过定义并优化损失函数,调整学习率对预训练的神经网络进行优化。图1示出了根据本专利技术的一个实施例的计算设备100的构造示意图。在基本的配置102中,计算设备100典型地包括系统存储器106和一个或者多个处理器104。存储器总线108可以用于在处理器104和系统存储器106之间的通信。取决于期望的配置,处理器104可以是任何类型的处理,包括但不限于:微处理器(μP)、微控制器(μC)、数字信息处理器(DSP)或者它们的任何组合。处理器104可以包括诸如一级高速缓存110和二级高速缓存112之类的一个或者多个级别的高速缓存、处理器核心114和寄存器116。示例的处理器核心114可以包括运算逻辑单元(ALU)、浮点数单元(FPU)、数字信号处理核心(DSP核心)或者它们的任何组合。示例的存储器控制器118可以与处理器104一起使用,或者在一些实现中,存储器控制器118可以是处理器104的一个内部部分。取决于期望的配置,系统存储器106可以是任意类型的存储器,包括但不限于:易失性存储器(诸如RAM)、非易失性存储器(诸如ROM、闪存等)或者它们的任何组合。系统存储器106可以包括操作系统120、一个或者多个程序122以及程序数据124。在一些实施方式中,程序122可以布置为在操作系统上利用程序数据124进行操作。计算设备100还可以包括有助于从各种接口设备(例如,输出设备142、外设接口144和通信设备146)到基本配置102经由总线/接口控制器130的通信的接口总线140。示例的输出设备142包括图形处理单元148和音频处理单元150。它们可以被配置为有助于经由一个或者多个A/V端口152与诸如显示器或者扬声器之类的各种外部设备进行通信。示例本文档来自技高网...

【技术保护点】
1.一种神经网络训练方法,适于在计算设备中执行,所述神经网络适于对图像进行处理以输出表征图像特征的关键点,其中,所述方法包括:从视频数据中提取预定数量的连续帧图像;将所述连续帧图像输入所述神经网络,得到连续帧图像的第一图像预测点;针对各第一图像预测点,对所述连续帧图像进行光流计算,以得到对应的光流预测点;基于所述连续帧图像的第一图像预测点和光流预测点,计算第一损失值;以及基于第一损失值,调整所述神经网络的参数以得到训练后的神经网络。

【技术特征摘要】
1.一种神经网络训练方法,适于在计算设备中执行,所述神经网络适于对图像进行处理以输出表征图像特征的关键点,其中,所述方法包括:从视频数据中提取预定数量的连续帧图像;将所述连续帧图像输入所述神经网络,得到连续帧图像的第一图像预测点;针对各第一图像预测点,对所述连续帧图像进行光流计算,以得到对应的光流预测点;基于所述连续帧图像的第一图像预测点和光流预测点,计算第一损失值;以及基于第一损失值,调整所述神经网络的参数以得到训练后的神经网络。2.如权利要求1所述的方法,其中,所述针对各第一图像预测点,对所述连续帧图像进行光流计算,以得到对应的光流预测点的步骤包括:基于光流法,计算所述连续帧图像中相邻帧之间各第一图像预测点的位移信息;基于所述位移信息,确定前一帧图像的各第一图像预测点在当前帧图像中的位置,以得到所述连续帧图像的光流预测点。3.如权利要求1所述的方法,其中,通过下述公式计算所述第一损失值:其中,registrationloss为第一损失值,Lt,i指所述连续帧图像中第t帧第i个网络预测点的坐标,为对应的第t帧第i个光流预测点的坐标,K为点坐标数,T为视频帧数。4.如权利要求1所述的方法,其中,所述方法还包括:获取已标注关键点的图像,所述图像与所述连续帧图像具有一致的数量和格式;将获取的图像输入神经网络中进行处理,获得第二图像预测点;基于标注的关键点和第二图像预测点,计算第二损失值;以及基于所述第二损失值,调整所述神经网络的参数以优化所述神经网络。5.如权利要求4所述的方法,其中,通过下述公式计算所述第二损失...

【专利技术属性】
技术研发人员:齐子铭李志阳周子健张伟许清泉
申请(专利权)人:厦门美图之家科技有限公司
类型:发明
国别省市:福建,35

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1