当前位置: 首页 > 专利查询>北京大学专利>正文

一种基于参数分离和知识蒸馏的图持续学习方法技术

技术编号:38675034 阅读:14 留言:0更新日期:2023-09-02 22:50
本发明专利技术属于深度学习技术领域,具体公开一种基于参数分离和知识蒸馏的图持续学习方法,该方法的输入是一张原图,由节点和边构成,Trans

【技术实现步骤摘要】
一种基于参数分离和知识蒸馏的图持续学习方法


[0001]本专利技术属于深度学习
,更具体地,涉及一种基于参数分离和知识蒸馏的图持续学习方法。

技术介绍

[0002]随着现实世界的迅速发展,人类生活的环境变得越来越复杂,为了更好的探索真实世界,研究者将世界中的大量系统网络化,运用图神经网络(Graph neural netowork/GNN)。挖掘网络数据信息对生活中的许多应用场景具有重要的意义,例如:推荐系统、链路预测、节点分类、节点聚类和图分类等任务。图数据作为一种非结构化数据,其拓扑结构能够更好的描述数据之间的关系,从图数据的拓扑结构中可以提取大量的关系属性特征,如何处理、挖掘图数据中的信息正逐渐受到研究者们的重视并成为一个热门的研究方向。
[0003]然而,大多数现有的图神经网络模型都假设数据是共同可用的,这与现实世界中数据按顺序到达的情况相反。因此,基于参数分离和知识蒸馏的图持续学习(Continual garph learning/CGL)考虑在数据流中连续学习不同的任务。在CGL的情况下,GNN通常受到灾难性遗忘问题的困扰。具体来说,当GNN模型在旧任务的数据上收敛后继续训练一个新的任务时,它就会忘记旧任务的知识,导致模型在旧任务上的性能变差。图像领域的持续学习出现灾难性遗忘的主要原因是,当训练一个新任务时,模型会改变参数,从而导致旧任务的特征分布被破坏。这一问题在图上更为复杂,图节点的特征向量主要由节点本身的特征和拓扑特征组成,由于不同任务的拓扑特征相互影响,在新任务的训练过程中,旧任务的特征分布被打乱的程度更大,从而导致旧任务的灾难性遗忘更多,我们将其定义为拓扑特征引起的灾难性遗忘(TCF)。现有的CGL节点分类方法主要分为三类:基于重放的方法、基于正则化的方法、以及参数分离的方法。现有的CGL方法没有对TCF给予足够的重视。虽然他们考虑了拓扑特征,但没有很好地区分哪些拓扑特征是跨任务共享的。

技术实现思路

[0004]提供了本专利技术以解决现有技术中存在的上述问题。因此,需要一种基于参数分离和知识蒸馏的图持续学习方法,本专利技术能有效缓解图神经网络持续学习过程产生的灾难性遗忘问题,并且有效地提高了图神经网络的持续学习能力和节点分类的准确率。
[0005]根据本专利技术的第一方案,提供了一种基于参数分离和知识蒸馏的图持续学习方法,所述方法包括:
[0006]构造子图;
[0007]参数分离;
[0008]知识蒸馏;
[0009]模型输出。
[0010]进一步地,所述构造子图具体包括:
[0011]获取图G={V,E,A},其中V是n个节点的集合,E是边的集合,A是相邻矩阵;
[0012]构建模型,所述模型按顺序学习一连串不相干的任务每个任务包含一个节点集V
k
,V
t
中的每个节点v
i
对应于标签集Y
t
中的一个类标签y
i
,其中c
t
是任务t中的类别数量;
[0013]基于每个节点v
i
∈V
t
,根据相似度为每个节点构造子图,其中所有节点的子图个数均为k,每个节点子图的其他节点根据相似度或根据在原图G中的邻居来选择,得到任务t的子图集合
[0014]对于节点v
i
,其邻居节点v
k
在g
i
中的初始嵌入分别设为和其中d
e
表示维度,其中I
j
∈R
|V|
代表识别节点v
j
的one

hot向量,是一个可学习嵌入矩阵;
[0015]g
i
中所有节点的嵌入组织成一个矩阵第l层的嵌入矩阵表示为:
[0016][0017]其中Q
t
、K
t
和分别是任务t的query、key和value向量;是任务t的注意力矩阵;Ψ
t
∈R
(k+1)
×
(k+1)
是任务t的相似度矩阵;σ(
·
)是Softmax激活函数;
[0018]Q
t
、K
t
和的定义如下:
[0019][0020]其中和是投影矩阵。
[0021]进一步地,所述参数分离具体包括:
[0022]为每个任务分配可学习的二进制掩码,对于任务t,得到:
[0023][0024]其中,Q
t
、K
t
和V
t
分别为任务t的query、key和value向量;和分别是任务t的query、key和value的二进制掩码;

表示点乘。
[0025]引入一个可微分函数,得到Q
t
的可微分二进制表示为:
[0026][0027]其中I
t
∈R
T
代表一个识别当前任务t的单热向量,是一个可学习嵌入矩阵,δ是一个具有大值的变形因子;
[0028]在引入可微分的二进制掩码后,任务t的门控注意力表示为:
[0029][0030]确定任务t的相似度门控注意力A
t
的二进制掩码被定义为:
[0031][0032]其中是一个可学习的嵌入矩阵;
[0033]第l层的嵌入矩阵表示为
[0034]进一步地,所述知识蒸馏具体包括:
[0035]将用于约束参数更新的掩膜注意力正则化损失定义为:
[0036][0037]其中和表示任务t

1和t的屏蔽注意力;和分别表示和中第i行和第j列的值;共享的二进制注意掩码是和的位和运算结果,是共享拓扑参数的掩码。
[0038]为了防止蒸馏的时候,模型对新知识的获得进行抑制,本专利技术引入一个不对称的掩膜注意正则化损失,将公式(7)改写为:
[0039][0040]其中是一个不对称的函数。
[0041]进一步地,所述模型输出具体包括:
[0042]更新节点嵌入,得到第L层上的嵌入矩阵H
L
,基于所述第L层上的嵌入矩阵,得到任务t的最终输出
[0043][0044]其中是任务t的投影矩阵,是v
i
节点在L层的嵌入。
[0045]在模型训练过程中,训练损失定义为:
[0046][0047]其中CE(
·
)是交叉熵损失,α是一个加权系数。
[0048]进一步地,在模型输出之后,所述方法还包括结果的获得与验证。
[0049]进一步地,所述结果的获得与验证,具体包括:
[0050]通过AA和AF两个指标,来验证模型对于灾难性遗忘的缓解效果,其中,AA表示模型在所有任务上的节点分类的准确率,AF表示所有任务的遗忘程度,所述遗忘程度表示当模型不断持续学习的过程中本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于参数分离和知识蒸馏的图持续学习方法,其特征在于,所述方法包括:构造子图;参数分离;知识蒸馏;模型输出。2.根据权利要求1所述的方法,其特征在于,所述构造子图具体包括:获取图G={V,E,A
adj
},其中V是n个节点的集合,E是边的集合,A
adj
是相邻矩阵;构建模型,所述模型按顺序学习一连串不相干的任务每个任务包含一个节点集V
t
,V
t
中的每个节点v
i
对应于标签集Y
t
中的一个类标签y
i
,其中c
t
是任务t中的类别数量;基于每个节点v
i
∈V
t
,根据相似度为每个节点构造子图,其中所有节点的子图个数均为k,每个节点子图的其他节点根据相似度或根据在原图G中的邻居来选择,得到任务t的子图集合对于节点v
i
,其邻居节点v
j
在g
i
中的初始嵌入分别设为和其中d
e
表示维度,其中I
j
∈R
|V|
代表识别节点v
j
的one

hot向量,是一个可学习嵌入矩阵;gi中所有节点的嵌入组织成一个矩阵第l层的嵌入矩阵表示为:其中Q
t
、K
t
和分别是任务t的query、key和value向量;是任务t的注意力矩阵;是任务t的相似度矩阵;σ(
·
)是Softmax激活函数;Q
t
、K
t
和的定义如下:其中和是投影矩阵。3.根据权利要求2所述的方法,其特征在于,所述参数分离具体包括:为每个任务分配可学习的二进制掩码,对于任务t,得到:
其中,Q
t
、K
t
和V
t
分别为任务t的query、key和value向量;和分别是任务t的query、key和value的二进制掩码;

表示点乘。引入一个可微分函数,得到Q
t
的可微分二进制表示为:其中I
t
∈R
T
代表一个识别当前任务t的单热向量,是一个可学习嵌入矩阵,δ是一个具有大值的变形因子;在引入可微分的二进制掩码后,任务t的门控注意力表示为:确定任务t的相似度门控注意力A
t
的二进制掩码被定义为:其中是一个可学习的嵌入矩阵;第l层的嵌入矩阵表示为4.根据权利要求3所述的方法,其特征在于,所述知识蒸馏具体包括:将用于约束参数更新的掩膜注意力正则化损失定义为:其中和表示任务t

1和t的屏蔽注意力;和分别表示和中第i行和第j列的值;共享的二进制注意掩码是和的位和运算结果,是共享拓扑参数的掩码。引入一...

【专利技术属性】
技术研发人员:吕肖庆林鸿翔贾瑞琪赵堉萌
申请(专利权)人:北京大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1