时空交通预测模型的对抗训练方法、装置、设备及介质制造方法及图纸

技术编号:39248800 阅读:23 留言:0更新日期:2023-10-30 12:00
本发明专利技术公开了一种时空交通预测模型的对抗训练方法、装置、设备及介质,方法包括:获取时空交通预测模型,将所述时空交通预测模型的对抗训练过程定义为优化问题;利用训练好的策略网络,动态选取对抗训练过程中的对抗节点;通过知识蒸馏技术,利用所述对抗节点和所述优化问题对所述时空交通预测模型进行对抗训练。本发明专利技术实施例能够提高时空交通预测模型的对抗鲁棒性,在时空流量预测任务中更有效地防御动态对抗攻击。动态对抗攻击。动态对抗攻击。

【技术实现步骤摘要】
时空交通预测模型的对抗训练方法、装置、设备及介质


[0001]本专利技术涉及模型训练
,尤其涉及一种时空交通预测模型的对抗训练方法、装置、设备及介质。

技术介绍

[0002]目前各种对抗训练方法都旨在通过生成对抗样本并在训练过程中加入对抗样本,以提高模型在对抗攻击下的鲁棒性。对于时空交通数据,由于时空类数据通常具有强烈的时序依赖和空间依赖性,攻击者可能利用这种依赖性来设计针对性的对抗训练所需要的对抗攻击。然而,现有的对抗训练方法主要针对静态数据,很难直接应用于具有时序和空间依赖性的数据。因此,有必要研究一种时空交通预测模型的对抗训练方法。

技术实现思路

[0003]本专利技术实施例的目的是提供一种时空交通预测模型的对抗训练方法、装置、设备及介质,能够提高时空交通预测模型的对抗鲁棒性,在时空流量预测任务中更有效地防御动态对抗攻击。
[0004]为实现上述目的,本专利技术实施例提供了一种时空交通预测模型的对抗训练方法,包括:
[0005]获取时空交通预测模型,将所述时空交通预测模型的对抗训练过程定义为优化问题;本文档来自技高网...

【技术保护点】

【技术特征摘要】
1.一种时空交通预测模型的对抗训练方法,其特征在于,包括:获取时空交通预测模型,将所述时空交通预测模型的对抗训练过程定义为优化问题;利用训练好的策略网络,动态选取对抗训练过程中的对抗节点;通过知识蒸馏技术,利用所述对抗节点和所述优化问题对所述时空交通预测模型进行对抗训练。2.如权利要求1所述的时空交通预测模型的对抗训练方法,其特征在于,所述获取时空交通预测模型,将所述时空交通预测模型的对抗训练过程定义为优化问题,包括:获取时空交通预测模型:其中,为时空交通预测模型,是时空交通预测模型预测的t时刻到t+T时刻的交通状态,表示交通网络表示交通网络是n个节点的集合,ε是一组边,表示节点的属性,其中x
i,t
表示t时刻i节点的交通属性以及地理信息;定义对抗样本空间:其中,是时空对抗样本,Δ
t
是时空对抗性扰动,矩阵I
t
∈{0,1}
n
×
n
是对抗节点指标,其第j个对角元素表示是否节点v
j
在t时刻被选为对抗节点,若普通节点被选为对抗节点,则第j个对角线元素等于1,否则为0,表示第t时刻的时空特征,p表示p范数,∈表示对抗扰动的限制大小;将所述时空交通预测模型的对抗训练过程定义为优化问题:其中,是t

τ时刻到t时刻的对抗样本,为对抗训练的损失函数,为训练集合的时间索引集合,表示t+1:t+T时刻的预测真值,表示t

τ时刻到t时刻的对抗样本。3.如权利要求2所述的时空交通预测模型的对抗训练方法,其特征在于,所述知识蒸馏技术的损失函数为:其中,表示知识蒸馏损失函数,表示教师模型,其采用自上一时刻训练的时空交通预测模型,表示MSE损失;所述对抗训练的损失函数为:
其中,α表示控制从教师模型传递的知识量的参数。4.如权利要求1所述的时空交通预测模型的对抗训练方法,其特征在于,所述利用训练好的策略网络,动态选取对抗训练过程中的对抗节点,包括:将从时空地理分布数据源中选择最优节点子集的问题表示为s,包括n个节点,每个节点由时空特征表示,从时间t

τ到t,目标是从n个节点的集合中选择η个节点,表示为节点子集Ω=(ω1,


η
),其中ω
k
∈{v1,

,v
n
}且}且其中,p
φ
(Ω∣s)表示随机策略。5.如权利要求4所述的时空交通预测模型的对抗训练方法,其特征在于,所述策略网络包括编码器和解码器,所述编码器将时空交通数据作为输入,并生成节点嵌入表征作为输出,所述解码器将所述节点嵌入表征和第k次迭代选择的节点作为输入,生成节点序列;其中,所述将时空交通数据作为输入,并生成节点嵌入表征作为输出,包括:在空间层通过聚合邻近节点的隐藏状态,通过自适应图卷积更新隐藏层嵌入:其中,Z

l
表示第l层隐藏嵌入的输出,W
i
是深度i的模型参数,A
ada
是可学习的邻接矩阵;在时间层处理序列数据:Z

l
=tanh(θ1★
E
l
)

σ(θ2★
E
l
)其中,σ表示sigmoid函数,θ1和θ2是模型参数,

表示扩张卷积操作,

表示逐元素乘法,E
l
是l块的输入和l

1块的输出,每个块都添加残差连接:E
l+1
=Z
l
+E
l
将不同层的隐藏状态连接,并传递到多层感知器中以获得最终的节点嵌入:F=MLP(||
l=1
Z
l
)其中,F是节点嵌入的集合,F
i
是节点v
i
的嵌入表征,所有节点嵌入的平均值表示为图嵌入,其表示为所述将所述节点嵌入表征和第k次迭代选择的节点作为输入,生成节点序列,包括:上下文节点嵌入定义为:其中,F
i
是图嵌入表征,V是第一次迭代中学习到的嵌入,Uω
k
‑1是k

1步迭代
中最后一个选择节点的嵌入;更新上下文节点嵌入以获取消息信息,通过多头注意力计算新的上下文节点嵌入:U

(c)
=||
j=1
MHA
j
(q
(c)
,k
j
,v
j
)其中,U

(c)
表示...

【专利技术属性】
技术研发人员:刘帆刘浩
申请(专利权)人:广州市香港科大霍英东研究院
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1