模型训练方法、装置、设备及存储介质制造方法及图纸

技术编号:38685273 阅读:8 留言:0更新日期:2023-09-02 22:58
本公开实施例提供了一种模型训练方法、装置、设备及存储介质,用于训练多模态融合网络。获取多模态数据;其中,所述多模态数据包括图像数据、文本数据及音频数据中至少两种模态的数据;将所述多模态数据依次输入所述多模态融合网络,输出多模态数据处理结果;基于所述多模态数据处理结果训练所述多模态适配子网络、模态融合子网络及目标任务子网络中的至少一项,获得训练后的多模态融合网络。本公开实施例提供的模型训练方法,训练多模态融合网络中除预训练多模态子网络外的其他子网络,能够有效降低训练所需内存和显存等资源,同时又能利用预训练好的大模型,可以极大的节省计算资源及时间,从而提高多模态融合网络的训练及部署效率。效率。效率。

【技术实现步骤摘要】
模型训练方法、装置、设备及存储介质


[0001]本公开实施例涉及神经网络
,尤其涉及一种模型训练方法、装置、设备及存储介质。

技术介绍

[0002]目前,神经网络模型越来越大,里面包含的参数量非常大,尤其是多模态神经网络模型,包含的参数量更大。如果想将神经网络模型迁移至下游应用领域,需要大量的计算资源把大模型训练起来,由于模型规模较大,将其加载进显存都非常困难。因此,训练大规模多模态模型通常需要大量的计算资源和时间,影响模型训练和部署效率。

技术实现思路

[0003]本公开实施例提供一种模型训练方法、装置、设备及存储介质,训练多模态融合网络中除预训练多模态子网络外的其他子网络,可以极大的节省计算资源及时间,从而提高多模态融合网络的训练及部署效率。
[0004]第一方面,本公开实施例提供了一种模型训练方法,用于训练多模态融合网络,所述多模态融合网络包括依次连接的预训练多模态子网络、多模态适配子网络、模态融合子网络及目标任务子网络,所述方法包括:
[0005]获取多模态数据;其中,所述多模态数据包括图像数据、文本数据及音频数据中的至少两种模态的数据;
[0006]将所述多模态数据输入所述多模态融合网络,输出多模态数据处理结果;
[0007]基于所述多模态数据处理结果训练所述多模态适配子网络、模态融合子网络及目标任务子网络中的至少一项,获得训练后的多模态融合网络。
[0008]第二方面,本公开实施例还提供了一种模型训练装置,用于训练多模态融合网络,所述多模态融合网络包括依次连接的预训练多模态子网络、多模态适配子网络、模态融合子网络及目标任务子网络,所述装置包括:
[0009]多模态数据获取模块,用于获取多模态数据;其中,所述多模态数据包括图像数据、文本数据及音频数据中的至少两种模态的数据;
[0010]多模态数据处理结果获取模块,用于将所述多模态数据输入所述多模态融合网络,输出多模态数据处理结果;
[0011]多模态融合网络训练模块,用于基于所述多模态数据处理结果训练所述多模态适配子网络、模态融合子网络及目标任务子网络中的至少一项,获得训练后的多模态融合网络。
[0012]第三方面,本公开实施例还提供了一种电子设备,所述电子设备包括:
[0013]一个或多个处理器;
[0014]存储装置,用于存储一个或多个程序,
[0015]当所述一个或多个程序被所述一个或多个处理器执行,使得所述一个或多个处理
是“至少部分地基于”。术语“一个实施例”表示“至少一个实施例”;术语“另一实施例”表示“至少一个另外的实施例”;术语“一些实施例”表示“至少一些实施例”。其他术语的相关定义将在下文描述中给出。
[0033]需要注意,本公开中提及的“第一”、“第二”等概念仅用于对不同的装置、模块或单元进行区分,并非用于限定这些装置、模块或单元所执行的功能的顺序或者相互依存关系。
[0034]需要注意,本公开中提及的“一个”、“多个”的修饰是示意性而非限制性的,本领域技术人员应当理解,除非在上下文另有明确指出,否则应该理解为“一个或多个”。
[0035]本公开实施方式中的多个装置之间所交互的消息或者信息的名称仅用于说明性的目的,而并不是用于对这些消息或信息的范围进行限制。
[0036]可以理解的是,在使用本公开各实施例公开的技术方案之前,均应当依据相关法律法规通过恰当的方式对本公开所涉及个人信息的类型、使用范围、使用场景等告知用户并获得用户的授权。
[0037]例如,在响应于接收到用户的主动请求时,向用户发送提示信息,以明确地提示用户,其请求执行的操作将需要获取和使用到用户的个人信息。从而,使得用户可以根据提示信息来自主地选择是否向执行本公开技术方案的操作的电子设备、应用程序、服务器或存储介质等软件或硬件提供个人信息。
[0038]作为一种可选的但非限定性的实现方式,响应于接收到用户的主动请求,向用户发送提示信息的方式例如可以是弹窗的方式,弹窗中可以以文字的方式呈现提示信息。此外,弹窗中还可以承载供用户选择“同意”或者“不同意”向电子设备提供个人信息的选择控件。
[0039]可以理解的是,上述通知和获取用户授权过程仅是示意性的,不对本公开的实现方式构成限定,其它满足相关法律法规的方式也可应用于本公开的实现方式中。
[0040]可以理解的是,本技术方案所涉及的数据(包括但不限于数据本身、数据的获取或使用)应当遵循相应法律法规及相关规定的要求。
[0041]为了将上游模型迁移到下游任务上,传统的解决方法如下:
[0042]1、微调(full model finetune):将整个预训练的多模态模型微调到新的任务特定数据集上是迁移学习中最常用的方法。虽然有效,但该方法存在多个限制,如高计算成本、灾难性遗忘的风险以及需要大量的任务特定数据等。
[0043]2、特征拼接(feature concatenation):针对多模态融合,一种简单策略是将从不同模态中提取的特征拼接起来,然后再输入到分类器中。然而,这种方法无法捕捉模态之间的复杂交互,通常会导致高维输入,增加模型的复杂性和计算要求。
[0044]3、多模态注意力机制(Multimodal attention mechanism):这种是传统常用的方法,这种也被称为跨模态注意力机制,会根据任务的相关性,为不同模态或其特征分配权重。虽然这种方法可以有效地模拟每个模态的重要性,但在捕捉高阶交互或适应具有少量训练数据的新任务方面可能存在局限。
[0045]图1为本公开实施例所提供的一种模型训练方法的流程示意图,本公开实施例适用于对多模态融合网络进行训练的情形,该方法可以由模型训练装置来执行,该装置可以通过软件和/或硬件的形式实现,可选的,通过电子设备来实现,该电子设备可以是移动终端、PC端或服务器等。
[0046]其中,多模态融合网络包括依次连接的预训练多模态子网络(Pre

trained Multi

modal Model)、多模态适配子网络(Multi

Modal Adapters,MMA)、模态融合子网络(Dynamic Fusion Mechanism,DFM)及目标任务子网络(Task

Specific Fine

Tuning,TSFT)。
[0047]如图1所示,所述方法包括:
[0048]S110,获取多模态数据。
[0049]其中,多模态数据包括图像数据、文本数据及音频数据中至少两种模态的数据。本实施例中,多模态数据可以是成对或者成组的图像

文本

音频数据。示例性的,获取多模态数据的过程可以是:从开源(或者被授权)的电影、电视剧、短视频等进行抽帧处理,从帧中抽取出图像、文本及音频,并将图像、文本及音频进行对齐,同时去除重复的多模态数本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种模型训练方法,用于训练多模态融合网络,其特征在于,所述多模态融合网络包括依次连接的预训练多模态子网络、多模态适配子网络、模态融合子网络及目标任务子网络,所述方法包括:获取多模态数据;其中,所述多模态数据包括图像数据、文本数据及音频数据中至少两种模态的数据;将所述多模态数据依次输入所述多模态融合网络,输出多模态数据处理结果;基于所述多模态数据处理结果训练所述多模态适配子网络、模态融合子网络及目标任务子网络中的至少一项,获得训练后的多模态融合网络。2.根据权利要求1所述的方法,其特征在于,将所述多模态数据输入所述多模态融合网络,输出多模态数据处理结果,包括:基于所述预训练多模态子网络对所述多模态数据进行特征提取,获得多模态特征数据;基于所述多模态适配子网络对所述多模态特征数据进行调整,获得调整后的多模态特征数据;基于所述模态融合子网络对所述调整后的多模态特征数据进行融合,获得融合特征数据;基于所述目标任务子网络对所述融合特征数据进行目标任务的数据处理,获得多模态数据处理结果。3.根据权利要求2所述的方法,其特征在于,所述多模态适配子网络包括多模态适配器及跨模态适配器;其中,所述多模态适配器包括图像适配器、文本适配器及音频适配器中至少两种模态的适配器;基于所述多模态适配子网络对所述多模态特征数据进行调整,获得调整后的多模态特征数据,包括:基于所述多模态适配器分别对对应模态的特征数据进行调整,获得调整后的各模态的特征数据;基于所述跨模态适配器对调整后的各模态的特征数据进行跨模态调整,获得再次调整后的各模态的特征数据。4.根据权利要求3所述的方法,其特征在于,所述图像适配器、文本适配器及音频适配器均包括两个全连接层,且第一个全连接层的输入与第二个全连接层的输出残差连接;所述跨模态适配器包括多头注意力层及前馈层。5.根据权利要求2所述的方法,其特征在在于,所述模态融合子网络包括模态注意力融合模块、上下文融合模块及特征融合模块;基于所述模态融合子网络对所述调整后的多模态特征数据进行融合,获得融合特征数据,包括:基于所述模态注意力融合模块对所述调整后的多模态特征数据按照注意力分数进行融合,获得第一中间融合特征数据;基于所述上下文融合模块对所述第一中间融合特征数据进行上下文融合,获得第二中间融合特征数据;基于所述特征融合模块对所述第一中间融合特征数据和所述第二中间融合特征数据进行叠加,获得融合特征数据。6.根据权利要求5所述的方法,其特征在于,所述模态注意力融合模块包括:前馈层、激
活层及融合层;基于所述...

【专利技术属性】
技术研发人员:杨志雄杨延展
申请(专利权)人:抖音视界有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1