一种神经网络的训练方法、装置及设备制造方法及图纸

技术编号:33503778 阅读:54 留言:0更新日期:2022-05-19 01:13
本发明专利技术公开了一种神经网络的训练方法、装置及设备,其中,所述方法包括:构建至少一个计算单元;将待训练神经网络分别放置到所述至少一个计算单元内,每个计算单元执行参数生成操作,分别得到每个计算单元生成的梯度数据;将所述每个计算单元得出的梯度数据进行平均化,得到平均化后的梯度数据;根据所述平均化后的梯度数据更新所述待训练神经网络;任一所述计算单元的参数生成操作包括:运行所述计算单元中的待训练神经网络,生成轨迹数据;根据所述轨迹数据,生成梯度数据。通过上述方式,本发明专利技术提高了神经网络的训练效率。提高了神经网络的训练效率。提高了神经网络的训练效率。

【技术实现步骤摘要】
一种神经网络的训练方法、装置及设备


[0001]本专利技术涉及强化学习
,具体涉及一种神经网络的训练方法、装置及设备。

技术介绍

[0002]在超大动作空间和状态空间下,深度强化学习训练时间较长,仅一次实验就需要耗费大量训练时间,而生成有效的策略模型往往需要进行大量实验,这使得整个生产策略模型流程效率低下。
[0003]深度强化学习训练过程耗时较长的一个核心原因在于强化学习算法的采样效率偏低,需要基于大量的环境交互轨迹数据进行学习才能收敛到预期效果,得到较好策略模型。在相同的强化学习算法和超参约束下,单位时间轨迹数据吞吐量决定了训练效率。
[0004]而目前的强化学习算法架构,例如在专利文件(CN110705705A,CN108021395A)中,在提高吞吐率上有以下问题:第一,随着将深度强化学习应用到越来越复杂的决策问题中,复杂问题的模拟环境往往具有超大的动作空间和状态空间,这会导致网络中仍然会传输大量的观测和动作数据,大大降低了轨迹数据的收集效率。
[0005]第二,当收集的轨迹数据量很大时,单个GPU会成为采样轨迹数据进行学习的瓶颈。最后,针对不同算法的轨迹数据,并没有通用的可配置的高性能轨迹存储结构。

技术实现思路

[0006]为解决上述问题,提出了本专利技术实施例的神经网络的训练方法、装置及设备。
[0007]根据本专利技术实施例的一个方面,提供了一种神经网络的训练方法,包括:构建至少一个计算单元;将待训练神经网络分别放置到所述至少一个计算单元内,每个计算单元执行参数生成操作,分别得到每个计算单元生成的梯度数据;将所述每个计算单元得出的梯度数据进行平均化,得到平均化后的梯度数据;根据所述平均化后的梯度数据更新所述待训练神经网络;任一所述计算单元的参数生成操作包括:运行所述计算单元中的待训练神经网络,生成轨迹数据;根据所述轨迹数据,生成梯度数据。
[0008]可选的,所述每个计算单元均包括:至少一个CPU、至少一个独立内存以及一个GPU。
[0009]可选的,在构建至少一个计算单元后,还包括:预先配置轨迹数据的维度。
[0010]可选的,运行所述计算单元中的待训练神经网络,生成轨迹数据,包括:将所述待训练神经网络与初始环境进行交互,得到第一动作分布以及第一环境;对所述第一动作分布进行动作采样,得到第一动作;
判断得到所述第一动作后是否满足预设停止交互条件;若满足预设停止交互条件,将所述第一动作相关的运行数据作为轨迹数据,所述第一动作相关的运行数据包括以下至少一个:所述初始环境、所述第一动作分布、所述第一环境、所述第一动作;若不满足预设停止交互条件,将所述第一动作传递给所述第一环境,根据第一动作以及第一环境重新执行将所述待训练神经网络与初始环境进行交互的步骤。
[0011]可选的,在运行所述计算单元中的待训练神经网络,生成轨迹数据之后,还包括:通过所述至少一个独立内存,将轨迹数据进行储存。
[0012]可选的,根据所述轨迹数据,生成梯度数据,包括:通过神经网络的训练算法对所述轨迹数据进行处理,生成梯度数据。
[0013]可选的,在生成梯度数据之后,还包括:将所述梯度数据储存在所述至少一个独立内存中。
[0014]可选的,将所述每个计算单元得出的梯度数据进行平均化,得到平均化后的梯度数据,包括:通过所述至少一个独立内存,读取每个计算单元生成的梯度数据;将所述每个计算单元生成的梯度数据进行求和,得到求和后的梯度数据;对所述求和后的数据进行平均化,得到平均化后的梯度数据。
[0015]根据本专利技术实施例的另一方面,提供了一种神经网络的训练装置,所述装置包括:构建模块,用于构建至少一个计算单元;处理模块,用于将待训练神经网络分别放置到所述至少一个计算单元内,每个计算单元执行参数生成操作,分别得到每个计算单元生成的梯度数据;将所述每个计算单元得出的梯度数据进行平均化,得到平均化后的梯度数据;运行所述计算单元中的待训练神经网络,生成轨迹数据;根据所述轨迹数据,生成梯度数据;更新模块,用于根据所述平均化后的梯度数据更新所述待训练神经网络。
[0016]根据本专利技术实施例的又一方面,提供了一种计算设备,包括:处理器、存储器、通信接口和通信总线,所述处理器、所述存储器和所述通信接口通过所述通信总线完成相互间的通信;所述存储器用于存放至少一可执行指令,所述可执行指令使所述处理器执行上述神经网络的训练方法对应的操作。
[0017]根据本专利技术实施例的再一方面,提供了一种计算机存储介质,所述存储介质中存储有至少一可执行指令,所述可执行指令使处理器执行如上述神经网络的训练方法对应的操作。
[0018]根据本专利技术上述实施例提供的方案,通过构建至少一个计算单元;将待训练神经网络分别放置到所述至少一个计算单元内,每个计算单元执行参数生成操作,分别得到每个计算单元生成的梯度数据;将所述每个计算单元得出的梯度数据进行平均化,得到平均化后的梯度数据;根据所述平均化后的梯度数据更新所述待训练神经网络;任一所述计算单元的参数生成操作包括:运行所述计算单元中的待训练神经网络,生成轨迹数据;根据所述轨迹数据,生成梯度数据,可以提升单位时间轨迹数据吞吐量,解决超大动作空间和状态空间下的深度强化学习训练时间过长的问题,提高了训练效率。
[0019]上述说明仅是本专利技术实施例技术方案的概述,为了能够更清楚了解本专利技术实施例的技术手段,而可依照说明书的内容予以实施,并且为了让本专利技术实施例的上述和其它目的、特征和优点能够更明显易懂,以下特举本专利技术实施例的具体实施方式。
附图说明
[0020]通过阅读下文优选实施方式的详细描述,各种其他的优点和益处对于本领域普通技术人员将变得清楚明了。附图仅用于示出优选实施方式的目的,而并不认为是对本专利技术实施例的限制。而且在整个附图中,用相同的参考符号表示相同的部件。在附图中:图1示出了本专利技术实施例提供的神经网络的训练方法的方法流程图;图2示出了本专利技术实施例提供的一种具体的基于轨迹数据本地化的通用可配置分布式策略学习流程示意图;图3示出了本专利技术实施例提供的一种具体的基于轨迹数据本地化的通用可配置分布式策略学习架构示意图;图4示出了本专利技术实施例提供的神经网络的训练装置的结构示意图;图5示出了本专利技术实施例提供的计算设备的结构示意图。
具体实施方式
[0021]下面将参照附图更详细地描述本专利技术的示例性实施例。虽然附图中显示了本专利技术的示例性实施例,然而应当理解,可以以各种形式实现本专利技术而不应被这里阐述的实施例所限制。相反,提供这些实施例是为了能够更透彻地理解本专利技术,并且能够将本专利技术的范围完整的传达给本领域的技术人员。
[0022]图1示出了本专利技术实施例提供的神经网络的训练方法的方法流程图。如图1所示,该方法包括以下步骤:步骤11,构建至少一个计算单元;步骤12,将待训练神经网络分别放置到所述至少一个计算单元内,每个计算单元执行参数生成操作,分别得到每个计算单元生成的梯度数据;步骤13,将所述每个计算单本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种神经网络的训练方法,其特征在于,所述方法包括:构建至少一个计算单元;将待训练神经网络分别放置到所述至少一个计算单元内,每个计算单元执行参数生成操作,分别得到每个计算单元生成的梯度数据;将所述每个计算单元得出的梯度数据进行平均化,得到平均化后的梯度数据;根据所述平均化后的梯度数据更新所述待训练神经网络;任一所述计算单元的参数生成操作包括:运行所述计算单元中的待训练神经网络,生成轨迹数据;根据所述轨迹数据,生成梯度数据。2.根据权利要求1所述的神经网络的训练方法,其特征在于,所述每个计算单元均包括:至少一个CPU、至少一个独立内存以及一个GPU。3.根据权利要求1所述的神经网络的训练方法,其特征在于,在构建至少一个计算单元后,还包括:预先配置轨迹数据的维度。4.根据权利要求1所述的神经网络的训练方法,其特征在于,运行所述计算单元中的待训练神经网络,生成轨迹数据,包括:将所述待训练神经网络与初始环境进行交互,得到第一动作分布以及第一环境;对所述第一动作分布进行动作采样,得到第一动作;判断得到所述第一动作后是否满足预设停止交互条件;若满足预设停止交互条件,将所述第一动作相关的运行数据作为轨迹数据,所述第一动作相关的运行数据包括以下至少一个:所述初始环境、所述第一动作分布、所述第一环境、所述第一动作;若不满足预设停止交互条件,将所述第一动作传递给所述第一环境,根据第一动作以及第一环境重新执行将所述待训练神经网络与初始环境进行交互的步骤。5.根据权利要求2所述的神经网络的训练方法,其特征在于,在运行所述计算单元中的待训练神经网络,生成轨迹数据之后,还包括:通过所述至少一个独立内存,将轨迹数据进行储存。6.根据权利要求1所述的神经网络的训练方法,其特征在于,根据所述轨迹数据,生成梯度数据,包括:通过神经网络的...

【专利技术属性】
技术研发人员:徐波唐伟徐博
申请(专利权)人:中国科学院自动化研究所
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1