一种基于迁移学习的时序数据预测方法、装置及存储介质制造方法及图纸

技术编号:32819112 阅读:25 留言:0更新日期:2022-03-26 20:16
本申请公开一种基于迁移学习的时序数据预测方法、装置及存储介质,涉及机器学习技术领域和数据处理技术领域,能够提高迁移后的模型的训练效率和时序数据预测的准确度。该方法包括:获取目标域数据;目标域数据是时序数据,目标域数据包括至少一个带标签的目标域数据和多个未带标签的目标域数据;将源域特征提取模型的网络参数,迁移到待训练的目标域特征提取模型;源域特征提取模型包括:特征提取层和元学习层;根据至少一个带标签的目标域数据,对待训练的目标域特征提取模型中的元学习层进行训练,得到目标域特征提取模型;根据多个未带标签的目标域数据、目标域特征提取模型和标签预测模型,进行时序数据预测。进行时序数据预测。进行时序数据预测。

【技术实现步骤摘要】
一种基于迁移学习的时序数据预测方法、装置及存储介质


[0001]本申请涉及机器学习
和数据处理
,尤其涉及一种基于迁移学习的时序数据预测方法、装置及存储介质。

技术介绍

[0002]目前,基于时序数据的机器学习问题受到越来越多的关注。在时序数据的预测中,存在这样一个难题:若目标场景下的标注数据较少,且难于获取,那么如何基于目标场景下少量的标注数据,来进行模型训练。
[0003]现有技术常使用迁移学习的方法来解决上述问题,通过利用相似场景下训练好的模型,进行模型迁移和模型微调等,来加速目标场景下的模型的训练。但是,现有技术提供的方法存在以下不足:在对迁移后的模型进行模型微调时,往往需要对整个模型中的网络参数进行调整,加大了模型训练过程的复杂度,导致模型训练效率较低。

技术实现思路

[0004]本申请提供一种基于迁移学习的时序数据预测方法、装置及存储介质,能够提高迁移后的模型的训练效率和时序数据预测的准确度。
[0005]第一方面,本申请提供一种基于迁移学习的时序数据预测方法,包括:获取目标域数据;其中,目标域数据是时序数据,目标域数据包括至少一个带标签的目标域数据和多个未带标签的目标域数据;将源域特征提取模型的网络参数,迁移到待训练的目标域特征提取模型;其中,源域特征提取模型与目标域特征提取模型的网络结构相同,源域特征提取模型包括:特征提取层和元学习层,特征提取层用于提取时序数据的特征,元学习层用于学习特征提取层提取特征的能力;根据至少一个带标签的目标域数据,对待训练的目标域特征提取模型中的元学习层进行训练,得到目标域特征提取模型;根据多个未带标签的目标域数据、目标域特征提取模型和标签预测模型,进行时序数据预测。
[0006]基于本申请提供的技术方案,至少可以产生以下有益效果:通过构建包括元学习层的源域特征提取模型,经过模型迁移和模型训练(即对迁移后的模型进行模型微调),得到目标域特征提取模型,然后根据目标域特征提取模型和标签预测模型,对目标域中未带标签的时序数据进行标签预测。由于本申请实施例提供的元学习层能够学习源域特征提取模型中特征提取层提取特征的能力,因此,在进行模型训练时,只训练元学习层的网络参数即可,相比现有技术中需要对整个模型中的网络参数进行调整的技术方案,能够减少迁移后的模型所需调整的网络参数的数量,让迁移后的模型在目标域上快速收敛,提高了迁移后的模型的训练效率。此外,通过对元学习层进行训练和优化,可以使元学习层提取的特征比特征提取层提取的特征更加准确,能够提高时序数据预测的准确度。
[0007]可选的,上述元学习层的网络结构为残差网络,该残差网络中包括卷积层;上述根据至少一个带标签的目标域数据,对待训练的目标域特征提取模型中的元学习层进行训练,包括:根据至少一个带标签的目标域数据,对元学习层中卷积层的卷积参数进行训练。
[0008]可选的,上述目标域特征提取模型中除元学习层以外的其他网络层的网络参数,与源域特征提取模型中的其他网络层的网络参数相同。
[0009]可选的,待训练的源域特征提取模型还包括线性Linear层;Linear层用于在模型训练的过程中进行标签预测;上述方法还包括:获取源域数据;源域数据是时序数据;源域数据包括至少一个带标签的源域数据;将带标签的源域数据输入特征提取层,得到带标签的源域数据的第一特征;将带标签的源域数据的第一特征输入元学习层,得到带标签的源域数据的第二特征;将带标签的源域数据的第二特征输入Linear层,得到带标签的源域数据的预测标签;根据带标签的源域数据的预测标签和带标签的源域数据的真实标签,确定模型的损失值;根据模型的损失值进行模型优化,得到源域特征提取模型。
[0010]可选的,上述根据多个未带标签的目标域数据、目标域特征提取模型和标签预测模型,进行时序数据预测,包括:将多个未带标签的目标域数据输入目标域特征提取模型,得到多个未带标签的目标域数据的特征;将多个未带标签的目标域数据的特征,输入标签预测模型,预测出多个未带标签的目标域数据的标签。
[0011]可选的,上述标签预测模型为机器学习模型。由于现有技术中,通常在特征提取模型之后添加复杂的神经网络层或者单层的全连接层来进行标签预测,即将特征提取层提取的特征输入复杂的神经网络层或者单层的全连接层,得到预测标签。这种情况,在目标域的数据量较多的情况下,可以通过建立大量的训练样本进行模型训练,得到较为准确的标签预测模型;但是,在目标域的数据量较少的情况下,训练样本较少,会使得复杂神经网络产生过拟合现象,或者,会造成单层的全连接层的预测结果不准确。因此,本申请提供的标签预测模型采用机器学习模型,利用了机器学习模型仅根据少量训练集,即可完成模型训练的特点,使得在目标域数据量较少的情况下,通过少量的训练样本,训练得到相比于现有技术提供的复杂的神经网络层或者单层的全连接层来说,更加准确、泛化能力更强的标签预测模型。
[0012]第二方面,本申请提供一种基于迁移学习的时序数据预测装置,包括:获取模块,用于获取目标域数据;其中,目标域数据是时序数据,目标域数据包括至少一个带标签的目标域数据和多个未带标签的目标域数据;迁移模块,用于将源域特征提取模型的网络参数,迁移到待训练的目标域特征提取模型;其中,源域特征提取模型与目标域特征提取模型的网络结构相同,源域特征提取模型包括:特征提取层和元学习层,特征提取层用于提取时序数据的特征,元学习层用于学习特征提取层提取特征的能力;训练模块,用于根据至少一个带标签的目标域数据,对待训练的目标域特征提取模型中的元学习层进行训练,得到目标域特征提取模型;预测模块,用于根据多个未带标签的目标域数据、目标域特征提取模型和标签预测模型,进行时序数据预测。
[0013]可选的,上述元学习层的网络结构为残差网络,该残差网络中包括卷积层;上述训练模块,具体用于根据至少一个带标签的目标域数据,对元学习层中卷积层的卷积参数进行训练。
[0014]可选的,上述目标域特征提取模型中除元学习层以外的其他网络层的网络参数,与源域特征提取模型中的其他网络层的网络参数相同。
[0015]可选的,待训练的源域特征提取模型还包括线性Linear层;Linear层用于在模型训练的过程中进行标签预测;上述获取模块,还用于获取源域数据;源域数据是时序数据;
源域数据包括至少一个带标签的源域数据;上述训练模块,还用于将带标签的源域数据输入特征提取层,得到带标签的源域数据的第一特征;将带标签的源域数据的第一特征输入元学习层,得到带标签的源域数据的第二特征;将带标签的源域数据的第二特征输入Linear层,得到带标签的源域数据的预测标签;根据带标签的源域数据的预测标签和带标签的源域数据的真实标签,确定模型的损失值;根据模型的损失值进行模型优化,得到源域特征提取模型。
[0016]可选的,上述预测模块,具体用于将多个未带标签的目标域数据输入目标域特征提取模型,得到多个未带标签的目标域数据的特征;将多个未带标签的目标域数据的特征,输入标签预本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于迁移学习的时序数据预测方法,其特征在于,包括:获取目标域数据;其中,所述目标域数据是时序数据,所述目标域数据包括至少一个带标签的目标域数据和多个未带标签的目标域数据;将源域特征提取模型的网络参数,迁移到待训练的目标域特征提取模型;其中,所述源域特征提取模型与所述目标域特征提取模型的网络结构相同,所述源域特征提取模型包括:特征提取层和元学习层,所述特征提取层用于提取特征,所述元学习层用于学习所述特征提取层提取特征的能力;根据所述至少一个带标签的目标域数据,对所述待训练的目标域特征提取模型中的所述元学习层进行训练,得到所述目标域特征提取模型;根据所述多个未带标签的目标域数据、所述目标域特征提取模型和标签预测模型,进行时序数据预测。2.根据权利要求1所述的方法,其特征在于,所述元学习层的网络结构为残差网络,所述残差网络中包括卷积层;所述根据所述至少一个带标签的目标域数据,对所述待训练的目标域特征提取模型中的所述元学习层进行训练,包括:根据所述至少一个带标签的目标域数据,对所述元学习层中卷积层的卷积参数进行训练。3.根据权利要求2所述的方法,其特征在于,所述目标域特征提取模型中除所述元学习层以外的其他网络层的网络参数,与所述源域特征提取模型中的所述其他网络层的网络参数相同。4.根据权利要求1所述的方法,其特征在于,待训练的源域特征提取模型还包括线性Linear层;所述Linear层用于在模型训练的过程中进行标签预测;所述方法还包括:获取源域数据;所述源域数据是时序数据;所述源域数据包括至少一个带标签的源域数据;将所述带标签的源域数据输入所述特征提取层,得到所述带标签的源域数据的第一特征;将所述带标签的源域数据的第一特征输入所述元学习层,得到所述带标签的源域数据的第二特征;将所述带标签的源域数据的第二特征输入所述Linear层,得到所述带标签的源域数据的预测标签;根据所述带标签的源域数据的所述预测标签和所述带标签的源域数据的真实标签,确定模型的损失值;根据所述模型的损失值进行模型优化,得到所述源域特征提取模型。5.根据权利要求1所述的方法,其特征在于,所述根据所述多个未带标签的目标域数据、所述目标域特征提取模型和标签预测模型,进行时序数据预测,包括:将所述多个未带标签的目标域数据输入所述目标域特征提取模型,得到所述多个未带标签的目标域数据的特征;将所述多个未带标签的目标域数据的特征,输入所述标签预测模型,预测出所述多个
未带标签的目标域数据的标签。6.根据权利要求1所述的方法,其特征在于,所述标签预测模型为机器学习模型。7.一种基于迁移学习的时序数据预测装置,其特征在于,包括:获取模块,用于获取目标域数据;其中,所述目标域数据是时序数据,所述目标域数据包括至少一个带标签的目标域数据和多个未带标签的目标域数据;迁移模块,用于将源域特征提取模型的网络参数...

【专利技术属性】
技术研发人员:姜伟浩曹进
申请(专利权)人:杭州海康威视数字技术股份有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1