System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind() 基于知识蒸馏的模型训练方法、装置、设备及介质制造方法及图纸_技高网

基于知识蒸馏的模型训练方法、装置、设备及介质制造方法及图纸

技术编号:39995497 阅读:7 留言:0更新日期:2024-01-09 02:43
本申请公开了一种基于知识蒸馏的模型训练方法、装置、设备及介质,属于人工智能技术领域。该方法包括:获取第一训练数据集和第二训练数据集;基于第一训练数据集训练第一深度学习模型,得到第一教师模型;在学生模型的训练过程中,根据第一类指标获取知识蒸馏强度;其中,知识蒸馏强度用于反映知识蒸馏过程中传递知识的程度;第一类指标包括学生模型的训练状态和模型性能、知识蒸馏过程中的温度参数以及模型训练参数中的至少一种;在知识蒸馏强度的约束下,基于第二训练数据集和第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到学生模型。本申请能够提高学生模型的训练效率和性能。

【技术实现步骤摘要】

本申请涉及人工智能,特别涉及一种基于知识蒸馏的模型训练方法、装置、设备及介质


技术介绍

1、知识蒸馏(knowledge distillation,kd)作为一种模型压缩技术,其核心思想是将大型神经网络(也被称为通用模型或教师模型)学习到的知识迁移到小型神经网络(也被称为垂直模型或学生模型)。换言之,知识蒸馏的目标是将通用模型学习到的知识迁移到垂直模型上,以使垂直模型在保持较高性能的同时具有较低的计算复杂性。

2、目前知识蒸馏技术已应用于众多领域,比如自然语言处理领域等。而无论应用于何种领域,如何实现更高效的知识迁移和更优的垂直模型性能一直是本领域的一个研究热点。即,如何通过一种新的模型训练方法来提高垂直模型的训练效率和性能是本领域的一个关注焦点。


技术实现思路

1、本申请实施例提供了一种基于知识蒸馏的模型训练方法、装置、设备及介质,提高了学生模型的训练效率和性能。所述技术方案如下所示。

2、一方面,提供了一种基于知识蒸馏的模型训练方法,所述方法包括如下步骤。

3、获取第一训练数据集和第二训练数据集;

4、基于所述第一训练数据集训练第一深度学习模型,得到第一教师模型;

5、在学生模型的训练过程中,根据第一类指标获取知识蒸馏强度;其中,所述知识蒸馏强度用于反映知识蒸馏过程中传递知识的程度;所述第一类指标包括所述学生模型的训练状态和模型性能、知识蒸馏过程中的温度参数以及模型训练参数中的至少一种;

6、在所述知识蒸馏强度的约束下,基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型。

7、在一种可能的实现方式中,所述第一教师模型用于执行多种深度学习任务;所述第二训练数据集中的训练数据来源于所述第一训练数据集;所述第一训练数据集中的训练数据未被标注;所述第二训练数据集中的训练数据已被标注;或,

8、所述第一训练数据集和所述第二训练数据集为与所述目标深度学习任务匹配的同一数据集。

9、在一种可能的实现方式中,所述方法还包括:

10、在所述学生模型的训练过程中,根据第二类指标获取知识蒸馏率;其中,所述知识蒸馏率用于控制知识蒸馏的速度;所述第二类指标包括所述学生模型的训练进度、模型性能和模型训练参数中的至少一种;

11、所述在所述知识蒸馏强度的约束下,基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型,包括:

12、在所述知识蒸馏强度和所述知识蒸馏率的约束下,基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练所述第二深度学习模型,得到所述学生模型。

13、在一种可能的实现方式中,所述基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型,包括:

14、基于所述第一教师模型的输出概率分布和所述第二深度学习模型预测的输出概率,构建第一损失函数;

15、基于所述第二训练数据集和所述第二深度学习模型预测的输出概率,构建第二损失函数;

16、获取所述第一损失函数的第一权重和所述第二损失函数的第二权重;

17、基于所述第一权重和所述第二权重,对所述第一损失函数和所述第二损失函数进行加权,得到目标损失函数;

18、通过所述目标损失函数迭代获取损失值,直至满足训练停止条件,得到所述学生模型。

19、在一种可能的实现方式中,所述获取所述第一损失函数的第一权重和所述第二损失函数的第二权重,包括:

20、周期性获取所述学生模型在指定数据集上的性能变化;根据所述性能变化,确定所述第一权重和所述第二权重;其中,所述指定数据集为验证数据集或测试数据集;或,

21、基于所述模型训练参数,确定所述第一权重和所述第二权重;或,

22、在模型训练过程中配置权重参数;基于所述权重参数,在模型训练过程中确定所述第一权重和所述第二权重。

23、在一种可能的实现方式中,所述训练状态包括所述学生模型的训练进度和损失变化情况;所述根据第一类指标获取知识蒸馏过程中的知识蒸馏强度,包括:

24、根据所述第一类指标中的每一项分别获取知识蒸馏强度,得到与所述第一类指标中包含的指标项数匹配的多个知识蒸馏强度;

25、获取所述第一类指标中的每一项对知识蒸馏强度的影响权重;

26、根据所述第一类指标中每一项对应的知识蒸馏强度和影响权重,确定当前的知识蒸馏强度。

27、在一种可能的实现方式中,所述在所述知识蒸馏强度的约束下,基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型,包括:

28、获取所述第一教师模型的中间层特征表示和所述第二深度学习模型的中间层特征表示;

29、在所述知识蒸馏强度的约束下,基于所述第二训练数据集、所述第一教师模型的输出概率分布和中间层特征表示、所述第二深度学习模型的中间层特征表示,训练所述第二深度学习模型,得到所述学生模型。

30、在一种可能的实现方式中,所述方法还包括:

31、构建元学习任务,所述元学习任务包括支持集和查询集;其中,所述支持集包括多个子任务,每个子任务配置有不同的第一类指标;所述查询集包括用于训练所述子任务的任务样本;

32、基于所述支持集训练所述元学习模型以及基于所述查询集进行模型性能评估,得到训练好的元学习模型;

33、在所述学生模型的训练过程中,调用训练好的元学习模型,基于所述目标深度学习任务的任务特点或所述第二训练数据集的数据分布,获取知识蒸馏过程中的知识蒸馏强度。

34、在一种可能的实现方式中,所述方法还包括:

35、获取多个训练数据集;

36、基于所述多个训练数据集训练多个深度学习模型,得到多个第二教师模型;

37、所述基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型,包括:

38、根据所述学生模型的训练进度和模型性能,获取所述第一教师模型和所述多个第二教师模型中每个模型的权重;

39、基于获取到的每个模型的权重,对所述第一教师模型和所述多个第二教师模型的输出概率进行加权,得到融合后的输出概率分布;

40、基于所述第二训练数据集和所述融合后的输出概率分布,训练所述第二深度学习模型,得到所述学生模型。

41、在一种可能的实现方式中,所述方法还包括:

42、在新增训练数据的情况下,基于新增的训练数据对所述第一教师模型进行模型微调,得到更新后的第一教师模型;

43、所述基于所述第本文档来自技高网...

【技术保护点】

1.一种基于知识蒸馏的模型训练方法,其特征在于,所述方法包括:

2.根据权利要求1所述的方法,其特征在于,所述第一教师模型用于执行多种深度学习任务;所述第二训练数据集中的训练数据来源于所述第一训练数据集;所述第一训练数据集中的训练数据未被标注;所述第二训练数据集中的训练数据已被标注;或,

3.根据权利要求1所述的方法,其特征在于,所述方法还包括:

4.根据权利要求1所述的方法其特征在于,所述基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型,包括:

5.根据权利要求4所述的方法,其特征在于,所述获取所述第一损失函数的第一权重和所述第二损失函数的第二权重,包括:

6.根据权利要求1所述的方法,其特征在于,所述训练状态包括所述学生模型的训练进度和损失变化情况;所述根据第一类指标获取知识蒸馏过程中的知识蒸馏强度,包括:

7.根据权利要求1所述的方法,其特征在于,所述在所述知识蒸馏强度的约束下,基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型,包括:

8.根据权利要求1所述的方法,其特征在于,所述方法还包括:

9.根据权利要求1所述的方法,其特征在于,所述方法还包括:

10.根据权利要求1所述的方法,其特征在于,所述方法还包括:

11.根据权利要求10所述的方法,其特征在于,所述方法还包括:

12.根据权利要求1所述的方法,其特征在于,所述方法还包括:

13.一种基于知识蒸馏的模型训练装置,其特征在于,所述装置包括:

14.一种计算机设备,其特征在于,所述设备包括处理器和存储器,所述存储器中存储有至少一条程序代码,所述至少一条程序代码由所述处理器加载并执行以实现如权利要求1至12中任一项权利要求所述的基于知识蒸馏的模型训练方法。

15.一种计算机可读存储介质,其特征在于,所述存储介质中存储有至少一条程序代码,所述至少一条程序代码由处理器加载并执行以实现如权利要求1至12中任一项权利要求所述的基于知识蒸馏的模型训练方法。

...

【技术特征摘要】

1.一种基于知识蒸馏的模型训练方法,其特征在于,所述方法包括:

2.根据权利要求1所述的方法,其特征在于,所述第一教师模型用于执行多种深度学习任务;所述第二训练数据集中的训练数据来源于所述第一训练数据集;所述第一训练数据集中的训练数据未被标注;所述第二训练数据集中的训练数据已被标注;或,

3.根据权利要求1所述的方法,其特征在于,所述方法还包括:

4.根据权利要求1所述的方法其特征在于,所述基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型,包括:

5.根据权利要求4所述的方法,其特征在于,所述获取所述第一损失函数的第一权重和所述第二损失函数的第二权重,包括:

6.根据权利要求1所述的方法,其特征在于,所述训练状态包括所述学生模型的训练进度和损失变化情况;所述根据第一类指标获取知识蒸馏过程中的知识蒸馏强度,包括:

7.根据权利要求1所述的方法,其特征在于,所述在所述知识蒸馏强度的约束下,基于所述第二训练数据集和所述第一教...

【专利技术属性】
技术研发人员:陈孝良涂贤玲李良斌常乐
申请(专利权)人:北京声智科技有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1