自动激活值检查点搜索和自动张量并行搜索的融合系统技术方案

技术编号:37443165 阅读:9 留言:0更新日期:2023-05-06 09:15
本发明专利技术公开了自动激活值检查点搜索和自动张量并行搜索的融合系统,涉及深度学习技术领域,包括网络线性化、元信息抽取、自动激活值检查点求解器、自动张量并行求解器、两阶段求解器,通过使用元信息抽取为自动张量并行求解器提供所需的内存开销和运算开销信息,自动张量并行求解器根据所给的内存预算和元信息进行求解,两阶段求解器将一系列张量并行策略进行网络线性化,将线性化网络传给自动激活值检查点求解器,自动激活值检查点求解器将内存开销控制在真实内存预算之下,将自动张量并行和自动激活值检查点的策略复合起来,进一步减缓深度学习训练中遇到的内存墙问题,同时两阶段求解器也能保证解决内存问题的同时兼顾模型训练的性能。训练的性能。训练的性能。

【技术实现步骤摘要】
自动激活值检查点搜索和自动张量并行搜索的融合系统


[0001]本专利技术涉及深度学习
,具体为自动激活值检查点搜索和自动张量并行搜索的融合系统。

技术介绍

[0002]随深度学习模型逐渐变得越来越大以达到更佳的精度要求。深度学习模型的参数量已经达到百亿级别,甚至千亿级别。智源悟道2.0模型甚至有1.75万亿的参数规模。
[0003]在这一背景下,使用多种技术减缓GPU内存压力是深度学习大模型训练的关键。激活值检查点和张量并行是两种行之有效的方法,很多相关工作也给出了基于动态规划和整数规划的策略搜索来给出更好的优化策略。然而,尚未有系统将这两者有机的结合起来。
[0004]当前的主流自动激活值检查点搜索方案有rotor和checkmate
[0005]rotor:该搜索系统提供了一套完整的线性化网络自动激活值检查点搜索框架,
[0006]系统假设
[0007]网络线性化:该系统假设网络可以被视为一个线性化的执行序列,即前向传播中,每一个网络层只和前一层的输出有关。内存一致性:决定保存的激活值会留存于内存之中,在被反向计算使用之前不会在中途被丢弃。
[0008]系统流程
[0009]通过测试获得计算序列中每一层的各项开销(激活值存储带来的内存开销、计算时间)
[0010]在线性化假设之下,将内存开销进行离散化,可以将最优激活值检查点的安排建模为一个动态规划问题,在多项式时间内进行求解,论文地址:https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b

Abstract.html
[0011]checkmate:该系统给出了任意网络(有向无环图)情况下的激活值检查点最优解,系统假设
[0012]网络为任意有向无环图,模糊了前向传播和反向传播的分解,整体的考虑所有计算。
[0013]系统流程,通过测试获得计算序列中每一层的各项开销(激活值存储带来的内存开销、计算时间),针对整个计算图以及激活值检查点序列,可以将问题建模为一个标准的整数规划问题,使用标准整数规划求解器进行求解,
[0014]当前的主流自动张量并行搜索方案有Alpa
[0015]Alpa:该系统针对自动张量并行和自动流水并行进行了细致的建模,在给定网络的代价模型的前提下,使用整数规划求解最优策略
[0016]论文地址:https://arxiv.org/pdf/2201.12023.pdf
[0017]Flexflow:Flexflow是一个在SOAP空间中,为任意DNN模型和设备拓扑,根据执行模拟器(execution simulator)cost model,自动搜索最优并行策略的系统,
[0018]但是rotor:基于真实张量计算,需要大量时间对内存和计算时间进行估算,有时会发生内存不足的错误。对内存的估算并不是非常准确,导致最终的优化结果存在问题。现存的线性化需要手动调整模型,使用起来并不方便。线性化导致搜索空间缩小了很多。
[0019]checkmate:通过真实张量进行节点开销测算。整数规划导致求解时间非常长。
[0020]Alpa:没有考虑自动激活值检查点搜索,没有PyTorch的版本,对此我们提出了一种自动激活值检查点搜索和自动张量并行搜索的融合系统。

技术实现思路

[0021]针对现有技术的不足,本专利技术提供了自动激活值检查点搜索和自动张量并行搜索的融合系统,解决了上述
技术介绍
中提出的问题。
[0022]为实现以上目的,本专利技术通过以下技术方案予以实现:自动激活值检查点搜索和自动张量并行搜索的融合系统,包括网络线性化、元信息抽取、自动激活值检查点求解器、自动张量并行求解器、两阶段求解器、优化计算图,
[0023]还包括以下步骤:
[0024]S1:初始化模型,
[0025]S2:使用元信息抽取为自动张量并行求解器提供所需的内存开销和运算开销信息;
[0026]S3:自动张量并行求解器根据所给的内存预算和元信息进行求解;
[0027]S4:两阶段求解器将一系列张量并行策略进行网络线性化,将线性化网络传给自动激活值检查点求解器;
[0028]S5:自动激活值检查点求解器将内存开销控制在真实内存预算之下,得到一系列策略,从中选取估算运行时间最短的策略;
[0029]S6:根据给出的复合策略优化计算图,编译执行。
[0030]优选的,所述网络线性化包括以下步骤:
[0031]通过元跟踪来获取网络的计算图根据用户标注(例如语言模型的注意力掩码),以及节点特性来查找图中的通用型节点
[0032]根据节点的依赖关系(忽略通用型节点),来获取可以线性化网络计算图分割点
[0033]优选的,所述元信息抽取在不真实执行的情况下,通过pytorch的__torch_dispatch__机制,来获取张量在我们希望的设备上(GPU、CPU)进行运行时的底层算子,并由此估算运行时的内存开销和计算量;
[0034]使用__torch_dispatch__机制获取各个算子的底层算子;
[0035]根据获取的算子将模型的运算开销和模型的内存开销抽象成依赖于输入和输出形状的数学公式,方便自动张量并行获取算子在不同切分方式下的。
[0036]优选的,所述自动激活值检查点求解器以rotor算法为基础,重新修改了建模,考虑了张量并行的多卡通信场景,在用户输入的内存上限下,给出在考虑张量并行通信开销的情况下的最优解。
[0037]优选的,所述自动张量并行求解器以Alpa为基础在PyTorch上构建的自动张量并行,在给定元信息的情况下进行张量并行策略的求解
[0038]优选的,所述两阶段求解器包括以下步骤,在给定自动张量并行求解器和自动激
活值检查点求解器的情况下,依据用户实际可用内存预算向自动张量并行求解器输入一组虚拟内存预算,给出一系列的张量并行策略,再使用自动激活值检查点求解器给出激活值检查点策略将内存控制到真实内存预算之下,比较一系列的策略,选出估算运行时间最短的策略。
[0039]优选的,所述优化计算图依据动态规划求解器的最优解,重构PyTorch原有计算图,生成新的计算图,使自动搜索得到的激活值检查点最终被使用,从而大幅减少训练内存开销。
[0040]本专利技术提供了自动激活值检查点搜索和自动张量并行搜索的融合系统,具备以下有益效果:
[0041]1、该自动激活值检查点搜索和自动张量并行搜索的融合系统,通过使用元信息抽取为自动张量并行求解器提供所需的内存开销和运算开销信息,自动张量并行求解器根据所给的内存预算和元信息进行求解,两阶段求解器将一系列张量并行策略进行网络线性化,将线性化网络传给自动激活值检查点求解器,自动激活值检查点求解器将内存开销控制在真实内存预算之下,得到一本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.自动激活值检查点搜索和自动张量并行搜索的融合系统,其特征在于:包括网络线性化、元信息抽取、自动激活值检查点求解器、自动张量并行求解器、两阶段求解器、优化计算图,还包括以下步骤:S1:初始化模型,S2:使用元信息抽取为自动张量并行求解器提供所需的内存开销和运算开销信息;S3:自动张量并行求解器根据所给的内存预算和元信息进行求解;S4:两阶段求解器将一系列张量并行策略进行网络线性化,将线性化网络传给自动激活值检查点求解器;S5:自动激活值检查点求解器将内存开销控制在真实内存预算之下,得到一系列策略,从中选取估算运行时间最短的策略;S6:根据给出的复合策略优化计算图,编译执行。2.根据权利要求1所述的自动激活值检查点搜索和自动张量并行搜索的融合系统,其特征在于:所述网络线性化包括以下步骤:通过元跟踪来获取网络的计算图根据用户标注(例如语言模型的注意力掩码),以及节点特性来查找图中的通用型节点;根据节点的依赖关系(忽略通用型节点),来获取可以线性化网络计算图分割点。3.根据权利要求1所述的自动激活值检查点搜索和自动张量并行搜索的融合系统,其特征在于:所述元信息抽取在不真实执行的情况下,通过pytorch的__torch_dispatch__机制,来获取张量在我们希望的设备上(GPU、CPU)进行运行时的底层算子,并由此估算运行时的内存开销和计算量;使用__torch_dispatch__机制获取各个算子的底层算子...

【专利技术属性】
技术研发人员:刘育良李升桂姚博远邵彦骏方佳瑞卞正达柳泓鑫李永彬麦思琪吴俊铭陈巍文黄海晨路广阳娄宇轩
申请(专利权)人:北京潞晨科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1