一种基于元学习的概率域泛化学习方法技术

技术编号:24208926 阅读:13 留言:0更新日期:2020-05-20 15:56
本发明专利技术公开了一种基于元学习的概率域泛化学习方法,属于元学习领域,一种基于元学习的概率域泛化学习方法,可以实现首次将元学习思想结合到域泛化中,利用元学习框架解决域泛化中随着源域数目增加参数线性增加的问题;首次将变分信息瓶颈思想结合到元学习和域泛化中,可以进一步增加本专利的泛化能力;本方案可以通过元学习解决参数随着源域数目线性增加问题,并通过元学习框架,可以更加精确地获取域不变的特征表示,为了进一步增加本方案的域泛化性能,本方案将变分思想和信息瓶颈结合,将其融入到一个统一的概率框架中,形成一种全新的,并及其有效的基于元学习的概率域泛化学习方法。

A generalization learning method of probability domain based on meta learning

【技术实现步骤摘要】
一种基于元学习的概率域泛化学习方法
本专利技术涉及元学习领域,更具体地说,涉及一种基于元学习的概率域泛化学习方法。
技术介绍
传统的机器学习假设训练数据与测试数据服从相同的数据分布,这个条件在实际应用中很难得到满足。解决这个问题有几种经典方法,包括1)迁移学习:迁移学习的目标是将从一个环境中学到的知识用来帮助新环境中的学习任务;2)域自适应:域自适应学习的重点在于如何克服源域分布和目标域分布不同,实现目标域上的学习任务;3)域泛化:目标域不可知的情况下,使得分布或者模型对未知情况具备良好特性。这几类方法的难度是递增的,在本专利技术中,提出一种基于变分信息瓶颈元学习的概率域泛化学习方法,就是针对域泛化的方法。目前针对域泛化的方法主要有1)基于特征的方法,其只要通过设计跨域不变特征实现域泛化;2)基于分类器的方法,其针对每个数据集也就是源域中的每一个子域对子分类器进行设计,然后将子分类器结合成一个融合分类器来实现;3)信息瓶颈:任何神经网络可以通过隐层与输入和输出变量之间的共享信息(mutualinformation)来量化,深度学习的目标就是在学习的过程中最大化地压缩输入信息,最大化地保留输出信息。信息瓶颈就是通过控制输入和输出变量之间的共享信息达到泛化的目的。域泛化方法的关键是权衡源域到目标域之间的变化。前述方法或者是尽量抽取对域变化不敏感的特征表示或者是通过在每一个域学习得到一个模型,然后选择和目标域相近的模型进行预测。这些方法中,参数的数量会随着源域的增加而线性增加,从而在数据不充分的应用中,很容易出现过拟合现象。
技术实现思路
1.要解决的技术问题针对现有技术中存在的问题,本专利技术的目的在于提供一种基于元学习的概率域泛化学习方法,它可以实现首次将元学习思想结合到域泛化中,利用元学习框架解决域泛化中随着源域数目增加参数线性增加的问题;首次将变分信息瓶颈思想结合到元学习和域泛化中,可以进一步增加本专利的泛化能力;本方案可以通过元学习解决参数随着源域数目线性增加问题,并通过元学习框架,可以更加精确地获取域不变的特征表示,为了进一步增加本方案的域泛化性能,本方案将变分思想和信息瓶颈结合,将其融入到一个统一的概率框架中,形成一种全新的,并及其有效的基于元学习的概率域泛化学习方法。2.技术方案为解决上述问题,本专利技术采用如下的技术方案。一种基于元学习的概率域泛化学习方法,包括以下步骤:输入:具有K个源域的训练数据集S,学习率λ,迭代次数Niter;输出:参数θ,包括一个特征提取网络h的参数和两个推理网络g1和g2参数;分类模型参数ψ;S1、从K个源域随机选取一个作为目标域,其余K-1个作为源域;S2、从每一个源域Ds中选取包含C个类别的M个样本,表示为S3、从目标域Dt中选取N个样本,表示为S4、对源域数据集Ds第c类的每一个样本利用卷积神经网络提取特征如下:S5、对于源域数据集Ds中的每一个类别的样本,利用置换不变的实例池化操作,得到类别的平均特征S6、重复S4-S5,计算所有类别的平均特征S7、将类别平均特征送入推理网络g1中,计算关于该类别分类器参数ψ的分布;S8、从每个类别分类器的概率分布即中抽取样本,最后构成权重向量ψc;S9、对每一个类别重复S7-S8,并按照列排列构成矩阵如下:ψ=[ψ1,ψ2,...,ψC]S10、将类别平均特征送入推理网络g2中,计算关于隐含变量z的分布;S11、从每个类别的隐含变量的概率分布,即中抽取样本,最后构成隐含向量zc;S12、对目标域数据集Dt第c类的每一个样本利用特征提取网络h提取特征如下:S13、将目标域的每一个类别的每一个特征送入推理网络g2中,计算关于目标域的分布;S14、从每个类别的隐含变量的概率分布,即中抽取样本zj,c;S15、计算每个类别的损失函数如下:S16、重复S12-S15,使其覆盖所有类别。S17、按照如下公式迭代更新参数θ,S18、重复S2-S17,到所有K-1结束。进一步的,所述S7中用高斯分布表示每个类别分类器参数的概率分布,即利用推理网络g1,得到类别c的分类器参数分布的均值和方差进一步的,所述S10中用高斯分布表示每个类别隐含变量的概率分布,即利用推理网络g2,得到类别c的隐含变量分布的均值和方差进一步的,所述S13中用高斯分布表示每个类别隐含变量的概率分布,即利用推理网络g2,得到类别c中样本j的分布的均值和方差进一步的,所述元学习方法在训练阶段后进行需进行元学习的测试。进一步的,所述元学习测试的方法为:输入:参数θ,包括一个特征提取网络h的参数和两个推理网络g1和g2参数;参数ψ,分类模型;待分类的目标域任务;输出:分类结果;步骤1:从目标域T中选取N个样本,表示为步骤2:每一个样本利用特征提取网络h提取特征如下:步骤3:每一个样本的特征送入推理网络g2中,得到样本的分布的均值和方差步骤4:从样本隐含变量的概率分布,即中抽取样本zj;步骤5:利用分类器参数ψ计算分类结果,即ψzj得到的向量中最大维度表示的类别,即为分类结果。3.有益效果相比于现有技术,本专利技术的优点在于:本方案可以实现首次将元学习思想结合到域泛化中,利用元学习框架解决域泛化中随着源域数目增加参数线性增加的问题;首次将变分信息瓶颈思想结合到元学习和域泛化中,可以进一步增加本专利的泛化能力;本方案可以通过元学习解决参数随着源域数目线性增加问题,并通过元学习框架,可以更加精确地获取域不变的特征表示,为了进一步增加本方案的域泛化性能,本方案将变分思想和信息瓶颈结合,将其融入到一个统一的概率框架中,形成一种全新的,并及其有效的基于元学习的概率域泛化学习方法。附图说明图1为本专利技术的元学习中数据/模型关系图;图2为本专利技术的在旋转的MNIST数据库上的10次测试的平均分类准确率数据表;图3为本专利技术的在CLVS四个数据库上的分类准确率数据表。具体实施方式下面将结合本专利技术实施例中的附图;对本专利技术实施例中的技术方案进行清楚、完整地描述;显然;所描述的实施例仅仅是本专利技术一部分实施例;而不是全部的实施例,基于本专利技术中的实施例;本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例;都属于本专利技术保护的范围。在本专利技术的描述中,需要说明的是,术语“上”、“下”、“内”、“外”、“顶/底端”等指示的方位或位置关系为基于附图所示的方位或位置关系,仅是为了便于描述本专利技术和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本专利技术的限制。此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性。本文档来自技高网
...

【技术保护点】
1.一种基于元学习的概率域泛化学习方法,其特征在于:包括以下步骤:/n输入:具有K个源域的训练数据集S,学习率λ,迭代次数N

【技术特征摘要】
1.一种基于元学习的概率域泛化学习方法,其特征在于:包括以下步骤:
输入:具有K个源域的训练数据集S,学习率λ,迭代次数Niter;
输出:参数θ,包括一个特征提取网络h的参数和两个推理网络g1和g2参数;分类模型参数ψ;
S1、从K个源域随机选取一个作为目标域,其余K-1个作为源域;
S2、从每一个源域Ds中选取包含C个类别的M个样本,表示为



S3、从目标域Dt中选取N个样本,表示为



S4、对源域数据集Ds第c类的每一个样本利用卷积神经网络提取特征如下:



S5、对于源域数据集Ds中的每一个类别的样本,利用置换不变的实例池化操作,得到类别的平均特征



S6、重复S4-S5,计算所有类别的平均特征
S7、将类别平均特征送入推理网络g1中,计算关于该类别分类器参数ψ的分布;
S8、从每个类别分类器的概率分布即中抽取样本,最后构成权重向量ψc;
S9、对每一个类别重复S7-S8,并按照列排列构成矩阵如下:
ψ=[ψ1,ψ2,...,ψC]
S10、将类别平均特征送入推理网络g2中,计算关于隐含变量z的分布;
S11、从每个类别的隐含变量的概率分布,即中抽取样本,最后构成隐含向量zc;
S12、对目标域数据集Dt第c类的每一个样本利用特征提取网络h提取特征如下:



S13、将目标域的每一个类别的每一个特征送入推理网络g2中,计算关于目标域的分布;
S14、从每个类别的隐含变量的概率分布,即中抽取样本zj,c;
S15、计算每个类别的损失函数如下:



...

【专利技术属性】
技术研发人员:甄先通张磊李欣左利云简治平蔡泽涛
申请(专利权)人:广东石油化工学院
类型:发明
国别省市:广东;44

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1