基于元信息估计的自动激活值检查点搜索方法技术

技术编号:37079997 阅读:11 留言:0更新日期:2023-03-29 19:55
本发明专利技术公开了基于元信息估计的自动激活值检查点搜索方法,具体涉及深度大模型优化领域,包括以下步骤:S1、初始化模型;S2、线性化模型的建立;S3、元信息抽取;S4、动态规划求解;S5、计算图优化:S6、执行,本发明专利技术通过设计通用节点特性的传播方式,解除了部分依赖,使得模型内可视为线性化分割点的节点增多,增加了线性化网络的链长,从而增加了搜索空间,同时,本发明专利技术的线性化是全自动的,不需要对模型进行改写。写。写。

【技术实现步骤摘要】
基于元信息估计的自动激活值检查点搜索方法


[0001]本专利技术涉及深度大模型优化领域,尤其涉及基于元信息估计的自动激活值检查点搜索方法。

技术介绍

[0002]随深度学习模型逐渐变得越来越大以达到更佳的精度要求。深度学习模型的参数量已经达到百亿级别,甚至千亿级别。智源悟道2.0模型甚至有1.75万亿的参数规模。
[0003]在这一背景下,使用多种技术减缓GPU内存压力是深度学习大模型训练的关键。
[0004]激活值检查点正是其中的关键之一,通过在前向传播时丢弃中间结果节省内存,再由反向传播时重计算来保证正确性。然而,虽然理论上使用激活值检查点技术,可以大幅减少训练时内存的使用,大部分机器学习工程师缺乏对如何使用该技术的直觉,导致优化的效果不理想。
[0005]当前的主流自动激活值检查点搜索方案有rotor和checkmate
[0006]rotor:该搜索系统提供了一套完整的线性化网络自动激活值检查点搜索框架
[0007]系统假设;
[0008]网络线性化:该系统假设网络可以被视为一个线性化的执行序列,即前向传播中,每一个网络层只和前一层的输出有关。
[0009]内存一致性:决定保存的激活值会留存于内存之中,在被反向计算使用之前不会在中途被丢弃;
[0010]系统流程;
[0011]通过测试获得计算序列中每一层的各项开销(激活值存储带来的内存开销、计算时间)
[0012]在线性化假设之下,将内存开销进行离散化,可以将最优激活值检查点的安排建模为一个动态规划问题,在多项式时间内进行求解;
[0013]checkmate:该系统给出了任意网络情况下的激活值检查点最优解
[0014]系统假设;
[0015]网络为任意有向无环图,模糊了前向传播和反向传播的分解,整体的考虑所有计算;
[0016]系统流程;
[0017]通过测试获得计算序列中每一层的各项开销(激活值存储带来的内存开销、计算时间)
[0018]针对整个计算图以及激活值检查点序列,可以将问题建模为一个标准的整数规划问题,使用标准整数规划求解器进行求解;
[0019]然而,上述两种方法均具有一定的缺陷,例如:基于真实张量计算,需要大量时间对内存和计算时间进行估算,有时会发生内存不足的错误;对内存的估算并不是非常准确,导致最终的优化结果存在问题;
[0020]现存的线性化需要手动调整模型,使用起来并不方便;
[0021]线性化导致搜索空间缩小了很多;
[0022]通过真实张量进行节点开销测算;
[0023]整数规划导致求解时间非常长;
[0024]而自动激活值检查点搜索可以提供一个最优的策略,帮助工程师优化模型。

技术实现思路

[0025]本专利技术的目的是为了解决现有技术中存在的缺点,而提出的基于元信息估计的自动激活值检查点搜索方法。
[0026]为了实现上述目的,本专利技术采用了如下技术方案:
[0027]基于元信息估计的自动激活值检查点搜索方法,包括以下步骤:
[0028]S1、初始化模型;
[0029]S2、线性化模型的建立;
[0030]S3、元信息抽取;
[0031]S4、动态规划求解;
[0032]S5、计算图优化;
[0033]S6、执行。
[0034]优选的,所述步骤S1具体为:通过元跟踪来获取网络的计算图。
[0035]优选的,所述步骤S2具体为:根据用户标注,以及节点特性来查找图中的通用型节点,根据节点的依赖关系,来获取可以线性化网络计算图分割点。
[0036]优选的,所述步骤S3具体为:
[0037]S3.1、使用新的张量数据结构虚拟产生一个位于设备上的张量;
[0038]S3.2、注册一系列特殊的算子作为虚拟执行的工具;
[0039]S3.3、元信息抽取的内容如下,fwd_in(算子前向传播输入),fwd_tmp(算子前向传播的临时内存开销),fwd_out(算子前向传播的输出),bwd_out(算子反向传播的输出),bwd_tmp(算子反向传播的临时内存开销)分别对应动态规划求解器中需要的变量信息。
[0040]优选的,所述步骤S4具体为:以rotor算法为基础,重新修改建模,在用户输入的内存上限下,计算出在考虑张量并行通信开销的情况下的最优解。
[0041]优选的,所述步骤S5具体为:依据动态规划求解器的最优解,重构PyTorch原有计算图,生成新的计算图,使自动搜索得到的激活值检查点最终被使用。
[0042]本专利技术的有益效果为:
[0043]1.本专利技术通过设计通用节点特性的传播方式,解除了部分依赖,使得模型内可视为线性化分割点的节点增多,增加了线性化网络的链长,从而增加了搜索空间,同时,本专利技术的线性化是全自动的,不需要对模型进行改写。
[0044]2.本专利技术在初始化模型的部分进行元信息估计,利用注册的特殊算子,虚拟执行模型,并在虚拟执行的过程中抽取等同于真实执行可获得的元信息。
[0045]3.本专利技术中修改了建模,在原本的开销函数中加入了定义张量并行的通讯开销的相关变量,从而将建模推广到了多卡场景,支持了在分布式并行计算中可以使用自动激活值检查点的搜索。
附图说明
[0046]图1为本专利技术中流程图。
[0047]图2为本专利技术中通用节点传播的流程图。
[0048]图3为本专利技术中元信息抽取内容图。
具体实施方式
[0049]下面将结合本专利技术实施例中的附图,对本专利技术实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本专利技术一部分实施例,而不是全部的实施例。
[0050]如图1

3所示,基于元信息估计的自动激活值检查点搜索方法,包括以下步骤:
[0051]S1、初始化模型;
[0052]S2、线性化模型的建立;
[0053]S3、元信息抽取;
[0054]S4、动态规划求解;
[0055]S5、计算图优化;
[0056]S6、执行。
[0057]其中,所述步骤S1具体为:通过元跟踪来获取网络的计算图。
[0058]其中,所述步骤S2具体为:根据用户标注(例如语言模型的注意力掩码),以及节点特性来查找图中的通用型节点,根据节点的依赖关系(忽略通用型节点),来获取可以线性化网络计算图分割点,该过程包括两条流,第一条流在于传递通用节点的信息,有两个规则:
[0059]父节点均为通用节点的节点为通用节点;
[0060]部分特殊操作(例如取节点信息的节点)也为通用节点;
[0061]第二条流在于寻找节点的依赖关系,通过torch.fx产生的DAG(有向无环图),我们通过节点的生存分析(按照DAG的拓扑排序),将可把所有现存依赖(当节点的一个使用者被访问后,我们称节点关于该使用本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.基于元信息估计的自动激活值检查点搜索方法,其特征在于,包括以下步骤:S1、初始化模型;S2、线性化模型的建立;S3、元信息抽取;S4、动态规划求解;S5、计算图优化;S6、执行。2.根据权利要求1所述的基于元信息估计的自动激活值检查点搜索方法,其特征在于,所述步骤S1具体为:通过元跟踪来获取网络的计算图。3.根据权利要求2所述的基于元信息估计的自动激活值检查点搜索方法,其特征在于,所述步骤S2具体为:根据用户标注,以及节点特性来查找图中的通用型节点,根据节点的依赖关系,来获取可以线性化网络计算图分割点。4.根据权利要求5所述的基于元信息估计的自动激活值检查点搜索方法,其特征在于,所述步骤S3具体为:S3.1、使用新的张量数据结构虚拟产生一个位于设备上的张量;S3.2、注册一系列特殊的算子作为虚拟执行的工具,例如:多种不同的卷积算子,线性层算子,激活函数,批量归一化算子,以及一系列PyTorch原生...

【专利技术属性】
技术研发人员:李升桂刘育良邵彦骏姚博远方佳瑞卞正达柳泓鑫李永彬麦思琪吴俊铭陈巍文黄海晨路广阳娄宇轩
申请(专利权)人:北京潞晨科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1