一种基于深度学习模型的文本纠错方法技术

技术编号:33554750 阅读:21 留言:0更新日期:2022-05-26 22:51
一种基于深度学习模型的文本纠错方法,BERT模型使用了Transformer模型的编码器部分,MacBERT用目标单词的相似单词,替代被mask的字符,减轻了预训练和微调阶段之间的差距。并且原始下一个句子预测任务贡献不大,其引入了句子顺序预测任务。基于上两个预训练任务的设置,MacBERT便有了强大的文本建模能力。MacBERT便有了强大的文本建模能力。MacBERT便有了强大的文本建模能力。

【技术实现步骤摘要】
一种基于深度学习模型的文本纠错方法


[0001]本专利技术涉及机器人对话及文本检索领域,具体涉及一种基于深度学习模型的文本纠错方法。

技术介绍

[0002]在检索或者对话场景下,错别字意味着搜索不到内容或者机器人检索不到相关对话,对于用户而言,就是需求无法满足,造成了很差的体验,因此在机器人对话或者检索领域,就很必要去纠错。

技术实现思路

[0003]本专利技术为了克服以上技术的不足,提供了一种使预训练模型自带的tokenzier对输入文本进行编码并输入到模型中,对模型输出的logits解码即得到改正之后的文本的方法。
[0004]本专利技术克服其技术问题所采用的技术方案是:
[0005]一种基于深度学习模型的文本纠错方法,包括如下步骤:
[0006]a)建立模型,该模型由检测网络、软屏蔽网络和纠正网络构成;
[0007]b)将文本转换为能够输入进模型的嵌入;
[0008]c)检测网络输出文本中第i个字符是错别字的概率p
i

[0009]d)软屏蔽网络软屏蔽嵌入本文第i个字符,将其定义为e
i


[0010]e)将e
i

输入纠正网络,纠正网络为基于MacBERT的序列多分类标记模型,检测网络的输出特征作为MacBERT模型12层Transformer模块的输入,将MacBERT模型最后一层的输出与MacBERT模型Input部分的Embedding特征进行残差连接,将残差连接结果作为每个字符最终的特征表示;
[0011]f)模型通过端对端进行学习训练;
[0012]g)将训练完成后的模型通过transformers库加载产生bin文件与txt文件,bin文件为训练完成保存的模型,txt文件为保存的词表;
[0013]h)使用transformers库中的tokenizer对原始文本进行编码,将编码结果输入到训练好的模型中,输出结果为张量tensor,对张量tensor输出取每行的最大值位置下标,使用tokenizer.decode对位置下标进行解码,将解码后的文本作为纠错后的文本。
[0014]进一步的,步骤b)中通过BERT模型的embedding层的输出或word2vec嵌入将将文本转换为能够输入进模型的嵌入。
[0015]进一步的,步骤c)中检测网络由双向门控神经网络Bi

GRU构成,双向门控神经网络Bi

GRU学习输入文本的上下文信息,输出文本每个字符是错别字的概率p
i
。进一步的,步骤d)中软屏蔽网络通过e
i

=p
i
*e
mask
+(1

p
i
)*e
i
计算得到e
i

,式中e
i
为本文第i个字符的输入嵌入,e
mask
为掩码的嵌入。
[0016]进一步的,步骤f)中模型训练时损失函数由检测网络和纠正网络的损失函数加权
得到。
[0017]本专利技术的有益效果是:使用端对端的文本纠错模型,其模型首先要有预测词语的能力。BERT模型使用了Transformer模型的编码器部分,可以理解为BERT旨在学习庞大语料库文本的内部信息。对于BERT模型的升级版MacBERT,其在预训练时策略有所调整,BERT模型的缺点是预训练和微调阶段任务不一致,pretrain有[mask]字符,而finetune没有。MacBERT用目标单词的相似单词,替代被mask的字符,减轻了预训练和微调阶段之间的差距。并且原始下一个句子预测任务贡献不大,其引入了句子顺序预测任务。基于上两个预训练任务的设置,MacBERT便有了强大的文本建模能力。
附图说明
[0018]图1为本专利技术的方法流程图。
具体实施方式
[0019]下面结合附图1对本专利技术做进一步说明。
[0020]一种基于深度学习模型的文本纠错方法,包括如下步骤:
[0021]a)建立模型,该模型由检测网络、软屏蔽网络和纠正网络构成。
[0022]b)将文本转换为能够输入进模型的嵌入。
[0023]c)检测网络输出文本中第i个字符是错别字的概率p
i
,概率p
i
值越大表示该位置出错的可能性越大。
[0024]d)软屏蔽相当于输入嵌入和掩码嵌入的加权和,误差概率p
i
作为权重,软屏蔽网络软屏蔽嵌入本文第i个字符,将其定义为e
i


[0025]e)将e
i

输入纠正网络,纠正网络为基于MacBERT的序列多分类标记模型,检测网络的输出特征作为MacBERT模型12层Transformer模块的输入,将MacBERT模型最后一层的输出与MacBERT模型Input部分的Embedding特征进行残差连接,将残差连接结果作为每个字符最终的特征表示。最终,将每个字符特征通过一层Softmax分类器,从候选词表中输出概率最大的字符认为是每个位置的正确字符。
[0026]f)模型通过端对端进行学习训练。
[0027]g)将训练完成后的模型通过transformers库加载产生bin文件与txt文件,bin文件为训练完成保存的模型,txt文件为保存的词表。
[0028]h)使用transformers库中的tokenizer对原始文本进行编码,将编码结果输入到训练好的模型中,输出结果为张量tensor,对张量tensor输出取每行的最大值位置下标,使用tokenizer.decode对位置下标进行解码,将解码后的文本作为纠错后的文本。
[0029]将文本纠错分为检测网络和纠正网络两部分,训练损失函数以检测损失函数det_loss与纠正损失函数乘比例系数之和作为整体的损失函数。
[0030]具体来看,模型输入是字粒度的embedding,检测网络是由Bi

GRU组成,充分学习输入的上下文表示,输出每个位置i可能是错别字的概率p
i
,值越大表示该位置出错的可能性越大。将每个位置的特征以p
i
的概率乘上masking字符的特征,以(1

p
i
)的概率乘上原始的输入特征,最后两部分相加作为每一个字符的特征输入到纠正网络中。纠正网络是一个基于macbert的序列多分类标记模型。检测网络输出的特征作为MacBERT 12层Transformer
模块的输入,最后一层的输出+Input部分的Embedding特征作为每个字符最终的特征表示。最终,将每个字特征经过一层softmax分类器,从候选词表中输出概率最大的字符认为是每个位置的正确字符。
[0031]使用训练好的本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于深度学习模型的文本纠错方法,其特征在于,包括如下步骤:a)建立模型,该模型由检测网络、软屏蔽网络和纠正网络构成;b)将文本转换为能够输入进模型的嵌入;c)检测网络输出文本中第i个字符是错别字的概率p
i
;d)软屏蔽网络软屏蔽嵌入本文第i个字符,将其定义为e
i

;e)将e
i

输入纠正网络,纠正网络为基于MacBERT的序列多分类标记模型,检测网络的输出特征作为MacBERT模型12层Transformer模块的输入,将MacBERT模型最后一层的输出与MacBERT模型Input部分的Embedding特征进行残差连接,将残差连接结果作为每个字符最终的特征表示;f)模型通过端对端进行学习训练;g)将训练完成后的模型通过transformers库加载产生bin文件与txt文件,bin文件为训练完成保存的模型,txt文件为保存的词表;h)使用transformers库中的tokenizer对原始文本进行编码,将编码结果输入到训练好的模型中,输出结果为张量tensor,对张量tensor输出取每行的最大值位置下标,使用tokenizer.decode对位置下标进行解码,将解码后的文本作为纠错后的...

【专利技术属性】
技术研发人员:李晓瑜冯落落冯卫森李沛
申请(专利权)人:山东新一代信息产业技术研究院有限公司
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1