【技术实现步骤摘要】
预训练模型的微调方法及其装置
[0001]本公开涉及深度学习领域,尤其涉及模型处理
,可应用在对模型进行微调。
技术介绍
[0002]如今自然语言领域发展走向超大规模模型时代,针对规模较大的模型,在有限算力下,采用知识继承的提示词微调技术,可以在少样本与全样本设定下更有效地微调模型。
[0003]其中,在有限算力下,为了提升超大规模模型在特定下游任务上的精确度,按照预训练阶段与微调阶段是否新增需要微调的参数,可以分为如下三种模型微调方法:
[0004]第一种、预训练阶段与微调阶段均不增加需要微调的参数。
[0005]第二种、仅微调阶段增加需要微调的参数。
[0006]第三种、预训练阶段与微调阶段均增加需要微调的参数。
[0007]其中,上述第一种方法的微调效果较差。第二种方法中在微调阶段增加的参数会由于未初始化而导致模型在小样本任务上表现较差,并且还影响模型的收敛速度。第三种方法中无法兼容生成任务。
技术实现思路
[0008]本公开提供了一种预训练模型的微调方法及 ...
【技术保护点】
【技术特征摘要】
1.一种预训练模型的微调方法,其中,包括:确定预训练模型,所述预训练模型包括N层转换Transformer层以及N层全连接层,每层Transformer层分别连接至一层全连接层,N为正整数,其中,每层全连接层对应的参数基于所述Transformer层的参数确定,所述Transformer层的参数为预训练好的;基于预训练数据库确定目标下游任务对应的提示词;将所述提示词以及所述目标下游任务对应的输入文本输入至所述预训练模型得到输出结果,基于所述输出结果计算损失值,并基于所述损失值调整更新所述N层全连接层对应的参数,以微调所述预训练模型。2.根据权利要求1所述的方法,其中,所述每层全连接层对应的参数基于所述Transformer层的参数确定,包括:每层全连接层对应的参数基于所述全连接层连接的Transformer层的多头注意力模块中的参数确定。3.根据权利要求1或2所述的方法,其中,所述全连接层对应的参数包括:键key向量对应参数W
k
′
,键值value向量对应的参数W
v
′
。4.根据权利要求1所述的方法,其中,所述预训练数据库包括:预训练阶段的预训练任务,以及所述预训练任务对应的提示词。5.根据权利要求4所述的方法,其中,所述基于预训练数据库确定目标下游任务对应的提示词,包括:在所述预训练数数据库的预训练任务中,确定与所述目标下游任务的相似程度高于阈值的特定预训练任务;将所述特定预训练任务对应的提示词确定为所述目标下游任务对应的提示词。6.根据权利要求3所述的方法,其中,所述将所述提示词以及所述目标下游任务对应的输入文本输入至所述预训练模型,包括:将所述提示词以及所述输入文本输入至所述Transformer层;将所述提示词输入至所述全连接层。7.根据权利要求6所述的方法,其中,所述全连接层用于:计算出输入的提示词的各个字符对应的key向量以及value向量,并将所述提示词的各个字符对应的key向量以及value向量输入至与所述全连接层连接的Transformer层。8.根据权利要求6或7所述的方法,其中,所述Transformer层用于:计算出输入的提示词的各个字符对应的查询query向量,并基于所述提示词的各个字符分别对应的query向量、key向量、value向量,确定出所述提示词的各个更新字符,将所述提示词的各个更新字符作为下一层Transformer层和下一层全连接层的输入;以及计算出输入文本的各个字符分别对应的query向量、key向量、以及value向量,并将所述输入文本的各个字符分别对应的query向量、key向量、以及value向量输入至下一层Transformer层。9.根据权利要求8所述的方法,其中,所述Transformer层基于如下公式一计算输入的提示词的各个字符对应的query向量,以及输入文本的各个字符分别对应的query向量、key
向量、以及value向量;公式一:q
i
,k
i
,v
i
=W
q
x
i
,W
k
x
i
,W
v
x
i
q
′
i
=W
q
x
′
i
;其中,x
i
表示输入文本的第i个字符,q
i
表示输入文本的第i个字符对应的query向量,k
i
表示输入文本的第i个字符对应的key向量,v
i
表示输入文本的第i个字符对应的value向量,x
i
′
表示输入至Transformer层的提示词的第i个字符,q
i
′
表示输入至Transformer层的提示词的第i个字符对应的query向量,W
q
、W
k
、W
v
为所述Transformer层预训练好的参数。10.根据权利要求7所述的方法,其中,所述全连接层基于如下公式二计算输入的提示词的各个字符对应的key向量以及value向量;公式二:k
i
′
,v
i
′
=W
k
′
x
i
′
,W
v
′
x
i
′
;其中,x
i
′
表示输入至全连接层的提示词的第i个字符,k
i
′
表示输入至全连接层的提示词的第i个字符对应的key向量,v
i
′
表示输入至全连接层的提示词的第i个字符对应的value向量。11.根据权利要求8所述的方法,其中,所述Transformer层基于如下公式三确定出所述提示词的各个更新字符;公式三:其中,j表示输入至第l
‑
1层全连接层和第l
‑
1层Transformer层的提示词所包括的字符总数,x
i
′
,l
表示第l
‑
1层Transformer层计算出的输入至第l层全连接层和第l层Transformer层的提示词的第i个更新字符,q
i
′
,l
‑1表示第l
‑
1层Transformer层计算出的输入至第l
‑
1层Transformer层的提示词的第i个字符对应的query向量,k
j
′
,l
‑1表示第l
‑
1层全连接层计算出的输入至第l
‑
1层全连接层的提示词的第j个字符对应的key向量,v
′
j,l
‑1表示第l
‑
1层全连接层计算出的输入至第l
‑
1层全连接层的提示词的第j个字符对应的value向量。12.根据权利要求10所述的方法,其中,所述方法还包括:响应于所述预训练模型微调完毕,针对所述目标下游任务保存每层全连接层对应的k
i
′
,v
i
′
。13.一种预训练模型的微调装置,其中,包括:第一确定模块,用于确定预训练模型,所述预训练模型包括N层转换Transformer层以及N层全连接层,每层Transformer层分别连接至一层全连接层,N为正整数,其中,每层全连接层对应的参数基于所述Transformer层的参数确定,所述Transformer层的参数为...
【专利技术属性】
技术研发人员:尚骏远,赵晏彬,丁思宇,王硕寰,孙宇,
申请(专利权)人:北京百度网讯科技有限公司,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。