模型训练方法和装置制造方法及图纸

技术编号:39715773 阅读:6 留言:0更新日期:2023-12-17 23:23
本发明专利技术公开了模型训练方法和装置

【技术实现步骤摘要】
模型训练方法和装置、轨迹预测方法和装置


[0001]本专利技术涉及自动驾驶
,尤其涉及一种模型训练方法和装置

轨迹预测方法和装置


技术介绍

[0002]在自动驾驶车辆行驶过程中,需要对障碍物等智能体的轨迹进行预测,以避免车辆与障碍物发生碰撞,保证车辆的行驶安全

[0003]现有方法通常基于单一智能体的历史轨迹,拟合得到其未来轨迹

[0004]但是,该方法忽略了不同智能体之间的相互影响,预测准确度较低


技术实现思路

[0005]有鉴于此,本专利技术实施例提供一种模型训练方法和装置

轨迹预测方法和装置,能够提高轨迹预测准确度

[0006]第一方面,本专利技术实施例提供了一种模型训练方法,包括:获取智能体的历史轨迹和与所述智能体相关的地图线数据;分别将所述智能体的历史轨迹和所述地图线数据转换至相对坐标系下;其中,所述智能体包括预测智能体和其他智能体,所述其他智能体包括自车,所述相对坐标系的原点位于所述自车的轨迹上;对经过转换的预测智能体的历史轨迹中多个连续帧进行遮挡,得到所述预测智能体的目标轨迹;分别按照时间维度和空间维度,对多个预测智能体的目标轨迹进行聚合,得到预测智能体的时间聚合轨迹和空间聚合轨迹;分别按照时间维度和空间维度,对经过转换的多个其他智能体的历史轨迹进行聚合,得到其他智能体的时间聚合轨迹和空间聚合轨迹;基于多个预测智能体的目标轨迹,分别按照时间维度和空间维度,构建预测智能体的时间
mask
和空间
mask
;基于经过转换的多个其他智能体的历史轨迹,分别按照时间维度和空间维度,构建其他智能体的时间
mask
和空间
mask
;将多个预测智能体当前所在的位置,所述智能体的时间聚合轨迹

空间聚合轨迹

时间
mask
和空间
mask
,以及经过转换的地图线数据输入神经网络模型,得到一阶差分结果;其中,所述神经网络模型基于多头注意力机制,学习轨迹的时间特征和空间特征之间的关系;基于所述一阶差分结果和所述预测智能体当前所在的位置,计算所述预测智能体的预测轨迹;基于所述预测智能体的预测轨迹和历史轨迹,计算损失函数的损失值;基于所述损失值调整所述神经网络模型的参数

[0007]第二方面,本专利技术实施例提供了一种轨迹预测方法,包括:获取智能体的历史轨迹和与所述智能体相关的地图线数据;分别将所述智能体的历史轨迹和所述地图线数据转换至相对坐标系下;其中,所述智能体包括预测智能体和其他智能体,所述其他智能体包括自车,所述相对坐标系的原点位于所述自车的轨迹上;对经过转换的预测智能体的历史轨迹中多个连续帧进行遮挡,得到所述预测智能体的目标轨迹;分别按照时间维度和空间维度,对多个预测智能体的目标轨迹进行聚合,得到预测智能体的时间聚合轨迹和空间聚合轨迹;分别按照时间维度和空间维度,对经过转换的多个其他智能体的历史轨迹进行聚合,得到其他智能体的时间聚合轨迹和空间聚合轨迹;基于多个预测智能体的目标轨迹,分别按照时间维度和空间维度,构建预测智能体的时间
mask
和空间
mask
;基于经过转换的多个其他智能体的历史轨迹,分别按照时间维度和空间维度,构建其他智能体的时间
mask
和空间
mask
;将多个预测智能体当前所在的位置,所述智能体的时间聚合轨迹

空间聚合轨迹

时间
mask
和空间
mask
,以及经过转换的地图线数据输入神经网络模型,得到一阶差分结果;其中,所述神经网络模型基于多头注意力机制,学习轨迹的时间特征和空间特征之间的关系;基于所述一阶差分结果和所述预测智能体当前所在的位置,计算所述预测智能体的预测轨迹;基于所述预测智能体的预测轨迹和历史轨迹,计算损失函数的损失值;基于所述损失值调整所述神经网络模型的参数;基于训练好的所述神经网络模型预测当前智能体的轨迹

[0008]第三方面,本专利技术实施例提供了一种模型训练装置,包括:转换模块,配置为获取智能体的历史轨迹和与所述智能体相关的地图线数据;分别将所述智能体的历史轨迹和所述地图线数据转换至相对坐标系下;其中,所述智能体包括预测智能体和其他智能体,所述其他智能体包括自车,所述相对坐标系的原点位于所述自车的轨迹上;遮挡模块,配置为对经过转换的预测智能体的历史轨迹中多个连续帧进行遮挡,得到所述预测智能体的目标轨迹;聚合模块,配置为分别按照时间维度和空间维度,对多个预测智能体的目标轨迹进行聚合,得到预测智能体的时间聚合轨迹和空间聚合轨迹;分别按照时间维度和空间维度,对经过转换的多个其他智能体的历史轨迹进行聚合,得到其他智能体的时间聚合轨迹和空间聚合轨迹;构建模块,配置为基于多个预测智能体的目标轨迹,分别按照时间维度和空间维度,构建预测智能体的时间
mask
和空间
mask
;基于经过转换的多个其他智能体的历史轨迹,分别按照时间维度和空间维度,构建其他智能体的时间
mask
和空间
mask
;训练模块,配置为将多个预测智能体当前所在的位置,所述智能体的时间聚合轨


空间聚合轨迹

时间
mask
和空间
mask
,以及经过转换的地图线数据输入神经网络模型,得到一阶差分结果;其中,所述神经网络模型基于多头注意力机制,学习轨迹的时间特征和空间特征之间的关系;基于所述一阶差分结果和所述预测智能体当前所在的位置,计算所述预测智能体的预测轨迹;基于所述预测智能体的预测轨迹和历史轨迹,计算损失函数的损失值;基于所述损失值调整所述神经网络模型的参数

[0009]第四方面,本专利技术实施例提供了一种轨迹预测装置,包括:转换模块,配置为获取智能体的历史轨迹和与所述智能体相关的地图线数据;分别将所述智能体的历史轨迹和所述地图线数据转换至相对坐标系下;其中,所述智能体包括预测智能体和其他智能体,所述其他智能体包括自车,所述相对坐标系的原点位于所述自车的轨迹上;遮挡模块,配置为对经过转换的预测智能体的历史轨迹中多个连续帧进行遮挡,得到所述预测智能体的目标轨迹;聚合模块,配置为分别按照时间维度和空间维度,对多个预测智能体的目标轨迹进行聚合,得到预测智能体的时间聚合轨迹和空间聚合轨迹;分别按照时间维度和空间维度,对经过转换的多个其他智能体的历史轨迹进行聚合,得到其他智能体的时间聚合轨迹和空间聚合轨迹;构建模块,配置为基于多个预测智能体的目标轨迹,分别按照时间维度和空间维度,构建预测智能体的时间
mask
和空间
mask
;基于经过转换的多个其他智能体的历史轨迹,分别按照时间维度和空间维度,构建其他智能体的时间<本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.
一种模型训练方法,其特征在于,包括:获取智能体的历史轨迹和与所述智能体相关的地图线数据;分别将所述智能体的历史轨迹和所述地图线数据转换至相对坐标系下;其中,所述智能体包括预测智能体和其他智能体,所述其他智能体包括自车,所述相对坐标系的原点位于所述自车的轨迹上;对经过转换的预测智能体的历史轨迹中多个连续帧进行遮挡,得到所述预测智能体的目标轨迹;分别按照时间维度和空间维度,对多个预测智能体的目标轨迹进行聚合,得到预测智能体的时间聚合轨迹和空间聚合轨迹;分别按照时间维度和空间维度,对经过转换的多个其他智能体的历史轨迹进行聚合,得到其他智能体的时间聚合轨迹和空间聚合轨迹;基于多个预测智能体的目标轨迹,分别按照时间维度和空间维度,构建预测智能体的时间
mask
和空间
mask
;基于经过转换的多个其他智能体的历史轨迹,分别按照时间维度和空间维度,构建其他智能体的时间
mask
和空间
mask
;将多个预测智能体当前所在的位置,所述智能体的时间聚合轨迹

空间聚合轨迹

时间
mask
和空间
mask
,以及经过转换的地图线数据输入神经网络模型,得到一阶差分结果;其中,所述神经网络模型基于多头注意力机制,学习轨迹的时间特征和空间特征之间的关系;基于所述一阶差分结果和所述预测智能体当前所在的位置,计算所述预测智能体的预测轨迹;基于所述预测智能体的预测轨迹和历史轨迹,计算损失函数的损失值;基于所述损失值调整所述神经网络模型的参数
。2.
如权利要求1所述的方法,其特征在于,所述相对坐标系的原点,为所述自车的历史轨迹中,按照时间由远及近的顺序,最后一帧轨迹点;所述相对坐标系的
X
轴,为所述自车的历史轨迹中,按照时间由远及近的顺序,最后一帧轨迹点指向第一帧轨迹点的朝向;所述相对坐标系的
Y
轴,为当左手食指指向所述自车的朝向时,左手拇指的朝向
。3.
如权利要求1所述的方法,其特征在于,所述神经网络模型包括:矢量网络模块

时空交互模块

智能体交互模块和解码模块;所述矢量网络模块,用于从经过转换的地图线数据中提取地图特征,针对预测智能体和其他智能体分别执行:基于时间聚合轨迹和时间
mask
确定时间聚合特征,基于空间聚合轨迹和空间
mask
确定空间聚合特征;所述时空交互模块,用于针对预测智能体和其他智能体分别执行:基于时间聚合特征和空间聚合特征,确定时间与空间交互特征;所述智能体交互模块,用于基于地图特征

预测智能体的时间与空间交互特征以及其他智能体的时间与空间交互特征,确定预测智能体与地图的第一交互特征

其他智能体与地图的第一交互特征以及其他智能体与地图的第二交互特征;所述解码模块,用于基于预测智能体当前所在的位置

预测智能体与地图的第一交互
特征

其他智能体与地图的第一交互特征以及其他智能体与地图的第二交互特征,解码得到所述一阶差分结果
。4.
如权利要求3所述的方法,其特征在于,所述矢量网络模块,包括:第一空间矢量网络

第二空间矢量网络

第三空间矢量网络

第一时间矢量网络和第二时间矢量网络;所述第一空间矢量网络,用于基于预测智能体的空间聚合轨迹和空间
mask
,确定预测智能体的空间聚合特征;所述第二空间矢量网络,用于基于其他智能体的空间聚合轨迹和空间
mask
,确定其他智能体的空间聚合特征;所述第三空间矢量网络,用于从经过转换的地图线数据中提取地图特征;所述第一时间矢量网络,用于基于预测智能体的时间聚合轨迹和时间
mask
,确定预测智能体的时间聚合特征;所述第二时间矢量网络,用于基于其他智能体的时间聚合轨迹和时间
mask
,确定其他智能体的时间聚合特征;所述矢量网络模块中的各个矢量网络,均包括:输入层
、embedding


最大池化层

注意力模块

归一化层和输出层;其中,所述注意力模块基于多头注意力机制提取特征
。5.
如权利要求4所述的方法,其特征在于,所述时空交互模块,包括:第一时空交互网络和第二时空交互网络;所述第一时空交互网络以预测智能体的空间聚合特征为
query
,以预测智能体的时间聚合特征为
key

value
,采用带掩码的多头注意力机制进行特征提取,得到第一时空交互特征;所述第二时空互网络以其他智能体的空间聚合特征为
query
,以其他智能体的时间聚合特征为
key

value
,采用带掩码的多头注意力机制进行特征提取,得到第二时空交互特征
。6.
如权利要求5所述的方法,其特征在于,所述智能体交互模块,包括:预测智能体与地图的第一交互网络

预测智能体与其他智能体的第一交互网络

其他智能体与地图第一交互网络以及其他智能体与地图的第二交互网络;所述其他智能体与地图的第二交互网络以所述地图特征为
query
,以所述第二时空交互特征为
key

value
,采用带掩码的多头注意力机制进行特征提取,得到所述其他智能体与地图的第二交互特征;所述其他智能体与地图的第一交互网络以所述第二时空交互特征为
query
,以所述其他智能体与地图的第二交互特征为
key

value
,采用带掩码的多头注意力机制进行特征提取,得到所述其他智能体与地图的第一交互特征;所述预测智能体与其他智能体的第一交互网络以所述第一时空交互特征为
query
,以所述其他智能体与地图的第一交互特征为
key

value
,采用带掩码的多头注意力机制进行特征提取,得到预测智能体与其他智能体的第一交互特征;所述预测智能体与地图的第一交互网络以所述预测智能体与其他智能体的第一交互
特征为
query
,以所述其他智能体与地图的第二交互特征为
key

value
,采用带掩码的多头注意力机制进行特征提取,得到所述预测智能体与地图的第一交互特征
。7.
如权利要求6所述的方法,其特征在于,所述解码模块,包括:
embedding


第四空间矢量网络

历史与未来交互网络

预测智能体与其他智能体的第二交互网络

预测智能体与地图的第二交互网络和全连接层;所述<...

【专利技术属性】
技术研发人员:陈昌浩李勇强吕强苗乾坤
申请(专利权)人:新石器慧通北京科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1