一种基于深度学习的Cox模型预测疾病预后生存率方法技术

技术编号:37221456 阅读:12 留言:0更新日期:2023-04-20 23:07
本发明专利技术公开了属于计算机深度学习应用于生存分析领域的一种基于深度学习的Cox模型预测疾病预后生存率方法;该方法是将Cox模型与Transformer相结合预测疾病预后生存率,采用包括数据输入模块、数据预处理模块、预测模型构建模块、模型训练模块、模型测试模块、预测结果展示模块的Cox模型,并应用机器深度学习方法进行预测;预测模型直接对生存概率建模,对潜在的随机生存分布不做任何假设;可以拟合时间依赖效应,不依赖比例风险假设;本发明专利技术能提高预后风险预测的能力,具有神经网络连续建模事件时间的灵活性。降低预测成本。同时,Transformer的注意力机制增加预测结果的可解释性和显示特征之间的相关性。释性和显示特征之间的相关性。释性和显示特征之间的相关性。

【技术实现步骤摘要】
一种基于深度学习的Cox模型预测疾病预后生存率方法


[0001]本专利技术属于计算机深度学习应用于生存分析领域,特别涉及一种基于深度学习的Cox模型预测疾病预后生存率方法;具体说是一种基于Transformer的Cox模型用于预测疾病预后生存率的方法。

技术介绍

[0002]Transformer是目前最流行的成功应用于生存分析的深度学习模型,生存分析方法大致可分为三种类型:参数法、非参数法和半参数法。其中,参数化方法通常具有较好的分析结果。然而,它的应用范围较窄,因为样本的生存时间需要满足特定的分布类型,如Weibull分布、指数分布、伽马分布等。非参数方法更适合单变量分析,不需要生存数据满足任何特定的分布。非参数方法的代表是寿命表法和Kaplan

Merier方法。半参数法的典型算法是Cox比例风险回归模型(Cox Proportional Hazards Regression Model,CPH)。CPH是多元线性回归模型,它隐含了一项比例风险假设。
[0003]许多研究人员对CPH模型进行了扩展,以更好地预测生存时间和估计生存率。Faraggi和Simon首先提出用神经网络代替CPH模型协变量的线性组合。然而,该模型未能超过标准的Cox模型。之后,Xiang,Katzman,Luck和Yousefi等人通过使用更复杂的网络架构和训练损失改进了CPH模型。Mobadersany和Zhu用卷积神经网络取代了Katzma所提模型的多层感知机(MLP),并将Cox模型扩展到图像。此外,多任务学习、迁移学习、主动学习也被用于增强CPH模型。然而,上述基于Cox的模型仍然保持了比例风险假设(Proportional Hazard,PH),没有考虑协变量对生存状况的时变影响。
[0004]许多机器学习方法也被应用于生存分析。Chen等人使用支持向量机(SVM)来估计生存概率。随机生存森林(RSF)构造了一个用于分析右删失生存数据的集合树。BoXHE模型通过梯度提升以完全非参数的方式来估计风险函数。最近,一些新的生存模型提出使用深度学习。RNN

SURV和DRSA模型使用循环神经网络(RNN)计算患者的生存函数。DeepHit模型构建了一个深度神经网络来学习生存时间和竞争风险的分布。Nagpal等人提出了一个具有竞争风险的删失数据的参数生存回归模型。然而,这些方法并没有充分利用深度神经网络学习复杂风险函数的潜在能力。一个原因是他们只采用简单的MLP,这限制了它们内在的学习能力。
[0005]生存分析主要研究对某一特定事件产生显著影响的因素和事件发生的时间。它在医学上有广泛的应用。Cox回归模型是医学领域应用最广泛的生存方法。该模型可以同时分析多个变量对生存的影响,并且具有宽松的风险假设。然而,Cox模型隐含了许多限制性假设,这些假设在实践中难以满足。具体来说,Cox模型的风险率不随时间变化是常数,风险函数的对数是协变量的线性组合。神经网络可以在没有统计假设的情况下学习非线性和高度复杂的函数。利用深度学习和优化算法的最新发展,基于神经网络的生存模型可以学习协变量与风险之间的复杂关系,其可解释性较差,从而需要提供更有效的解决方案。
[0006]Transformer是目前最流行的深度学习模型,已经成功应用于各个领域。一些研究
者提出了基于Transformer的生存分析模型,并取得了令人满意的仿真结果。但是他们都将事件时间进行离散,并在预设的时间网格上计算风险或生存函数。与上述模型不同的是,我们的模型以Transformer为主体,结合Cox模型连续学习事件时间。据目前所知,我们首创将Transformer和Cox模型结合起来进行生存分析的。

技术实现思路

[0007]本专利技术的目的是提出一种基于深度学习的Cox模型预测疾病预后生存率方法,其特征在于,预测疾病预后生存率是采用包括:数据输入模块;数据预处理模块;预测模型构建模块;模型训练模块;模型测试模块;预测结果展示模块的的Cox模型应用机器深度学习方法进行预测,其步骤如下:
[0008]步骤1:数据输入模块,用于获取疾病预后数据;
[0009]步骤2:数据预处理模块,用于对待预测的疾病数据进行数据的预处理,同时数据按照7:3的比值分为训练样本集和测试集两组;
[0010]步骤3:预测模型构建模块,用于构建基于Transformer的Cox模型和预测模型的损失函数;
[0011]步骤4:模型训练模块,将步骤3构建好的预测模型用于处理步骤2构造的训练集;
[0012]步骤5:模型测试模块,步骤4训练好的模型作为最终应用模型,利用步骤2得到的测试样本集对训练好的模型进行测试。
[0013]步骤6:预测结果展示模块,根据步骤4中拟合好的模型,建立相应的预测模型。以患者疾病相关数据作为输入,评估生存风险预测模型的优劣,得到相应的生存率

时间曲线。
[0014]所述步骤2中,数据预处理模块,对待预测的疾病预后预处理,具体是将样本数据分为连续变量和分类变量,对分类变量进行one

hot编码,对连续变量进行归一化处理。
[0015]所述步骤3中,预测模型构建模块,用于构建基于Transformer的Cox模型和预测模型的损失函数,具体包括:
[0016]3‑
1:嵌入层:嵌入过程是通过将分类和数值协变量输入两个独立的全连接层来实现的;其中数值协变量还包括个体的事件时间t、分类协变数量D
c
和数值协变数量D
n
;对于分类协变量通过一个嵌入矩阵在d
k
维空间中表示;对于数值协变量通过一个嵌入矩阵对其进行投影;由于不处理序列数据,所以在嵌入层中不包括位置编码;事件时间t的嵌入t

直接传递到子网络层,其他协变量的嵌入在进入Transformer层之前被拼接起来;
[0017]3‑
2:Transformer层
[0018]为了进一步促进协变量之间的交互以实现高级组合嵌入,叠加了多个Transformer层;在每个Transformer层中,每两个子层之间都添加一个残差连接,然后进行层归一化;对于每个患者,Transformer层的输入如下:
[0019]x

=[x
′1,x
′2,...,x

D
],
[0020]其中D为所有输入层协变量的个数,D=D
c
+D
n

1。
[0021]首先,将Transformer的多头自注意力应用于输入x

来自动学习协变量之间的交
互:
[0022]x

=LayerNorm(MultiHeadAttention(x

)+x

).
[0023]对于h个平行的注意力本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于深度学习的Cox模型预测疾病预后生存率方法,其特征在于,预测疾病预后生存率是采用包括:数据输入模块;数据预处理模块;预测模型构建模块;模型训练模块;模型测试模块;预测结果展示模块的Cox模型应用机器深度学习方法进行预测,其步骤如下:步骤1:数据输入模块,用于获取疾病预后数据;步骤2:数据预处理模块,用于对待预测的疾病数据进行数据的预处理,同时数据按照7:3的比值分为训练样本集和测试集两组;步骤3:预测模型构建模块,用于构建基于Transformer的Cox模型和预测模型的损失函数;步骤4:模型训练模块,将步骤3构建好的预测模型用于处理步骤2构造的训练集;步骤5:模型测试模块,步骤4训练好的模型作为最终应用模型,利用步骤2得到的测试样本集对训练好的模型进行测试;步骤6:预测结果展示模块,根据步骤4中拟合好的模型,建立相应的预测模型,以患者疾病相关数据作为输入,评估生存风险预测模型的优劣,得到相应的生存率

时间曲线。2.根据权利要求1所述的基于深度学习的Cox模型预测疾病预后生存率方法,其特征在于,所述步骤2中,数据预处理模块,对待预测的疾病预后预处理,具体是将样本数据分为连续变量和分类变量,对分类变量进行one

hot编码,对连续变量进行归一化处理。3.根据权利要求1所述的基于深度学习的Cox模型预测疾病预后生存率方法,其特征在于,所述步骤3中,预测模型构建模块,用于构建基于Transformer的Cox模型和预测模型的损失函数,具体包括:3

1:嵌入层:嵌入过程是通过将分类和数值协变量输入两个独立的全连接层来实现的;其中数值协变量还包括个体的事件时间t、分类协变数量D
c
和数值协变数量D
n
;对于分类协变量通过一个嵌入矩阵在d
k
维空间中表示;对于数值协变量通过一个嵌入矩阵对其进行投影;由于不处理序列数据,所以在嵌入层中不包括位置编码;事件时间t的嵌入t

直接传递到子网络层,其他协变量的嵌入在进入Transformer层之前被拼接起来;3

2:Transformer层为了进一步促进协变量之间的交互以实现高级组合嵌入,叠加了多个Transformer层;在每个Transformer层中,每两个子层之间都添加一个残差连接,然后进行层归一化;对于每个患者,Transformer层的输入如下:x

=[x
′1,x
′2,

,x

D
],其中D为所有输入层协变量的个数,D=D
c
+D
n

1;首先,将Transformer的多头自注意力应用于输入x

来自动学习协变量之间的交互:x

=LayerNorm(MultiHeadAttention(x

)+x

).对于h个平行的注意力头,多头注意力的定义为:MultiHeadAttention=Concat(head1,head2,

,head
h
)W
o
,其中W
o
是可学习的权重矩阵,每个头计算协变量嵌入的所有元素之间的缩放点积注意力;第i个头的缩放点积注意力为:
其中Q
i
、K
i
和V
i
分别表示键矩阵、查询矩阵和值矩阵:Q
i
=x

W
iQ
,K
i
=x

W
iK
,V
i
=x

W
iV
,其中W
iQ
,W
iK
,W
iV
是可学习的权值矩阵,用于将原始嵌入变换成相同的维度;然后,将每个协变量的注意力嵌入x

j
输入到前馈神经网络(FNN)中,再进行残差连接和层归一化,x
″′
j
=LayerNorm(FNN(x

j
)+x

j
),j=1,2,

,D,对于每一层Transformer,嵌入的输出由上述公式导出;可以堆叠L层Transformer,第j个协变量的嵌入表示为:x

j

x
″′
(1,j)

x
...

【专利技术属性】
技术研发人员:滕婧张宏蕾米春琳
申请(专利权)人:华北电力大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1