对话生成模型的训练方法、装置、存储介质及计算机设备制造方法及图纸

技术编号:38320232 阅读:9 留言:0更新日期:2023-07-29 09:02
本发明专利技术公开了一种对话生成模型的训练方法、装置、存储介质及计算机设备,涉及人工智能及智慧医疗技术领域。其中方法包括:获取对话生成模型的训练数据集,其中,训练数据集包括多个预设有场景标签的训练样本,场景标签用于标注训练样本适用的场景;基于场景标签将训练样本划分为多个批次组,其中,每个批次组包含属于同一场景的预设数量的训练样本;对多个批次组进行随机排序,生成用于训练对话生成模型的目标训练数据集;根据目标训练数据集,对预设的神经网络模型进行训练,得到对话生成模型。上述方法能够从历史的对话样本中,生成按照适用场景随机排序的批次组,并基于该训练数据集对对话生成模型进行训练,提升模型的性能。能。能。

【技术实现步骤摘要】
对话生成模型的训练方法、装置、存储介质及计算机设备


[0001]本专利技术涉及人工智能及智慧医疗
,尤其是涉及一种对话生成模型的训练方法、装置、存储介质及计算机设备。

技术介绍

[0002]随着神经网络模型技术的发展,Transformer(编码解码模型)类预训练模型越来越受到各方的关注,其使得部署一个对话生成模型来应对多个不同的对话场景成为可能。特别是在医疗交互领域,对话生成模型可以接收来自不同医疗场景的对话信息,生成适用于不同医疗场景的回复信息。
[0003]当前,对话生成模型进行训练的方式多为随机的选取各场景的历史对话数据作为训练数据集中每个批次层面上的训练数据,但基于该种方式对模型进行训练的过程会因训练批次层面上的训练数据过于分散,导致对话生成模型收敛速度较慢。此外,将多个场景下的训练数据按照场景顺序对神经网络模型进行训练,会造成模型学习新知识后,几乎彻底遗忘掉之前学习的内容,导致在对模型的训练过程中会出现灾难性遗忘的情况,进而导致模型训练的效率大幅降低。

技术实现思路

[0004]有鉴于此,本申请提供了一种对话生成模型的训练方法、装置、存储介质及计算机设备,主要目的在于解决模型训练效率偏低的技术问题。
[0005]根据本专利技术的第一个方面,提供了一种对话生成模型的训练方法,该方法包括:
[0006]获取对话生成模型的训练数据集,其中,所述训练数据集包括多个预设有场景标签的训练样本,所述场景标签用于标注所述训练样本适用的场景;
[0007]基于所述场景标签将所述训练样本划分为多个批次组,其中,每个所述批次组包含属于同一场景的预设数量的训练样本;
[0008]对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集;
[0009]根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型。
[0010]根据本专利技术的第二个方面,提供了一种对话生成模型的训练装置,该装置包括:
[0011]数据获取模块,同于获取对话生成模型的训练数据集,其中,所述训练数据集包括多个预设有场景标签的训练样本,所述场景标签用于标注所述训练样本适用的场景;
[0012]样本分组模块,用于基于所述场景标签将所述训练样本划分为多个批次组,其中,每个所述批次组包含属于同一场景的预设数量的训练样本;
[0013]数据生成模块,用于对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集;
[0014]模型训练模块,用于根据所述目标训练数据集,对预设的神经网络模型进行训练,
得到所述对话生成模型。
[0015]根据本专利技术的第三个方面,提供了一种存储介质,其上存储有计算机程序,所述程序被处理器执行时实现上述对话生成模型的训练方法。
[0016]根据本专利技术的第四个方面,提供了一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述对话生成模型的训练方法。
[0017]本专利技术提供的一种对话生成模型的训练方法、装置、存储介质及计算机设备,能够对训练数据集中适用于不用场景的训练样本点进行分类,将适用于同一场景的训练样本在训练批次层面中进行分场景汇聚,使模型能够在适用于同一场景的训练样本下实现快速的收敛,加快模型的训练速度。同时,对批次组进行随机排序得到目标训练数据集,使基于目标训练数据集对对话生成模型进行训练时,不会出现因为训练数据集中某个场景的训练数据离训练数据集的末端太远,而导致的灾难性遗忘的情况发生,进而有效提高了对对话生成模型的训练效果,提升了对话生成模型的性能。
[0018]上述说明仅是本申请技术方案的概述,为了能够更清楚了解本申请的技术手段,而可依照说明书的内容予以实施,并且为了让本申请的上述和其它目的、特征和优点能够更明显易懂,以下特举本申请的具体实施方式。
附图说明
[0019]此处所说明的附图用来提供对本专利技术的进一步理解,构成本申请的一部分,本专利技术的示意性实施例及其说明用于解释本专利技术,并不构成对本专利技术的不当限定。在附图中:
[0020]图1示出了本专利技术实施例提供的一种对话生成模型的训练方法的流程示意图;
[0021]图2示出了本专利技术实施例提供的一种训练数据集的结构示意图之一;
[0022]图3示出了本专利技术实施例提供的一种训练数据集的结构示意图之二;
[0023]图4示出了本专利技术实施例提供的一种目标训练数据集的结构示意图;
[0024]图5示出了本专利技术实施例提供的一种对话生成模型的训练装置的结构示意图;
[0025]图6示出了本专利技术实施例提供的另一种对话生成模型的训练装置的结构示意图。
具体实施方式
[0026]下文中将参考附图并结合实施例来详细说明本专利技术。需要说明的是,在不冲突的情况下,本申请中的实施例及实施例中的特征可以相互组合。
[0027]现有的对话生成模型进行训练的方式多为随机的选取医疗领域各场景的历史对话数据作为训练数据集中每个批次层面上的训练数据,但基于该种方式对模型进行训练的过程会因训练批次层面上的训练数据过于分散,导致对话生成模型收敛速度较慢,同时,将多个场景下的训练数据按照场景顺序对神经网络模型进行训练,会造成模型学习新知识后,几乎彻底遗忘掉之前学习的内容,导致在对模型的训练过程中会出现灾难性遗忘的情况,进而导致模型训练的质量和效率大幅降低。
[0028]针对上述问题,在一个实施例中,如图1所示,提供了一种对话生成模型的训练方法,以该方法应用于计算机设备为例进行说明,包括以下步骤:
[0029]101、获取对话生成模型的训练数据集。
[0030]其中,所述训练数据集包括多个预设有场景标签的训练样本,所述场景标签用于标注所述训练样本适用的场景。
[0031]具体的,可以获取不同用户之间的历史对话数据作为训练数据集,例如,在医疗场景下,可以获取历史上医生和患者的诊断过程的对话数据,并按对话数据的应用场景进行分类,比如可以将医疗场景下的对话场景分为诊断场景、医生问询场景和治疗建议场景等。在获取训练数据集时,可以分别在诊断场景、医生问询场景和治疗建议场景中获取多组对话数据,并在每个对话数据上标注用于区分对话数据所属场景的场景标签,比如在诊断场景中获取的对话数据中标注诊断场景标签,在医生问询场景中获取的对话数据中标注医生问询场景标签,在治疗建议场景中获取的对话数据中标注治疗建议场景标签,得到训练样本。进一步的,如图2所示,可以将多个诊断场景下的训练样本A、多个医生问询场景下的训练样本B和多个治疗建议场景下的训练样本C组合成训练数据集10。应当注意的是,上述场景只是在医疗健康领域训练对话生成模型的常用场景,其他场景同样适用于本实施例。
[0032]进一步的,训练样本可以分为样本对话和标签对话,样本对话为对话数据中医生为得到适用于特定场景的结论与用户在进行沟通的对话,而标签对话为医生在得到特定本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种对话生成模型的训练方法,其特征在于,所述方法包括:获取对话生成模型的训练数据集,其中,所述训练数据集包括多个预设有场景标签的训练样本,所述场景标签用于标注所述训练样本适用的场景;基于所述场景标签将所述训练样本划分为多个批次组,其中,每个所述批次组包含属于同一场景的预设数量的训练样本;对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集;根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型。2.根据权利要求1所述的方法,其特征在于,所述基于所述场景标签将所述训练样本划分为多个批次组,包括:根据所述场景标签,将所述训练样本划分为多个场景组,其中,每个所述场景组包含属于同一场景的全部训练样本;将所述场景组内的多个所述训练样本划分为多个批次组,其中,每个所述批次组包括预设数量的所述训练样本。3.根据权利要求2所述的方法,其特征在于,所述对多个所述批次组进行随机排序,生成用于训练所述对话生成模型的目标训练数据集,包括:执行循环过程直至满足预设条件,其中,所述循环过程包括:从每个所述场景组内选出一个所述批次组,并将选出的多个所述批次组随机组合成综合组,其中,所述综合组内的训练样本适用的场景包含所述训练数据集中全部训练样本适用的场景;所述预设条件为:存在至少一个所述场景组中的全部所述批次组被组成所述综合组;将全部所述综合组与每个所述场景组内未被组成所述综合组的批次组进行随机排列,或对全部所述综合组进行随机排列,得到所述目标训练数据集。4.根据权利要求2所述的方法,其特征在于,所述将所述场景组内的多个所述训练样本划分为多个批次组,包括:将所述场景组内的训练样本排列成样本队列;以所述样本队列的一个端点为起始点,依次获取所述预设数量的训练样本组成一个批次组,得到多个批次组。5.根据权利要求1

4任一项所述的方法,其特征在于,所述根据所述目标训练数据集,对预设的神经网络模型进行训练,得到所述对话生成模型之前,所述方法还包括:获取每个所述批次组内的场景标签;判断同一个所述批次组内的场景标签是否对应同一个场景;若存在所述批次组内的场景标签未对应同一个场景,则发出报警提示信息。6.根据权利要求1

4任一项所述的方法,其特征在于,所述根据所述目标训练数据集...

【专利技术属性】
技术研发人员:刘佳瑞王世朋姚海申孙行智
申请(专利权)人:平安科技深圳有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1