【技术实现步骤摘要】
一种机器学习优化方法以及装置
本申请涉及人工智能领域,尤其涉及一种机器学习优化方法以及装置。
技术介绍
机器学习是人工智能的一个分支。人们通过构建一个参数化的模型来表示规律,一些端到端的机器学习方法在样本数据中获取规律,应用规律进行预测。以小样本(fewshot)机器学习任务为例,现有方案在进行小样本机器学习时,首选选定一个机器学习模型,然后在每轮学习的训练过程中,获取一组小样本机器学习任务,计算机器学习模型进行所有任务的损失函数的损失值。用求平均的方法将损失值汇聚作为最终的损失值,即平均损失值,依照此平均损失值以及对应的所有任务对机器学习模型进行反向传播的参数更新。然而,若使用平均损失值对小样本机器学习模型进行更新,将导致机器学习模型的鲁棒性差。例如,若小样本机器学习模型进行的实际新任务与训练时使用的平均任务存在较大差别,则机器学习模型并不能适应新任务,导致机器学习模型的输出结果不准确。
技术实现思路
本申请提供一种机器学习优化方法以及装置,用于对机器学习模型进行优化,得到鲁棒性更优的机器学习模型。第一方面,本申请提供一种机器学习优化方法,包括:获取机器学习模型;通过预设的数据集对机器学习模型进行至少一次迭代更新,得到更新后的机器学习模型,更新后的机器学习模型符合预设条件;其中,至少一次迭代更新的任意一次迭代更新包括:获取上一次迭代更新后的机器学习模型执行多个任务对应的多个损失值,从多个损失值中选取大于平均损失值的至少一个损失值,根据大于平均损失值的至少一个损失值以及至少 ...
【技术保护点】
1.一种机器学习优化方法,其特征在于,包括:/n获取机器学习模型;/n通过预设的数据集对所述机器学习模型进行至少一次迭代更新,得到更新后的机器学习模型,所述更新后的机器学习模型符合预设条件;/n其中,所述至少一次迭代更新的任意一次迭代更新包括:获取上一次迭代更新后的机器学习模型执行多个任务对应的多个损失值,从所述多个损失值中选取大于平均损失值的至少一个损失值,根据所述大于平均损失值的至少一个损失值以及所述至少一个损失值中每个损失值对应的任务对所述上一次迭代更新后的机器学习模型进行更新,以得到当前次迭代更新后的机器学习模型,若所述当前次更新后的机器学习模型不符合所述预设条件,则对所述当前次更新后的机器学习模型进行下一次迭代更新。/n
【技术特征摘要】
1.一种机器学习优化方法,其特征在于,包括:
获取机器学习模型;
通过预设的数据集对所述机器学习模型进行至少一次迭代更新,得到更新后的机器学习模型,所述更新后的机器学习模型符合预设条件;
其中,所述至少一次迭代更新的任意一次迭代更新包括:获取上一次迭代更新后的机器学习模型执行多个任务对应的多个损失值,从所述多个损失值中选取大于平均损失值的至少一个损失值,根据所述大于平均损失值的至少一个损失值以及所述至少一个损失值中每个损失值对应的任务对所述上一次迭代更新后的机器学习模型进行更新,以得到当前次迭代更新后的机器学习模型,若所述当前次更新后的机器学习模型不符合所述预设条件,则对所述当前次更新后的机器学习模型进行下一次迭代更新。
2.根据权利要求1所述的方法,其特征在于,在所述通过预设的数据集对所述机器学习模型进行至少一次迭代更新之前,所述方法还包括:
获取鲁棒性要求,所述鲁棒性要求用于表示在所述每一次迭代更新过程中,对所述机器学习模型进行更新的任务的数量占所述机器学习模型执行的任务的比例;
根据所述鲁棒性要求确定对所述机器学习模型的每一次迭代更新所执行的任务的数量m,所述m为大于1的正整数。
3.根据权利要求2所述的方法,其特征在于,在所述至少一次迭代更新的任意一次迭代更新中,所述获取上一次迭代更新后的所述机器学习模型执行多个任务对应的多个损失值,包括:
对预设的数据集进行采样,得到m个任务;
获取所述上一次迭代更新后的所述机器学习模型执行所述m个任务对应的m个损失值。
4.根据权利要求3所述的方法,其特征在于,所述获取所述上一次迭代更新后的所述机器学习模型执行所述m个任务对应的多个损失值,包括:
获取与所述m个任务对应的多个数据集,所述多个数据集中的每个数据集包括训练子集和验证子集;
通过所述每个数据集包括的训练子集对所述上一次迭代更新后的机器学习模型进行训练,得到训练后的机器学习模型;
将所述每个数据集包括的验证子集作为所述训练后的机器学习模型的输入,得到所述训练后的机器学习模型的多个输出结果;
根据所述多个输出结果得到所述m个损失值。
5.根据权利要求1-4中任一项所述的方法,其特征在于,在所述至少一次迭代更新的任意一次迭代更新中,
所述从所述多个损失值中选取大于平均损失值的至少一个损失值,包括:
从所述多个损失值中选取一个值最大的损失值;
所述根据所述大于平均损失值的至少一个损失值以及所述至少一个损失值中每个损失值对应的任务对所述上一次迭代更新后的机器学习模型进行更新,包括:
根据所述值最大的损失值以及所述值最大的损失值对应的任务,对所述上一次迭代更新后的机器学习模型基于反向传播进行更新,得到所述当前次迭代更新后的机器学习模型。
6.根据权利要求1-5中任一项所述的方法,其特征在于,所述机器学习模型包括以下至少一种:多层前馈神经网络、卷积神经网络、循环神经网络或者图神经网络。
7.根据权利要求1-6中任一项所述的方法,其特征在于,所述多个任务包括小样本学习任务。
8.一种机器学习优化装置,其特征在于,包括:
获取模块,用于获取机器学习模型;
更新模块,用于通过预设的数据集对所述机...
【专利技术属性】
技术研发人员:吴霜,谢传龙,田光见,
申请(专利权)人:华为技术有限公司,
类型:发明
国别省市:广东;44
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。