一种基于GAN和多任务学习的癌症生存分析方法技术

技术编号:36216094 阅读:24 留言:0更新日期:2023-01-04 12:13
本发明专利技术属于医疗信息技术领域,尤其涉及一种基于GAN和多任务学习的癌症生存分析方法,本发明专利技术使用GAN网络进行数据增强,将出现结局的癌症患者的特征、生存时间和结局类型输入到GAN网络中进行训练,生成大量的非删失生存数据;构建基于软参数共享的多任务学习癌症生存分析模型,多个不同的任务分别预测患者在未来一段时间每个时刻出现不同结局的概率;将所需分析的癌症患者特征输入已构建好的生存分析模型,输出未来不同结局的概率。输出未来不同结局的概率。输出未来不同结局的概率。

【技术实现步骤摘要】
一种基于GAN和多任务学习的癌症生存分析方法


[0001]本专利技术属于医疗信息
,尤其涉及一种基于GAN和多任务学习的癌症生存分析方法。

技术介绍

[0002]对癌症患者的精准预后预测有利于医生优化治疗措施、改善患者预后和降低患者的疾病负担。在医学上,预后通常指的是使用患者的特征预测其在一段时间内出现结局的概率。结局往往是指死亡、复发或病情加重等。生存分析是癌症预后预测中经常使用的分析方法。生存分析的一个关键是删失数据的存在,删失表明患者在研究期间没有发生结局事件。生存分析模型不直接对患者的生存时间进行预测,而是预测患者生存时间的概率分布。
[0003]传统上经常使用Cox比例风险(CPH)来进行癌症生存分析研究。CPH有两个假设:1)比例风险假设:不同患者之间的风险比是一个定值,不会随着时间的变化而变化。2)对数线性假设:患者的特征与患者风险的对数是线性相关的。然而,真实的生存数据很难满足线性比例风险条件。近年来随着深度学习的不断发展,越来越多的学者将全连接神经网络、卷积神经网络、循环神经网络和图神经网络等结构运用在癌症生存分析研究中。除此之外,一些学者还将半监督、自监督、主动学习和多任务学习等方法应用于癌症生存分析领域。
[0004]目前,现有的癌症生存分析方法存在以下不足。第一:癌症生存分析研究中经常会有患者发生删失的情况,但现有生存分析方法无法处理高度删失的情况。第二:使用多任务学习的癌症生存分析方法都是基于硬参数共享,其主要适合任务之间联系紧密的场景。但在癌症生存分析中,不同任务之间的差异性很大,任务与任务之间甚至可能是冲突的。第三:已有生存分析模型对癌症患者的短期结局发生率预测较为准确,但对患者的长期结局发生率预测能力还有待提升。

技术实现思路

[0005]为了解决上述现有技术中存在的技术问题,本专利技术提供了一种基于GAN和多任务学习的癌症生存分析方法,拟解决目前生存分析方法不能处理高度删失的问题。
[0006]本专利技术采用的技术方案如下:
[0007]一种基于GAN和多任务学习的癌症生存分析方法,包括以下步骤:
[0008]步骤1:获取癌症患者的生存数据,形成癌症患者的生存数据集,并将生存数据集中的部分生存数据作为训练集;
[0009]步骤2:基于训练好的Survival

GAN模型对训练集中的生存数据进行数据增强;
[0010]步骤3:搭建基于多任务学习的癌症生存分析模型,并基于增强后的训练集数据对癌症生存分析模型进行训练;使用网格搜索法并配合五折交叉验证搜索出癌症生存分析模型的最优超参数,并用最优超参数重新训练癌症分析模型;
[0011]步骤4:将所需分析的癌症患者的特征输入所构建的癌症生存分析模型中,得到癌症患者在未来一段时间内的每个时刻出现不同结局的概率。
[0012]本专利技术基于Survival

GAN模型对数据进行增强,使得能够生成大量的非删失生存数据,从而扩到了样本量,增强了模型预测的准确性和鲁棒性。
[0013]优选的,所述癌症患者的生存数据包括患者特征、观察时间以及最后一次随访时间的结局类型;若患者最后一次随访时间没有出现任何结局则观察时间为患者的删失时间;若患者最后一次随访时间出现了结局则观察时间为患者的生存时间。
[0014]优选的,所述步骤2包括以下步骤:
[0015]步骤2.1:根据获取的生存数据是否出现结局,将训练集中的癌症患者生存数据分为删失和出现结局的两大群体,并分别记录该两大群体的个数;
[0016]步骤2.2:基于出现结局的生存数据训练Survival

GAN模型;
[0017]步骤2.3:使用网格搜索法并配合五折交叉验证,搜索出Survival

GAN模型的最优超参数,并用最优超参数重新训练Survival

GAN模型;
[0018]步骤2.4:从训练集样本中随机选取K个真实的存活时间与K个不同的结局分别进行配对;依次将K个配对结果输入到Survival

GAN模型中,生成K个出现结局的生存数据;
[0019]步骤2.5:N2自增K,即N2=N2+K,N2表示出现结局的生存数据个数;
[0020]由于每一轮的Survival

GAN模型训练后均会产生K个生存数据,因此经过一轮训练后的生存数据等于K加上输入时的生存数据;即N2=N2+K。
[0021]步骤2.6:判断N2是否小于N1,若不是,则直接结束;若是,则返回到步骤2.4继续执行,直至满足N2大于N1;其中N1表示删失数据的个数。
[0022]优选的,Survival

GAN模型包括生成器和判别器;
[0023]所述生成器包括全连接网络,全连接网络的全连接层的层数和每层神经元的个数均为超参数;
[0024]所述判别器为多任务全连接网络,判别器的全连接层的层数和每层神经元的个数均为超参数;
[0025]所述判别器包括三个任务,第一个任务用于判断输入的患者特征是真的还是判别器生成的;第二个任务基于生存数据预测结局类型;第三个任务基于生存数据预测生存时间。
[0026]优选的,所述Survival

GAN模型的训练步骤如下:
[0027]设置生成器的超参数:Embedding输出的维度、随机噪声的维度、全连接层的层数和每层的神经元个数、学习率和优化器;
[0028]设置判别器的超参数:全连接层的层数和每层的神经元个数、学习率和优化器;
[0029]设置其余超参数:训练轮数和batch_size,batch_size为一次训练所抓取的训练样本的数量;
[0030]数据拼接:从标准正态分布中随机获取m个噪声数据,输入的m个真实生存数据的标签经过Embedding层编码后与噪声数据进行拼接,得到数据C
i

[0031]计算生成器的总损失:
[0032]L
G
=L
G1
+L
G2
+L
G3

[0033]式中:L
G
表示生成器的总损失,L
G1
、L
G2
和L
G3
均表示损失函数;
[0034]生成器训练参数的更新:基于生成器总的损失函数以及预设的学习率对生成器的训练参数进行更新;
[0035]计算判别器的总损失:
[0036]L
D
=L
D1
+L
D2
+L
D3

[0037]式中:L
D
表示判别器的总损失,L
D1
、L
D2
和L
D3
均表示损失函数;
[0038]判别器的训练参数更新:基本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,包括以下步骤:步骤1:获取癌症患者的生存数据,形成癌症患者的生存数据集,并将生存数据集中的部分生存数据作为训练集;步骤2:基于训练好的Survival

GAN模型对训练集中的生存数据进行数据增强;步骤3:构建基于多任务学习的癌症生存分析模型,基于增强后的训练集数据训练癌症生存分析模型;使用网格搜索法并配合五折交叉验证搜索出癌症生存分析模型的最优超参数,并用最优超参数重新训练癌症分析模型;步骤4:将所需分析的癌症患者的特征输入所构建的癌症生存分析模型中,得到癌症患者在未来一段时间内的每个时刻出现不同结局的概率。2.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述癌症患者的生存数据包括患者特征、观察时间以及最后一次随访时间的结局类型;若患者最后一次随访时间没有出现任何结局则观察时间为患者的删失时间;若患者最后一次随访时间出现了结局则观察时间为患者的生存时间。3.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述步骤2包括以下步骤:步骤2.1:根据获取的生存数据是否出现结局,将训练集中的癌症患者生存数据分为删失和出现结局的两大群体,并分别记录该两大群体的个数;步骤2.2:基于出现结局的生存数据训练Survival

GAN模型;步骤2.3:使用网格搜索法并配合五折交叉验证,搜索出Survival

GAN模型的最优超参数,并用最优超参数重新训练Survival

GAN模型;步骤2.4:从训练集样本中随机选取K个真实的存活时间与K个不同的结局分别进行配对;依次将K个配对结果输入到Survival

GAN模型中,生成K个出现结局的生存数据;步骤2.5:N2自增K,即N2=N2+K,N2表示出现结局的生存数据个数;步骤2.6:判断N2是否小于N1,若不是,则直接结束;若是,则返回到步骤2.4继续执行,直至满足N2大于N1;其中N1表示删失数据的个数。4.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,Survival

GAN模型包括生成器和判别器;所述生成器包括全连接网络,全连接网络的全连接层的层数和每层神经元的个数均为超参数;所述判别器为多任务全连接网络,判别器的全连接层的层数和每层神经元的个数均为超参数;所述判别器包括三个任务,第一个任务用于判断输入的患者特征是真的还是生成器生成的;第二个任务基于生存数据预测结局类型;第三个任务基于生存数据预测生存时间。5.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述Survival

GAN模型的训练步骤如下:设置生成器的超参数:Embedding输出的维度、随机噪声的维度、全连接层的层数和每层的神经元个数、学习率和优化器;设置判别器的超参数:全连接层的层数和每层的神经元个数、学习率和优化器;设置其余超参数:训练轮数和batch_size,batch_size为一次训练所抓取的训练样本
的数量;数据拼接:从标准正态分布中随机获取m个噪声数据,输入的m个真实生存数据的标签经过Embedding层编码后与噪声数据进行拼接,得到数据C
i
;计算生成器的总损失:L
G
=L
G1
+L
G2
+L
G3
;式中:L
G
表示生成器的总损失,L
G1
、L
G2
和L
G3
均表示损失函数;生成器训练参数的更新:基于生成器总的损失函数以及预设的学习率对生成器的训练参数进行更新;计算判别器的总损失:L
D
=L
D1
+L
D2
+L
D3
;式中:L
D
表示判别器的总损失,L
D1
、L
D2
和L
D3
均表示损失函数;判别器的训练参数更新:基于判别器的总损失函数以及预设的学习率对判别器的训练参数进行更新;生成器以及判别器训练的结束:判断训练轮数是否达到指定次数,若是则结束生成器以及判别器的训练,若为否,则继续执行判别器和生成器的训练,直至符合指定的训练次数。6.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述癌症生存分析模型包括专家网络、任务网络、注意力网络和辅助任务网络。7.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述癌症生存分析模型的训练步骤如下所述:A.设置超参数:设置任务网络、辅助任务网络、专家网络和注意网络的全连接层数和每层神经元的个数、学习率、优化器、训练轮数、batch_size、预测时刻的个数和4个损失函数的权重;B.预设batch_size的值为m,患者的结局类型一共有K种;每个批次的训练过程中,将m个患者的生存数据输入到癌症生存分析模型中进行训练;C.计算癌症生存分析模型的损失:癌症生存分析模型的总损失函数L
s
表示为:L
s
=λ1·
L
s1
+λ2·
L
s2
+λ3·
L
s3
+λ4·
L
s4
;式中:...

【专利技术属性】
技术研发人员:邱航阳旭菻杨萍
申请(专利权)人:电子科技大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1