【技术实现步骤摘要】
模型训练方法、装置、设备及存储介质
[0001]本申请实施例涉及人工智能
,尤其涉及一种模型训练方法、装置、设备及存储介质。
技术介绍
[0002]随着人工智能技术的快速发展,各种图神经网络模型应用而生,图神经网络模型的预测精度与训练过程紧密相关,而模型训练与样本的准备相关。例如,当样本空间不够或者样本数量不足的时候会严重影响训练或者导致训练出来的模型泛化程度不够,识别率与准确率不高。
[0003]因此,如何解决样本数量不足造成的模型训练泛化程度不够,成为本领域技术人员亟待解决的技术问题。
技术实现思路
[0004]本申请提供一种模型训练方法、装置、设备及存储介质,训练得到的生成器可以生成用于模型训练的抽样样本,以实现对模型的充分训练。
[0005]第一方面,本申请提供一种模型训练方法,包括:
[0006]获取第一训练图数据和N组超参,所述第一训练图数据包括R个数据对,每个数据对由图网络中的一个中心节点的第一特征信息和所述中心节点的一个邻居图节点的第一特征信息组成,所述N、R均为 ...
【技术保护点】
【技术特征摘要】
1.一种模型训练方法,其特征在于,包括:获取第一训练图数据和N组超参,所述第一训练图数据包括R个数据对,每个数据对由图网络中的一个中心节点的第一特征信息和所述中心节点的一个邻居图节点的第一特征信息组成,所述N、R均为正整数;将所述第一训练图数据分别输入所述N组超参中每一组超参下的生成器中,以使所述生成器学习在给定中心节点的特征条件下所述中心节点的邻居节点的特征概率分布,得到所述N组超参中每一组超参下训练后的生成器;从N组超参下训练后的生成器中,确定出目标生成器,所述目标生成器用于生成第二抽样样本,所述第二抽样样本用于训练预设的图神经网络模型。2.根据权利要求1所述的方法,其特征在于,所述从N组超参下训练后的生成器中,确定出目标生成器,包括:针对所述N组超参中的每一组超参,使用该组超参下训练后的生成器进行抽样,得到该组超参下训练后的生成器输出的第一抽样样本,并使用该组超参下训练后的生成器输出的第一抽样样本,对预设的第一预测模型进行训练,得到该组超参对应的训练后的第一预测模型;分别确定每一组超参对应的训练后的第一预测模型的预测准确度;将预测准确度最高的第一预测模型所对应的一组超参下训练后的生成器,确定为所述目标生成器。3.根据权利要求1所述的方法,其特征在于,所述将所述第一训练图数据分别输入所述N组超参中每一组超参下的生成器中,以使所述生成器学习在给定中心节点的特征条件下所述中心节点的邻居节点的特征概率分布,得到所述N组超参中每一组超参下训练后的生成器,包括:针对所述N组超参中的每一组超参下的生成器,根据预设的批batch大小,从所述第一训练图数据中获取第i个batch,所述第i个batch包括至少一个数据对,所述i为从1到M的正整数,所述M为预设的生成器的训练次数;使用所述第i个batch对该组超参下的生成器进行训练,以使所述生成器学习在给定所述第i个batch中的各中心节点的特征条件下各中心节点的邻居节点的特征概率分布,得到所述第i个batch训练后的生成器;使用所述第i个batch训练后的生成器进行抽样,得到所述生成器输出的所述第i个batch对应的第三抽样样本;将所述第i个batch对应的第三抽样样本输入预设的第二预测模型中,得到所述第二预测模型输出的所述第i个batch对应的预测结果;根据所述第二预测模型输出的所述第i个batch对应的预测结果,确定所述第i个batch对应的不确定性分数;将M个batch对应的不确定分数中最高不确定分数对应的batch所训练后的生成器,确定为该组超参下训练后的生成器。4.根据权利要求3所述的方法,其特征在于,所述生成器为条件变分自编码器,所述条件变分自编码器包括编码模块和解码模块,所述使用第i个batch对该组超参下的生成器进行训练,以使所述生成器学习在给定所述第i个batch中的中心节点的特征条件下所述中心
节点的邻居节点的特征概率分布,得到所述第i个batch训练后的生成器,包括:针对所述第i个batch中的每一个数据对,将所述数据对输入所述编码模块中,得到所述编码模块输出的第一方差和第一均值;将所述第一方差和所述第一均值通过重参数化技巧转换为第一隐变量,并将所述第一隐变量与所述数据对中的中心节点的第一特征信息,输入所述解码模块中,得到所述解码模块输出的增广特征向量;使用所述增广特征向量对所述编码模块和所述解码模块进行训练,得到所述第i个batch训练后的所述编码模块和所述解码模块。5.根据权利要求4所述的方法,其特征在于,所述使用所述第i个batch训练后的生成器进行抽样,得到所述生成器输出的第三抽样样本,包括:选取K个节点的第一特征信息,所述K为正整数;为所述K个节点中的每一个节点随机选取一个第二方差和第二均值;针对所述K个节点中的每一个节点,将所述节点对应的第二方差和第二均值通过重参数化技巧转换为第二隐变量,并将所述第二隐变量与所述节点的第一特征信息输入所述第i个batch训练后的所述解码模块中,得到所述解码模块输出的第三抽样样本。6.根据权利要求3所述的方法,其特征在于,所述方法还包括:确定所述第i个batch对应的不确定性分数是否为前i个batch对应的不确定性分数中的最高不确定性分数...
【专利技术属性】
技术研发人员:刘松涛,李蓝青,
申请(专利权)人:腾讯科技深圳有限公司,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。