【技术实现步骤摘要】
一种用于图神经网络的训练方法以及相关设备
本申请涉及人工智能领域,尤其涉及一种用于图神经网络的训练方法以及相关设备。
技术介绍
人工智能(ArtificialIntelligence,AI)是利用计算机或者计算机控制的机器模拟、延伸和扩展人的智能。人工智能包括研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。目前,对基于深度学习的图神经网络进行小样本学习是人工智能一种可行的研究方向。小样本学习指的是神经网络在预学习了一定已知类别的大量样本后,对于新的类别,只需要少量的标记样本就能够实现快速学习。但由于小样本学习中采用的新类别的样本较少,通过小样本学习的方式对图神经网络进行训练,在训练结束时,得到的图神经网络的权重参数往往不够优良,训练后的图神经网络的特征表达能力受到限制,进而影响整个图神经网络输出的处理结果的精度。因此,一种能够提高图神经网络的分类精度的小样本学习方法亟待推出。
技术实现思路
本申请实施例提供了一种用于图神经网络的训练方法以及相关设备,在训练阶段增设第一损失函数,可以更加充分利用训练阶段的样本中的信息,且第一损失函数的训练目标为提高测试图像的特征与正确类别的训练图像的图节点中心之间的相似度,能够提高图神经网络的特征表达能力,进而提高整个图神经网络输出的处理结果的精度。为解决上述技术问题,本申请实施例提供以下技术方案:第一方面,本申请实施例提供一种图神经网络的训练方法,可用于人工智能领域的小样本学习领域中。方法包括:训练设备获取训 ...
【技术保护点】
1.一种图神经网络的训练方法,其特征在于,所述方法包括:/n获取训练图像集合,所述训练图像集合中包括测试图像和N类训练图像,所述N为大于或等于1的整数;/n将所述训练图像集合输入所述图神经网络,以获取第一相似度信息和与测试图像对应的生成分类类别,所述第一相似度信息指示所述训练图像集合中每个测试图像的特征与所述N类训练图像中每类训练图像的图节点中心之间的相似度,一类训练图像的图节点中心指示一类训练图像的特征;/n根据所述第一相似度信息、第一损失函数、所述生成分类类别和第二损失函数,对所述图神经网络进行训练;/n其中,所述第一损失函数的训练目标为提高测试图像的特征与第一类别的训练图像的图节点中心之间的第一相似度,所述第二损失函数的训练目标为拉近所述生成分类类别和所述第一类别的相似度,所述第一类别为所述N类中测试图像的正确分类类别。/n
【技术特征摘要】
1.一种图神经网络的训练方法,其特征在于,所述方法包括:
获取训练图像集合,所述训练图像集合中包括测试图像和N类训练图像,所述N为大于或等于1的整数;
将所述训练图像集合输入所述图神经网络,以获取第一相似度信息和与测试图像对应的生成分类类别,所述第一相似度信息指示所述训练图像集合中每个测试图像的特征与所述N类训练图像中每类训练图像的图节点中心之间的相似度,一类训练图像的图节点中心指示一类训练图像的特征;
根据所述第一相似度信息、第一损失函数、所述生成分类类别和第二损失函数,对所述图神经网络进行训练;
其中,所述第一损失函数的训练目标为提高测试图像的特征与第一类别的训练图像的图节点中心之间的第一相似度,所述第二损失函数的训练目标为拉近所述生成分类类别和所述第一类别的相似度,所述第一类别为所述N类中测试图像的正确分类类别。
2.根据权利要求1所述的方法,其特征在于,所述获取所述第一相似度信息,包括:
通过所述图神经网络计算第二相似度信息,所述第二相似度信息指示所述训练图像集合中任意两个图像的特征之间的相似度;
根据所述第二相似度信息,生成所述第一相似度信息,其中,测试图像的特征与第二类别的训练图像的图节点中心之间的相似度为以下中的任一项:测试图像的特征与第二类别的训练图像中每个训练图像的特征之间相似度的平均值、测试图像的特征与第二类别的训练图像中每个训练图像的特征之间相似度的最大值和测试图像的特征与第二类别的训练图像中每个训练图像的特征之间相似度的最小值,所述第二类别为所述N类中的任一类。
3.根据权利要求2所述的方法,其特征在于,测试图像的特征与第二类别的训练图像中每个训练图像的特征之间相似度的平均值,所述根据所述第二相似度信息,生成所述第一相似度信息,包括:
从所述第二相似度信息中获取第三相似度信息,所述第三相似度信息指示所述训练图像集合中任意测试图像与任意训练图像之间的相似度;
将所述第三相似度信息与第一矩阵相乘,得到第二矩阵,所述第一矩阵包括所述训练图像集合中所有训练图像的独热编码;
将所述第二矩阵与所述N类中每类训练图像的个数相除,得到所述第一相似度信息。
4.根据权利要求1至3任一项所述的方法,其特征在于,所述第一损失函数的训练目标为所述第一相似度与第二相似度之间的差距大于预设阈值,所述第二相似度为测试图像的特征与第三类别的训练图像的图节点中心之间的相似度,所述第三类别为所述N类中测试图像的错误分类类别。
5.根据权利要求4所述的方法,其特征在于,所述第一相似度与所述第二相似度之间的差距大于预设阈值为以下中的任一项:所述第一相似度与所述第二相似度之间的差值大于预设阈值和所述第一相似度与所述第二相似度之间的比值大于预设阈值。
6.根据权利要求4所述的方法,其特征在于,所述方法还包括:
从所述第一相似度信息中获取所述第一相似度和所述第二相似度,并对所述第一相似度进行缩小处理;
根据所述第二相似度和缩小处理后的第一相似度,生成所述第一损失函数的函数值。
7.根据权利要求2所述的方法,其特征在于,所述图神经网络包括一个特征提取网络和至少一个特征更新网络,所述特征提取网络用于对输入的训练图像集合中的图像进行特征提取操作,每个特征更新网络,用于计算所述第二相似度信息,并根据所述每个图像的特征和所述第二相似度信息,进行特征更新操作;
所述通过所述图神经网络计算第二相似度信息,包括:
通过所述至少一个特征更新网络中的第一个特征更新网络计算所述第二相似度信息;
所述根据所述第一相似度信息和第一损失函数,对所述图神经网络进行训练,包括:
根据所述第一相似度信息和第一损失函数,对所述图神经网络中的所述特征提取网络和所述第一个特征更新网络进行训练。
8.根据权利要求1至3任一项所述的方法,其特征在于,所述图神经网络用于进行图像识别或者进行图像分类。
9.一种图像处理方法,其特征在于,所述方法包括:
获取待处理数据,所述待处理数据中包括待分类图像、N类参考图像和所述N类参考图像中每个参考图像的正确分类类别,所述N为大于或等于1的整数;
将所述待处理数据输入图神经网络,以得到所述待分类图像的特征和每个所述参考图像的特征;
根据所述待分类图像的特征、所述参考图像的特征和所述参考图像的正确分类类别,生成所述待分类图像的类别指示信息,所述类别指示信息指示所述待分类图像在所述N类中的生成分类类别;
其中,所述图神经网络为根据第一损失函数和第二损失函数训练得到的,所述第一损失函数的训练目标为提高所述待分类图像的特征与第一类别的参考图像的图节点中心之间的第一相似度,所述第二损失函数的训练目标为拉近所述生成分类类别和所述第一类别的相似度,所述第一类别为所述N类中所述待分类图像的正确分类类别。
10.根据权利要求9所述的方法,其特征在于,待分类图像的特征与第一类别的参考图像的图节点中心之间的第一相似度为以下中的任一项:待分类图像的特征与第一类别的参考图像中每个参考图像的特征之前相似度的平均值、待分类图像的特征与第一类别的参考图像中每个参考图像的特征之前相似度的最大值和待分类图像的特征与第一类别的参考图像中每个参考图像的特征之前相似度的最小值。
11.一种图神经网络的训练装置,其特征在于,所述装置包括:
获取模块,用于获取训练图像集合,所述训练图像集合中包括测试图像和N类训练图像,所述N为大于或等于1的整数;
输入模块,用于将所述训练图像集合输入所述图神经网络,以获取第一相似度信息和与测试图像对应的生成分类类别,所述第一相似度信息指示所述...
【专利技术属性】
技术研发人员:乔宇,王亚立,陈晨,刘健庄,岳俊,
申请(专利权)人:华为技术有限公司,
类型:发明
国别省市:广东;44
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。