一种用于图神经网络的训练方法以及相关设备技术

技术编号:26172839 阅读:23 留言:0更新日期:2020-10-31 13:52
本申请涉及人工智能领域中的小样本学习技术,公开了一种用于图神经网络的训练方法以及相关设备。方法包括:将包括测试图像和N类训练图像的训练图像集合输入图神经网络,以获取第一相似度信息和测试图像的生成分类类别,第一相似度信息指示测试图像的特征与N类中每类训练图像的特征之间的相似度;根据第一损失函数和第二损失函数对图神经网络进行训练;第一损失函数的目标为提高测试图像的特征与正确分类类别的训练图像的特征之间的相似度,第二损失函数的目标为拉近生成分类类别和正确分类类别的相似度,增设第一损失函数的约束,更加充分利用训练阶段的样本中的信息,提高图神经网络的特征表达能力,以提高图神经网络输出结果的精度。

【技术实现步骤摘要】
一种用于图神经网络的训练方法以及相关设备
本申请涉及人工智能领域,尤其涉及一种用于图神经网络的训练方法以及相关设备。
技术介绍
人工智能(ArtificialIntelligence,AI)是利用计算机或者计算机控制的机器模拟、延伸和扩展人的智能。人工智能包括研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。目前,对基于深度学习的图神经网络进行小样本学习是人工智能一种可行的研究方向。小样本学习指的是神经网络在预学习了一定已知类别的大量样本后,对于新的类别,只需要少量的标记样本就能够实现快速学习。但由于小样本学习中采用的新类别的样本较少,通过小样本学习的方式对图神经网络进行训练,在训练结束时,得到的图神经网络的权重参数往往不够优良,训练后的图神经网络的特征表达能力受到限制,进而影响整个图神经网络输出的处理结果的精度。因此,一种能够提高图神经网络的分类精度的小样本学习方法亟待推出。
技术实现思路
本申请实施例提供了一种用于图神经网络的训练方法以及相关设备,在训练阶段增设第一损失函数,可以更加充分利用训练阶段的样本中的信息,且第一损失函数的训练目标为提高测试图像的特征与正确类别的训练图像的图节点中心之间的相似度,能够提高图神经网络的特征表达能力,进而提高整个图神经网络输出的处理结果的精度。为解决上述技术问题,本申请实施例提供以下技术方案:第一方面,本申请实施例提供一种图神经网络的训练方法,可用于人工智能领域的小样本学习领域中。方法包括:训练设备获取训练图像集合,训练图像集合中包括h个测试图像和N类训练图像,每类训练图像中包括至少一个训练图像,N为大于或等于1的整数。训练设备将训练图像集合输入图神经网络,以获取第一相似度信息和与测试图像对应的生成分类类别。其中,第一相似度信息指示训练图像集合中每个测试图像的特征与N类训练图像中每类训练图像的图节点中心之间的相似度,一类训练图像的图节点中心指示一类训练图像的特征;第一相似度信息具体可以表现为一个h乘N的矩阵,矩阵中每个元素指代一个测试图像的特征与一类训练图像的图节点中心的相似度。若训练图像集合中包括一个测试图像,则输出前述一个测试图像的生成分类类别;若训练图像集合中包括多个测试图像,则输出前述多个测试图像中每个测试图像的生成分类类别。训练设备根据第一相似度信息、第一损失函数、生成分类类别和第二损失函数,对图神经网络进行训练。其中,第一损失函数的训练目标为提高测试图像的特征与第一类别的训练图像的图节点中心之间的第一相似度,第二损失函数的训练目标为拉近生成分类类别和第一类别的相似度,第一类别为N类中测试图像的正确分类类别。本实现方式中,在第二损失函数的基础上增加了第一损失函数,来进一步约束图神经网络的训练过程,可以更加充分利用训练阶段的样本中的信息;此外,由于若测试图像的特征与正确分类类别的特征的相似度越高,则证明图神经网络的特征表达能力越强,且整个图神经网络越容易将测试图像分类至正确类别,从而提高整个图神经网络输出的处理结果的精度。在第一方面的一种可能实现方式中,训练设备获取第一相似度信息,包括:训练设备通过图神经网络计算第二相似度信息。其中,第二相似度信息指示训练图像集合中任意两个图像的特征之间的相似度;N类训练图像中包括s个训练图像,也即训练图像集合中包括(h+s)个图像,则第二相似度信息可以为(h+s)乘(h+s)的相似度矩阵,相似度矩阵中的每一行代表一个图像的特征和其他图像的特征之间的相似度。训练设备根据第二相似度信息,生成第一相似度信息。其中,测试图像的特征与第二类别的训练图像的图节点中心之间的相似度为以下中的任一项:测试图像的特征与第二类别的训练图像中每个训练图像的特征之间相似度的平均值、测试图像的特征与第二类别的训练图像中每个训练图像的特征之间相似度的最大值和测试图像的特征与第二类别的训练图像中每个训练图像的特征之间相似度的最小值,第二类别为N类中的任一类。本实现方式中,提供了第一相似度信息的三种表达形式,提高了本方案的实现灵活性;且不需要实际计算每类训练图像的图节点中心,而是直接根据第二相似度信息来计算第一相似度信息,充分利用了图神经网络计算过程中的信息,有利于提高训练阶段的效率。在第一方面的一种可能实现方式中,测试图像的特征与第二类别的训练图像中每个训练图像的特征之间相似度的平均值。训练设备根据第二相似度信息,生成第一相似度信息,包括:训练设备从第二相似度信息中获取第三相似度信息。其中,第三相似度信息指示训练图像集合中任意测试图像与任意训练图像之间的相似度,第三相似度信息具体可以表现为一个h乘s的矩阵,该矩阵中每一行指示一个测试图像与每个训练图像之间的相似度。训练设备将第三相似度信息与第一矩阵相乘,得到第二矩阵。其中,第一矩阵包括训练图像集合中所有训练图像的独热编码,一个训练图像的独热编码具体可以表现为一个包括N个元素的向量,用于指示训练图像在N类类别中的正确分类类别。训练设备将第二矩阵与N类中每类训练图像的个数相除,得到第一相似度信息;N类中每类训练图像的个数具体可以表现为一个包含N个元素的向量,该向量中一个元素代表N类中一类训练图像中的图像个数。本实现方式中,公开了第一相似度信息的具体生成方式,提高了本方案与具体应用场景的结合程度;此外,一次性计算所有测试图像与N类中每类训练图像的图节点中心的相似度,也即并非逐个计算第一相似度信息中每个元素的数值,而是一次性计算整个第一相似度信息,提高了训练阶段的效率。在第一方面的一种可能实现方式中,第一损失函数的训练目标为第一相似度与第二相似度之间的差距大于预设阈值,第二相似度为测试图像的特征与第三类别的训练图像的图节点中心之间的相似度,第三类别为N类中测试图像的错误分类类别。其中,第一损失函数具体表现为交叉熵损失函数或三元组损失函数。本实现方式中,将第一损失函数的目标设定为第一相似度与第二相似度之间的差距大于预设阈值,也即不仅需要测试图像的特征与正确分类类别的图节点中心之间的相似度大于测试图像的特征与错误分类类别的图节点中心之间的相似度,而且两者之间的差距要大于预设阈值,以进一步提高训练后的图神经网络的特征表达能力,进而提高训练后的图神经网络的处理结果的精度;用户可以结合实际情况灵活设定预设阈值的取值,提高了本方案的实现灵活性。在第一方面的一种可能实现方式中,第一相似度与第二相似度之间的差距大于预设阈值为以下中的任一项:第一相似度与第二相似度之间的差值大于预设阈值和第一相似度与第二相似度之间的比值大于预设阈值。本实现方式中,提供了第一相似度与第二相似度之间的差距的两种比较方式,提高了本方案的实现灵活性。在第一方面的一种可能实现方式中,方法还包括:训练设备从第一相似度信息中获取第一相似度和第二相似度,并对第一相似度进行缩小处理;具体的,训练设备将第一相似度与预设阈值相减,或者,训练设备将第一相似度与预设阈值相除,得到缩小后的第一相似度。训练设备根据第二相似度和缩小处理后的第一相似度,生成第一损失函数的函数值。本实现方式本文档来自技高网...

【技术保护点】
1.一种图神经网络的训练方法,其特征在于,所述方法包括:/n获取训练图像集合,所述训练图像集合中包括测试图像和N类训练图像,所述N为大于或等于1的整数;/n将所述训练图像集合输入所述图神经网络,以获取第一相似度信息和与测试图像对应的生成分类类别,所述第一相似度信息指示所述训练图像集合中每个测试图像的特征与所述N类训练图像中每类训练图像的图节点中心之间的相似度,一类训练图像的图节点中心指示一类训练图像的特征;/n根据所述第一相似度信息、第一损失函数、所述生成分类类别和第二损失函数,对所述图神经网络进行训练;/n其中,所述第一损失函数的训练目标为提高测试图像的特征与第一类别的训练图像的图节点中心之间的第一相似度,所述第二损失函数的训练目标为拉近所述生成分类类别和所述第一类别的相似度,所述第一类别为所述N类中测试图像的正确分类类别。/n

【技术特征摘要】
1.一种图神经网络的训练方法,其特征在于,所述方法包括:
获取训练图像集合,所述训练图像集合中包括测试图像和N类训练图像,所述N为大于或等于1的整数;
将所述训练图像集合输入所述图神经网络,以获取第一相似度信息和与测试图像对应的生成分类类别,所述第一相似度信息指示所述训练图像集合中每个测试图像的特征与所述N类训练图像中每类训练图像的图节点中心之间的相似度,一类训练图像的图节点中心指示一类训练图像的特征;
根据所述第一相似度信息、第一损失函数、所述生成分类类别和第二损失函数,对所述图神经网络进行训练;
其中,所述第一损失函数的训练目标为提高测试图像的特征与第一类别的训练图像的图节点中心之间的第一相似度,所述第二损失函数的训练目标为拉近所述生成分类类别和所述第一类别的相似度,所述第一类别为所述N类中测试图像的正确分类类别。


2.根据权利要求1所述的方法,其特征在于,所述获取所述第一相似度信息,包括:
通过所述图神经网络计算第二相似度信息,所述第二相似度信息指示所述训练图像集合中任意两个图像的特征之间的相似度;
根据所述第二相似度信息,生成所述第一相似度信息,其中,测试图像的特征与第二类别的训练图像的图节点中心之间的相似度为以下中的任一项:测试图像的特征与第二类别的训练图像中每个训练图像的特征之间相似度的平均值、测试图像的特征与第二类别的训练图像中每个训练图像的特征之间相似度的最大值和测试图像的特征与第二类别的训练图像中每个训练图像的特征之间相似度的最小值,所述第二类别为所述N类中的任一类。


3.根据权利要求2所述的方法,其特征在于,测试图像的特征与第二类别的训练图像中每个训练图像的特征之间相似度的平均值,所述根据所述第二相似度信息,生成所述第一相似度信息,包括:
从所述第二相似度信息中获取第三相似度信息,所述第三相似度信息指示所述训练图像集合中任意测试图像与任意训练图像之间的相似度;
将所述第三相似度信息与第一矩阵相乘,得到第二矩阵,所述第一矩阵包括所述训练图像集合中所有训练图像的独热编码;
将所述第二矩阵与所述N类中每类训练图像的个数相除,得到所述第一相似度信息。


4.根据权利要求1至3任一项所述的方法,其特征在于,所述第一损失函数的训练目标为所述第一相似度与第二相似度之间的差距大于预设阈值,所述第二相似度为测试图像的特征与第三类别的训练图像的图节点中心之间的相似度,所述第三类别为所述N类中测试图像的错误分类类别。


5.根据权利要求4所述的方法,其特征在于,所述第一相似度与所述第二相似度之间的差距大于预设阈值为以下中的任一项:所述第一相似度与所述第二相似度之间的差值大于预设阈值和所述第一相似度与所述第二相似度之间的比值大于预设阈值。


6.根据权利要求4所述的方法,其特征在于,所述方法还包括:
从所述第一相似度信息中获取所述第一相似度和所述第二相似度,并对所述第一相似度进行缩小处理;
根据所述第二相似度和缩小处理后的第一相似度,生成所述第一损失函数的函数值。


7.根据权利要求2所述的方法,其特征在于,所述图神经网络包括一个特征提取网络和至少一个特征更新网络,所述特征提取网络用于对输入的训练图像集合中的图像进行特征提取操作,每个特征更新网络,用于计算所述第二相似度信息,并根据所述每个图像的特征和所述第二相似度信息,进行特征更新操作;
所述通过所述图神经网络计算第二相似度信息,包括:
通过所述至少一个特征更新网络中的第一个特征更新网络计算所述第二相似度信息;
所述根据所述第一相似度信息和第一损失函数,对所述图神经网络进行训练,包括:
根据所述第一相似度信息和第一损失函数,对所述图神经网络中的所述特征提取网络和所述第一个特征更新网络进行训练。


8.根据权利要求1至3任一项所述的方法,其特征在于,所述图神经网络用于进行图像识别或者进行图像分类。


9.一种图像处理方法,其特征在于,所述方法包括:
获取待处理数据,所述待处理数据中包括待分类图像、N类参考图像和所述N类参考图像中每个参考图像的正确分类类别,所述N为大于或等于1的整数;
将所述待处理数据输入图神经网络,以得到所述待分类图像的特征和每个所述参考图像的特征;
根据所述待分类图像的特征、所述参考图像的特征和所述参考图像的正确分类类别,生成所述待分类图像的类别指示信息,所述类别指示信息指示所述待分类图像在所述N类中的生成分类类别;
其中,所述图神经网络为根据第一损失函数和第二损失函数训练得到的,所述第一损失函数的训练目标为提高所述待分类图像的特征与第一类别的参考图像的图节点中心之间的第一相似度,所述第二损失函数的训练目标为拉近所述生成分类类别和所述第一类别的相似度,所述第一类别为所述N类中所述待分类图像的正确分类类别。


10.根据权利要求9所述的方法,其特征在于,待分类图像的特征与第一类别的参考图像的图节点中心之间的第一相似度为以下中的任一项:待分类图像的特征与第一类别的参考图像中每个参考图像的特征之前相似度的平均值、待分类图像的特征与第一类别的参考图像中每个参考图像的特征之前相似度的最大值和待分类图像的特征与第一类别的参考图像中每个参考图像的特征之前相似度的最小值。


11.一种图神经网络的训练装置,其特征在于,所述装置包括:
获取模块,用于获取训练图像集合,所述训练图像集合中包括测试图像和N类训练图像,所述N为大于或等于1的整数;
输入模块,用于将所述训练图像集合输入所述图神经网络,以获取第一相似度信息和与测试图像对应的生成分类类别,所述第一相似度信息指示所述...

【专利技术属性】
技术研发人员:乔宇王亚立陈晨刘健庄岳俊
申请(专利权)人:华为技术有限公司
类型:发明
国别省市:广东;44

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1