一种机器学习优化方法以及装置制造方法及图纸

技术编号:25990817 阅读:22 留言:0更新日期:2020-10-20 18:59
本申请公开了人工智能领域的一种机器学习优化方法以及装置,用于对机器学习模型进行优化,得到鲁棒性更优的机器学习模型。该方法包括:获取机器学习模型;通过预设的数据集对机器学习模型进行至少一次迭代更新,得到符合预设条件的更新后的机器学习模型;其中,任意一次迭代更新包括:获取上一次迭代更新后的机器学习模型执行多个任务对应的多个损失值,从多个损失值中选取大于平均损失值的至少一个损失值,根据大于平均损失值的至少一个损失值以及至少一个损失值中每个损失值对应的任务对上一次迭代更新后的机器学习模型进行更新,若当前次更新后的机器学习模型不符合预设条件,则对当前次更新后的机器学习模型进行下一次迭代更新。

【技术实现步骤摘要】
一种机器学习优化方法以及装置
本申请涉及人工智能领域,尤其涉及一种机器学习优化方法以及装置。
技术介绍
机器学习是人工智能的一个分支。人们通过构建一个参数化的模型来表示规律,一些端到端的机器学习方法在样本数据中获取规律,应用规律进行预测。以小样本(fewshot)机器学习任务为例,现有方案在进行小样本机器学习时,首选选定一个机器学习模型,然后在每轮学习的训练过程中,获取一组小样本机器学习任务,计算机器学习模型进行所有任务的损失函数的损失值。用求平均的方法将损失值汇聚作为最终的损失值,即平均损失值,依照此平均损失值以及对应的所有任务对机器学习模型进行反向传播的参数更新。然而,若使用平均损失值对小样本机器学习模型进行更新,将导致机器学习模型的鲁棒性差。例如,若小样本机器学习模型进行的实际新任务与训练时使用的平均任务存在较大差别,则机器学习模型并不能适应新任务,导致机器学习模型的输出结果不准确。
技术实现思路
本申请提供一种机器学习优化方法以及装置,用于对机器学习模型进行优化,得到鲁棒性更优的机器学习模型。第一方面,本申请提供一种机器学习优化方法,包括:获取机器学习模型;通过预设的数据集对机器学习模型进行至少一次迭代更新,得到更新后的机器学习模型,更新后的机器学习模型符合预设条件;其中,至少一次迭代更新的任意一次迭代更新包括:获取上一次迭代更新后的机器学习模型执行多个任务对应的多个损失值,从多个损失值中选取大于平均损失值的至少一个损失值,根据大于平均损失值的至少一个损失值以及至少一个损失值中每个损失值对应的任务对上一次迭代更新后的机器学习模型进行更新,以得到当前次迭代更新后的机器学习模型,若当前次更新后的机器学习模型不符合预设条件,则对当前次更新后的机器学习模型进行下一次迭代更新。因此,本申请实施方式中,在机器学习模型的每一轮迭代更新的过程中,使用大于平均损失值的至少一个损失值以及对应的任务来更新机器学习模型,从而提高机器学习模型在此大于平均损失值的至少一个损失值的任务中的表现,提高机器学习模型输出此大于平均损失值的至少一个损失值的任务的结果的准确性。并且,相对于在每一轮迭代更新中使用所有任务对机器学习模型进行更新,本申请实施方式仅使用大于平均损失值的至少一个损失值以及对应的任务来个更新机器学习模型,减少了更新流程,提高得到符合预设条件的机器学习模型的效率。在一种可能的实施方式中,在通过预设的数据集对机器学习模型进行至少一次迭代更新之前,上述方法还可以包括:获取鲁棒性要求,鲁棒性要求用于表示在每一次迭代更新过程中,对机器学习模型进行更新的任务的数量占机器学习模型执行的任务的比例;根据鲁棒性要求确定对机器学习模型的每一次迭代更新所执行的任务的数量m,m为大于1的正整数。因此,本申请实施方式中,可以通过鲁棒性要求确定对机器学习模型的每一轮更新中执行的任务的数量,以及对机器学习模型进行更新所使用的任务的数量所占的比例,从而可以提高机器学习模型在对机器学习模型进行更新所使用的任务中的输出结果的准确性,提高机器学习模型的鲁棒性。在一种可能的实施方式中,该鲁棒性要求为接收用户的输入数据得到,从而使用户可以通过输入数据调整机器学习模型的鲁棒性,提高用户体验。在一种可能的实施方式中,在至少一次迭代更新的任意一次迭代更新中,获取上一次迭代更新后的机器学习模型执行多个任务对应的多个损失值,可以包括:对预设的数据集进行采样,得到m个任务;获取上一次迭代更新后的机器学习模型执行m个任务对应的m个损失值。本申请实施方式中,可以根据鲁棒性要求确定每一轮更新所进行的任务的数量,并在预设的数据集中采集任务,进而完成对机器学习模型的更新。在一种可能的实施方式中,获取上一次迭代更新后的机器学习模型执行m个任务对应的多个损失值,可以包括:获取与m个任务对应的多个数据集,多个数据集中的每个数据集包括训练子集和验证子集;通过每个数据集包括的训练子集对上一次迭代更新后的机器学习模型进行训练,得到训练后的机器学习模型;将每个数据集包括的验证子集作为训练后的机器学习模型的输入,得到训练后的机器学习模型的多个输出结果;根据多个输出结果得到m个损失值。本申请实施方式中,可以基于选取的m个任务对应的数据,分别对机器学习模型进行训练和验证,从而完成对m个任务的机器学习。在一种可能的实施方式中,在至少一次迭代更新的任意一次迭代更新中,从多个损失值中选取大于平均损失值的至少一个损失值,可以包括:从多个损失值中选取一个值最大的损失值;根据大于平均损失值的至少一个损失值以及至少一个损失值中每个损失值对应的任务对上一次迭代更新后的机器学习模型进行更新,包括:根据值最大的损失值以及值最大的损失值对应的任务,对上一次迭代更新后的机器学习模型基于反向传播进行更新,得到当前次迭代更新后的机器学习模型。本申请实施方式中,可以选择值最大的损失值以及对应的任务对机器学习模型进行更新,从而可以提高机器学习模型在该值最大的任务中的表现,使得机器学习模型在1-1/m的置信度下,优化损失函数最大1/m部分的任务。在一种可能的实施方式中,机器学习模型可以包括以下至少一种:多层前馈神经网络、卷积神经网络、循环神经网络或者图神经网络。因此,本申请提供的机器学习优化方法,可以适用于多种场景,满足多种需求,提高用户体验。在一种可能的实施方式中,前述的多个任务包括小样本学习任务。因此,本申请实施方式中,针对小样本学习任务,可以进行快速地学习,从而快速得到机器学习模型。第二方面,本申请提供一种机器学习优化装置,包括:获取模块,用于获取机器学习模型;更新模块,用于通过预设的数据集对机器学习模型进行至少一次迭代更新,得到更新后的机器学习模型,更新后的机器学习模型符合预设条件;其中,至少一次迭代更新的任意一次迭代更新包括:获取上一次迭代更新后的机器学习模型执行多个任务对应的多个损失值,从多个损失值中选取大于平均损失值的至少一个损失值,根据大于平均损失值的至少一个损失值以及至少一个损失值中每个损失值对应的任务对上一次迭代更新后的机器学习模型进行更新,以得到当前次迭代更新后的机器学习模型,若当前次更新后的机器学习模型不符合预设条件,则对当前次更新后的机器学习模型进行下一次迭代更新。第二方面及第二方面任一种可能的实施方式产生的有益效果可参照第一方面及第一方面任一种可能实施方式的描述。在一种可能的实施方式中,获取模块,还用于在通过预设的数据集对机器学习模型进行至少一次迭代更新之前,获取鲁棒性要求,鲁棒性要求用于表示在每一次迭代更新过程中,对机器学习模型进行更新的任务的数量占机器学习模型执行的任务的比例;根据鲁棒性要求确定对机器学习模型的每一次迭代更新所执行的任务的数量m,m为大于1的正整数。在一种可能的实施方式中,更新模块,具体用于在至少一次迭代更新的任意一次迭代更新中,对预设的数据集进行采样,得到m个任务;获取上一次迭代更新后的机器学习模型执行m个任务对应的m个损失值。在一本文档来自技高网...

【技术保护点】
1.一种机器学习优化方法,其特征在于,包括:/n获取机器学习模型;/n通过预设的数据集对所述机器学习模型进行至少一次迭代更新,得到更新后的机器学习模型,所述更新后的机器学习模型符合预设条件;/n其中,所述至少一次迭代更新的任意一次迭代更新包括:获取上一次迭代更新后的机器学习模型执行多个任务对应的多个损失值,从所述多个损失值中选取大于平均损失值的至少一个损失值,根据所述大于平均损失值的至少一个损失值以及所述至少一个损失值中每个损失值对应的任务对所述上一次迭代更新后的机器学习模型进行更新,以得到当前次迭代更新后的机器学习模型,若所述当前次更新后的机器学习模型不符合所述预设条件,则对所述当前次更新后的机器学习模型进行下一次迭代更新。/n

【技术特征摘要】
1.一种机器学习优化方法,其特征在于,包括:
获取机器学习模型;
通过预设的数据集对所述机器学习模型进行至少一次迭代更新,得到更新后的机器学习模型,所述更新后的机器学习模型符合预设条件;
其中,所述至少一次迭代更新的任意一次迭代更新包括:获取上一次迭代更新后的机器学习模型执行多个任务对应的多个损失值,从所述多个损失值中选取大于平均损失值的至少一个损失值,根据所述大于平均损失值的至少一个损失值以及所述至少一个损失值中每个损失值对应的任务对所述上一次迭代更新后的机器学习模型进行更新,以得到当前次迭代更新后的机器学习模型,若所述当前次更新后的机器学习模型不符合所述预设条件,则对所述当前次更新后的机器学习模型进行下一次迭代更新。


2.根据权利要求1所述的方法,其特征在于,在所述通过预设的数据集对所述机器学习模型进行至少一次迭代更新之前,所述方法还包括:
获取鲁棒性要求,所述鲁棒性要求用于表示在所述每一次迭代更新过程中,对所述机器学习模型进行更新的任务的数量占所述机器学习模型执行的任务的比例;
根据所述鲁棒性要求确定对所述机器学习模型的每一次迭代更新所执行的任务的数量m,所述m为大于1的正整数。


3.根据权利要求2所述的方法,其特征在于,在所述至少一次迭代更新的任意一次迭代更新中,所述获取上一次迭代更新后的所述机器学习模型执行多个任务对应的多个损失值,包括:
对预设的数据集进行采样,得到m个任务;
获取所述上一次迭代更新后的所述机器学习模型执行所述m个任务对应的m个损失值。


4.根据权利要求3所述的方法,其特征在于,所述获取所述上一次迭代更新后的所述机器学习模型执行所述m个任务对应的多个损失值,包括:
获取与所述m个任务对应的多个数据集,所述多个数据集中的每个数据集包括训练子集和验证子集;
通过所述每个数据集包括的训练子集对所述上一次迭代更新后的机器学习模型进行训练,得到训练后的机器学习模型;
将所述每个数据集包括的验证子集作为所述训练后的机器学习模型的输入,得到所述训练后的机器学习模型的多个输出结果;
根据所述多个输出结果得到所述m个损失值。


5.根据权利要求1-4中任一项所述的方法,其特征在于,在所述至少一次迭代更新的任意一次迭代更新中,
所述从所述多个损失值中选取大于平均损失值的至少一个损失值,包括:
从所述多个损失值中选取一个值最大的损失值;
所述根据所述大于平均损失值的至少一个损失值以及所述至少一个损失值中每个损失值对应的任务对所述上一次迭代更新后的机器学习模型进行更新,包括:
根据所述值最大的损失值以及所述值最大的损失值对应的任务,对所述上一次迭代更新后的机器学习模型基于反向传播进行更新,得到所述当前次迭代更新后的机器学习模型。


6.根据权利要求1-5中任一项所述的方法,其特征在于,所述机器学习模型包括以下至少一种:多层前馈神经网络、卷积神经网络、循环神经网络或者图神经网络。


7.根据权利要求1-6中任一项所述的方法,其特征在于,所述多个任务包括小样本学习任务。


8.一种机器学习优化装置,其特征在于,包括:
获取模块,用于获取机器学习模型;
更新模块,用于通过预设的数据集对所述机...

【专利技术属性】
技术研发人员:吴霜谢传龙田光见
申请(专利权)人:华为技术有限公司
类型:发明
国别省市:广东;44

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1