当前位置: 首页 > 专利查询>中南大学专利>正文

基于主题注意力的深度学习文本分类模型训练方法技术

技术编号:34604194 阅读:22 留言:0更新日期:2022-08-20 09:08
本公开实施例中提供了一种基于主题注意力的深度学习文本分类模型训练方法,属于计算技术领域,具体包括:根据原始文本构建文本数据集;得到文本的数字化表示、文本的掩盖序列、文本的数字标签;得到样本,并将样本分为训练集和验证集;初始化前向网络中变量;得到表征文本的一组词向量;得到原始注意力矩阵;得到目标注意力矩阵;根据目标注意力矩阵,得到概率矩阵;计算注意力头输出;得到注意力输出;计算主题输出;计算主题概率向量;计算交叉熵损失;计算前向网络变量的梯度;更新网络变量;迭代计算交叉熵损失以及梯度;当迭代达到预设次数或模型损失趋于稳定,迭代停止。通过本公开的方案提高了模型的并行性、稳定性、可视性和准确率。准确率。准确率。

【技术实现步骤摘要】
基于主题注意力的深度学习文本分类模型训练方法


[0001]本公开实施例涉及计算
,尤其涉及一种基于主题注意力的深度学习文本分类模型训练方法。

技术介绍

[0002]目前,计算机以及互联网行业蓬勃发展,网络用户迅速增长,促进互联网企业以及网络用户更多的内容制作以及内容输出,并产生了大量的互联网数据。互联网数据包含大量文本数据,表现为内容繁多,形式多样。随着文本数据规模日趋庞大,相关企业处理面临的挑战也日益严峻。
[0003]文本规模的迅速增长,对文本处理工作提出了较高的要求。与传统的数据相比,网络中的文本数据具有许多新的特点,如数据量大、高度重复、高度冗余等。完全依靠人工处理这些信息的代价过大。文本分类是文本处理一项最为基础的任务,使用计算机快速高效的完成文本分类,有利于缓解信息高速增长带来的信息处理问题。
[0004]文本分类经历了从专家系统到机器学习算法再到深度学习算法的跨越。深度学习是机器学习中一种基于对数据进行表征学习的方法,其侧重于利用深度的神经网络,将模型处理得更为复杂,从而使模型对数据的理解更加深入。
[0005]深度学习文本分类模型目前主要以人工神经网络、卷积神经网络、循环神经网络为基础。这些网络搭建的模型为黑箱模型,其参数的解释性不高,不利于网络的优化以及实际的使用。同时,基于传统神经网络的文本分类模型在并发性、稳定性、训练速度、准确率等方面还有改进空间。
[0006]可见,亟需一种并发性、可解释性、稳定性、训练速度和准确率更高的基于主题注意力的深度学习文本分类模型训练方法。

技术实现思路

[0007]有鉴于此,本公开实施例提供一种基于主题注意力的深度学习文本分类模型训练方法,至少部分解决现有技术中存在并发性、可解释性、稳定性、训练速度和准确率较差的问题。
[0008]本公开实施例提供了一种基于主题注意力的深度学习文本分类模型训练方法,包括:
[0009]步骤1,获取原始文本,并根据所述原始文本构建文本数据集;
[0010]步骤2,根据所述文件数据集,得到文本的数字化表示、文本的掩盖序列、文本的数字标签;
[0011]步骤3,根据所述数字化表示,得到样本,并将样本分为训练集和验证集;
[0012]步骤4,初始化前向网络中变量,包括词嵌入表、主题向量以及其它全连接网络层权重;
[0013]步骤5,根据所述数字化表示,得到表征文本的一组词向量;
[0014]步骤6,根据所述词向量组和主题向量组,得到原始注意力矩阵;
[0015]步骤7,根据所述掩盖序列,掩盖原始注意力矩阵中的无效部分,得到目标注意力矩阵;
[0016]步骤8,根据所述目标注意力矩阵,得到概率矩阵;
[0017]步骤9,根据所述概率矩阵和值向量,计算注意力头输出;
[0018]步骤10,将不同头部的注意力头输出拼接并将拼接结果进行线性转化,得到注意力输出;
[0019]步骤11,根据所述注意力输出,计算主题输出;
[0020]步骤12,根据所述主题输出和主题向量,计算主题概率向量;
[0021]步骤13,将所述数字标签转化为one

hot编码形式后根据所述主题概率向量,计算交叉熵损失;
[0022]步骤14,根据所述交叉熵损失,计算前向网络变量的梯度;
[0023]步骤15,根据所述梯度,更新网络变量;
[0024]步骤16,依次从所述训练集中取出一定样本送入前向网络中,不断计算交叉熵损失以及梯度,更新网络变量;
[0025]步骤17,当迭代达到预设次数或模型损失趋于稳定,迭代停止。
[0026]根据本公开实施例的一种具体实现方式,所述前向网络包括词嵌入、主题嵌入,多头注意力模块、线性映射层、前馈网络模块、残差结构、标准化模块。
[0027]根据本公开实施例的一种具体实现方式,所述主题向量和查询向量之间、所述词向量和键向量之间、所述词向量和所述值向量之间,以及,所述注意力头输出和所述主题输出之间均设置有一个全连接层,主题输出和主题概率向量之间设置有多个全连接层。
[0028]根据本公开实施例的一种具体实现方式,所述步骤5具体包括:
[0029]将所述数字化表示中的数字序号依次取出,通过序号查询词嵌入表,取出序号对应行数的向量,将取出的向量按序拼接成矩阵,并根据所述矩阵得到所述词向量。
[0030]根据本公开实施例的一种具体实现方式,所述原始注意力矩阵Score,计算方法如下:
[0031]令Q为查询矩阵,K为键矩阵,V为值矩阵,n为类别数,l为文本最大长度,d
emb
为词向量维度,则:
[0032]Q=(q1,q2,

,q
n
),K=(k1,k2,

,k
l
,),V=(v1,v2,

,v
l
,)
[0033]将Q矩阵和K的转置矩阵做矩阵乘法,并进行缩放,公式如下:
[0034][0035]Score
i,j
表示文本中第j个字符对第i个主题的贡献。
[0036]根据本公开实施例的一种具体实现方式,所述步骤7具体包括:
[0037]步骤7.1,将所述查询向量、键向量、值向量投影到低纬度上,计算每个头独立的注意力;
[0038]步骤7.2,根据所述掩盖序列和每个头独立的注意力掩盖原始注意力矩阵中的无效部分,得到目标注意力矩阵。
[0039]根据本公开实施例的一种具体实现方式,所述概率矩阵的计算公式如下:
[0040][0041]Prob
i
=(Prob
i,1
,Prob
i,2
,
……
,Prob
i,l
)
[0042]Prob=Softmax(Score)=(Prob1,Prob2,
……
,Prob
n
)。
[0043]根据本公开实施例的一种具体实现方式,所述注意力头输出的计算公式如下:
[0044][0045][0046]根据本公开实施例的一种具体实现方式,所述主题概率向量包含多个主题概率,其中,所述主题概率由主题向量和主题输出进行点积运算得出或者使用单节点全连接网络计算得出。
[0047]根据本公开实施例的一种具体实现方式,所述步骤13之前,所述方法还包括:
[0048]使用softmax函数对所述主题概率向量进行归一化。
[0049]本公开实施例中的基于主题注意力的深度学习文本分类模型训练方案,包括:步骤1,获取原始文本,并根据所述原始文本构建文本数据集;步骤2,根据所述文件数据集,得到文本的数字化表示、文本的掩盖序列、文本的数字标签;步骤3,根据所述数字化表示,得到样本,并将样本分为训练集和验证集;步骤4,初始化前向网络中变量,本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种基于主题注意力的深度学习文本分类模型训练方法,其特征在于,包括:步骤1,获取原始文本,并根据所述原始文本构建文本数据集;步骤2,根据所述文件数据集,得到文本的数字化表示、文本的掩盖序列、文本的数字标签;步骤3,根据所述数字化表示,得到样本,并将样本分为训练集和验证集;步骤4,初始化前向网络中变量,包括词嵌入表、主题向量以及其它全连接网络层权重;步骤5,根据所述数字化表示,得到表征文本的一组词向量;步骤6,根据所述词向量组和主题向量组,得到原始注意力矩阵;步骤7,根据所述掩盖序列,掩盖原始注意力矩阵中的无效部分,得到目标注意力矩阵;步骤8,根据所述目标注意力矩阵,得到概率矩阵;步骤9,根据所述概率矩阵和值向量,计算注意力头输出;步骤10,将不同头部的注意力头输出拼接并将拼接结果进行线性转化,得到注意力输出;步骤11,根据所述注意力输出,计算主题输出;步骤12,根据所述主题输出和主题向量,计算主题概率向量;步骤13,将所述数字标签转化为one

hot编码形式后根据所述主题概率向量,计算交叉熵损失;步骤14,根据所述交叉熵损失,计算前向网络变量的梯度;步骤15,根据所述梯度,更新网络变量;步骤16,依次从所述训练集中取出一定样本送入前向网络中,不断计算交叉熵损失以及梯度,更新网络变量;步骤17,当迭代达到预设次数或模型损失趋于稳定,迭代停止。2.根据权利要求1所述的方法,其特征在于,所述前向网络包括词嵌入、主题嵌入,多头注意力模块、线性映射层、前馈网络模块、残差结构、标准化模块。3.根据权利要求1所述的方法,其特征在于,所述主题向量和查询向量之间、所述词向量和键向量之间、所述词向量和所述值向量之间,以及,所述注意力头输出和所述主题输出之间均设置有一个全连接层,主题输出和主题概率向量之间设置有多个全连接层。4.根据权利要求1所述的方法,其特征在于,所述步骤5具体包括:将所述数字化表示中的数字序号依次取出,通过序号查询词嵌入表,取出序号对应行数的向量,将取出的向量按序拼接成矩阵,并根据所述矩阵得到所...

【专利技术属性】
技术研发人员:张祖平彭杰龙哲
申请(专利权)人:中南大学
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1