一种基于深度嵌入聚类元学习的流行病预测方法技术

技术编号:35353642 阅读:35 留言:0更新日期:2022-10-26 12:26
本发明专利技术公开了一种基于深度嵌入聚类元学习的流行病预测方法,包括以下步骤:S1、获取历史数据,将历史数据切分为与目标地区数据长度相匹配的多个时间序列片段,每个时间序列片段均包括历史片段部分和未来片段部分;S2、针对每个时间序列片段,分别对其历史片段部分和未来片段部分进行标准化,并获取时间序列片段的特征集合;S3、基于无监督聚类模型对时间序列片段进行聚类,获得多个类,采样p个类构造元训练集,并获取元知识,基于元知识对新任务模型参数初始化,并通过元训练集对初始化后的新任务模型进行训练;S4、获取预测模型,初始化参数,通过多步梯度下降进行适应优化,进而针对元测试集中的新任务,对流行病发展进行预测。对流行病发展进行预测。对流行病发展进行预测。

【技术实现步骤摘要】
一种基于深度嵌入聚类元学习的流行病预测方法


[0001]本专利技术涉及流行病预测
,更具体的说是涉及一种基于深度嵌入聚类元学习的流行病预测方法。

技术介绍

[0002]目前,用于预测流感或其他时间序列数据的机器/深度学习主要分为两类。首先,一些研究人员专注于寻找有效的“特征”。例如,搜索引擎查询数据用于预测Google Flu Trends1中的流感。Twitter数据也用于其他研究论文。然而,这些模型通常受到来自互联网搜索等大量信息的不可靠来源的困扰。例如,谷歌的算法很容易过度拟合与流感无关的季节性术语,比如“高中篮球”。这个例子也证明了模型可解释性的重要性。其次,其他研究人员专注于寻找有效的“模型”,例如RF、Gradient Boosting、Multilayer Perceptron(MLP)、长短期记忆(LSTM)、变压器(TFR)等。基于深度学习的方法,例如Transformer因其准确性而受到更多关注,而它们中的大多数都因可解释性差而受苦。此外,统计模型和动态分析模型被认为是用于模拟流感感染模式的易于访问的工具,例如SI、SIS、SIR模型及其变体。然而,它们的参数会发生变化,并且参数的近似是困难的,例如基本再生数R0、人口流动性等。DEFSI将深度神经网络方法与因果模型相结合,以解决高分辨率ILI发病率预测。然而,这些模型中的大多数都严重依赖外部数据来提高准确性,例如经度和纬度以及气候信息
[0003]因此,提供一种基于深度嵌入聚类元学习的流行病预测方法,基于历史数据,针对疫情新爆发地区,利用少量初期数据,预测未来疫情发展情况是本领域技术人员亟需解决的问题。

技术实现思路

[0004]有鉴于此,本专利技术提供了一种基于深度嵌入聚类元学习的流行病预测方法;利用多个地区疫情传播的时间序列片段学习细粒度的传播模式,并可将学习到的传播模式用于新爆发疫情且仅存在少量历史数据地区的未来预测,仅需要很少的领域知识去构建元学习任务,并具有很好的可解释性;采用基于MAML的无监督元学习方法,将疾病传播模型从疫情传播稳定的地区迁移到疫情处于早期阶段的另一个地区。
[0005]为了实现上述目的,本专利技术采用如下技术方案:
[0006]一种基于深度嵌入聚类元学习的流行病预测方法,包括以下步骤:
[0007]S1、获取历史数据,将历史数据切分为与目标地区数据长度相匹配的多个时间序列片段,每个时间序列片段均包括历史片段部分和未来片段部分;
[0008]S2、针对每个时间序列片段,分别对其历史片段部分和未来片段部分进行标准化,并获取时间序列片段的特征集合;
[0009]S3、基于无监督聚类模型对时间序列片段进行聚类,获得多个类,采样p个类构造元训练集,并获取元知识,基于元知识对新任务模型参数初始化,并通过元训练集对初始化后的新任务模型进行训练;
[0010]S4、获取预测模型,初始化参数,通过多步梯度下降进行适应优化,进而针对元测试集中的新任务,对流行病发展进行预测。
[0011]优选的,所述步骤S1具体包括:
[0012]获取目标地区i长度为T的已知历史时间序列信息x
i
,将时间序列信息x
i
切分为多个长度为ω+ΔT的时间序列片段集合
[0013][0014]其中,M为地区的数量,T
i
为地区i的历史时序数据总长度,为地区i在时刻t的时间序列片段,为时间序列片段在t时刻前的ω个数据,即历史片段部分,其与目标地区i的已知观测数据对齐,为时间序列片段在t时刻后的ΔT个数据,即未来片段部分,与待预测数据对齐。
[0015]优选的,所述步骤S2具体包括:
[0016]S21、分别对历史片段部分和未来片段部分进行标准化:
[0017][0018][0019]其中,分别为时间序列片段的历史片段部分和未来片段部分的均值,分别为时间序列片段的历史片段部分和未来片段部分的方差,将时间序列片段标准化到0和1之间;
[0020]S22、对于时间序列片段基于CNN和RNN提取其序列局部特征和时序特征,时间序列片段中的历史片段部分对应已知数据的特征所在,因此时间序列的片段的嵌入表示仅从该部分特征中学习,将时间序列片段集合示仅从该部分特征中学习,将时间序列片段集合投影到嵌入空间Z中,生成时间序列片段的特征集合
[0021][0022]其中,ξ(
·
)为特征编码器,其由CNN和RNN两部分组成为CNN特征提取操作,用于提取时间序列片段的局部特征,为RNN特征提取操作,用于提取时间序列片段的时序特征,θ
c
,θ
r
分别为CNN模型参数和RNN模型参数。
[0023]优选的,所述步骤S3具体包括:
[0024]S31、对时间序列片段进行聚类,并学习他们的嵌入,基于深度聚类模型IDEC,采
用聚类损失来实现对给定输入进行聚类:
[0025][0026]其中,q
ij
表示由学生t分布测量的时间序列片段z
i
与聚类中心μ
j
的相似度,p
ij
是聚类的目标分布;
[0027]按时间序列片段特征集合进行聚类,得到时间序列片段数据集合的一个划分每个聚类都是多个时间序列片段特征的集合,聚类操作定义为:
[0028][0029][0030]其中,l为所有类别的总数,P
i
为第i个聚类簇,|P
i
|表示第i个聚类簇中元素的个数,z为P
i
中的元素,为l个类别的中心点,||
·
||为二范数;
[0031]S32、采样p个聚类构建元训练任务集M
train
={D1,D2,

,D
p
}表示为p种传播模式,每个聚类D
i
分为Query
i
和Support
i
两部分,并对应一个预测任务其中,Support
i
用于任务的学习适应,即用于基础学习器更新,Query
i
用于更新元学习器参数;
[0032]采用最小均方误差作为预测损失:
[0033][0034]其中,y为真实流行病确诊病例数,为模型预测结果。
[0035]基学习器学习阶段,每个任务对应一个基学习器,基于Support
i
数据,基学习器计算损失利用梯度下降最小化损失,找到使损失最小化的最优参数集:
[0036][0037]其中,θ'
i
为任务i的最优参数,θ为模型初始参数,α为超参数,为任务i的梯度;
[0038]元学习阶段,使用Query
i
数据,基于基学习器学到的最优参数θ'
i
,元学习器计算相对于这些最优参数θ'
i
的梯度,更新随机初始化的参数θ,即元知识本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于深度嵌入聚类元学习的流行病预测方法,其特征在于,包括以下步骤:S1、获取历史数据,将历史数据切分为与目标地区数据长度相匹配的多个时间序列片段,每个时间序列片段均包括历史片段部分和未来片段部分;S2、针对每个时间序列片段,分别对其历史片段部分和未来片段部分进行标准化,并获取时间序列片段的特征集合;S3、基于无监督聚类模型对时间序列片段进行聚类,获得多个类,采样p个类构造元训练集,并获取元知识,基于元知识对新任务模型参数初始化,并通过元训练集对初始化后的新任务模型进行训练;S4、获取预测模型,初始化参数,通过多步梯度下降进行适应优化,进而针对元测试集中的新任务,对流行病发展进行预测。2.根据权利要求1所述的一种基于深度嵌入聚类元学习的流行病预测方法,其特征在于,所述步骤S1具体包括:获取目标地区i长度为T的已知历史时间序列信息x
i
,将时间序列信息x
i
切分为多个长度为ω+ΔT的时间序列片段集合度为ω+ΔT的时间序列片段集合其中,M为地区的数量,T
i
为地区i的历史时序数据总长度,为地区i在时刻t的时间序列片段,为时间序列片段在t时刻前的ω个数据,即历史片段部分,其与目标地区i的已知观测数据对齐,为时间序列片段在t时刻后的ΔT个数据,即未来片段部分,与待预测数据对齐。3.根据权利要求1所述的一种基于深度嵌入聚类元学习的流行病预测方法,其特征在于,所述步骤S2具体包括:S21、分别对历史片段部分和未来片段部分进行标准化:进行标准化:其中,分别为时间序列片段的历史片段部分和未来片段部分的均值,分别为时间序列片段的历史片段部分和未来片段部分的方差,将时间序列片段标准化到0和1之间;S22、对于时间序列片段基于CNN和RNN提取其序列局部特征和时序特征,时间序列片段中的历史片段部分对应已知数据的特征所在,因此时间序列的片段的嵌入表示仅从
该部分特征中学习,将时间序列片段集合该部分特征中学习,将时间序列片段集合投影到嵌入空间Z中,生成时间序列片段的特征集合的特征集合其中,ξ(
·
)为特征编码器,其由CNN和RNN两部分组成为CNN特征提取操作,用于提取时间序列片段的局部特征,为RNN特征提取操作,用于提取时间序列片段的时序特征,θ
c
,θ
r
分别为CNN模型参数和RNN模型参数。4.根据权利要求1所述的一种基于深度嵌入聚类元学习的流行病预测方法,其特征在于,所述步骤S3具体包括:S31、对时间序列片段进行聚类,并学习他们的嵌入,基于深度聚类模型IDEC,采用聚类损失来实现对给定输入进行聚类:其中,q
ij
表示由学生t分布测量的时间序列片段z
...

【专利技术属性】
技术研发人员:赵学臣张中苗金凤杨福强
申请(专利权)人:山东女子学院
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1