System.ArgumentOutOfRangeException: 索引和长度必须引用该字符串内的位置。 参数名: length 在 System.String.Substring(Int32 startIndex, Int32 length) 在 zhuanliShow.Bind()
【技术实现步骤摘要】
本专利技术涉及深度学习,特别涉及扩散模型微调方法、装置、设备及介质。
技术介绍
1、目前扩散模型的微调技术应用场景非常广泛,风格化模型微调、定制化模型微调等需求的应用越来越多。目前常用的微调技术分为两种,一种是基模型微调,得到的微调模型具备较好的泛化能力,但拟合能力相对较弱,且更新参数多,训练和存储成本较大;另外一种是增量模型微调,例如是lora微调,其具备较好的风格拟合能力,但损失了模型较多的泛化能力。
2、综上可见,如何提高扩散模型的拟合能力和泛化能力,并降低计算成本是本领域有待解决的问题。
技术实现思路
1、有鉴于此,本专利技术的目的在于提供一种扩散模型微调方法、装置、设备及介质,能够提高扩散模型的拟合能力和泛化能力,并降低计算成本。其具体方案如下:
2、第一方面,本申请公开了一种扩散模型微调方法,包括:
3、获取待微调扩散模型和文本字符串信息;其中,所述待微调扩散模型包括当前文本编码器和当前unet网络,所述当前文本编码器包括当前切词器和当前嵌入层;
4、利用所述文本字符串信息对所述当前切词器的词汇表和所述当前嵌入层进行微调训练,以得到更新后切词器和更新后嵌入层,并基于所述更新后切词器和所述更新后嵌入层获取更新后文本编码器;
5、冻结所述更新后文本编码器,并对所述当前unet网络中交叉注意力层的键向量和值向量进行微调训练,以得到更新后unet网络;
6、基于所述更新后文本编码器和所述更新后unet网络获
7、可选的,所述利用所述文本字符串信息对所述当前切词器的词汇表和所述当前嵌入层进行微调训练,以得到更新后切词器和更新后嵌入层,包括:
8、利用所述文本字符串信息对所述当前切词器的词汇表进行微调训练,以得到更新后切词器;其中,所述更新后切词的词汇表中包括所述当前切词器的词汇表中未存储的新单词和所述新单词的新词向量编码;
9、将所述新单词确定为触发词,并基于所述触发词和所述新词向量编码对所述当前嵌入层进行微调训练,以得到更新后嵌入层;
10、相应的,所述利用所述文本字符串信息对所述当前切词器的词汇表和所述当前嵌入层进行微调训练的过程中,还包括:
11、冻结所述当前unet网络以及所述当前文本编码器中除所述当前切词器、所述当前嵌入层以外的参数。
12、可选的,所述利用所述文本字符串信息对所述当前切词器的词汇表进行微调训练,以得到更新后切词器,包括:
13、基于所述当前切词器的词汇表对所述文本字符串信息进行分词,以得到各个单词以及各个所述单词的词向量编码,并从各个所述单词中筛选出所述当前切词器的词汇表中未存储的目标单词;
14、将所述目标单词确定为新单词,并将所述目标单词的词向量编码确定为新词向量编码;
15、基于所述新单词和所述新词向量编码更新所述当前切词器,以得到更新后切词器。
16、可选的,所述基于所述触发词和所述新词向量编码对所述当前嵌入层进行微调训练,以得到更新后嵌入层,包括:
17、将与所述触发词对应的词向量进行初始化,以确定768维度为所述词向量的目标维度;
18、在所述目标维度下,基于所述触发词和所述新词向量编码对所述当前嵌入层进行微调训练,以得到更新后嵌入层。
19、可选的,所述对所述当前unet网络中交叉注意力层的键向量和值向量进行微调训练,以得到更新后unet网络,包括:
20、确定所述当前unet网络中交换网络层的交叉注意力层,并确定所述交叉注意力层中与键向量和值向量对应的线性层;
21、对所述线性层进行微调训练,以得到更新后unet网络;
22、相应的,所述对所述线性层进行微调训练,以得到更新后unet网络的过程中,还包括:
23、冻结所述交叉注意力层中除键向量和值向量以外的参数。
24、可选的,所述对所述线性层进行微调训练,以得到更新后unet网络,包括:
25、根据所述交叉注意力层在所述交换网络层的层深将所述线性层进行降秩拆解,以得到线性低秩矩阵,并对所述线性低秩矩阵进行微调训练,以得到更新后unet网络。
26、可选的,所述根据所述交叉注意力层在所述交换网络层的层深将所述线性层进行降秩拆解,以得到线性低秩矩阵,包括:
27、若所述交换网络层为所述交换网络层的浅层网络层,则以第一预设网络秩将所述线性层进行降秩拆解,以得到第一线性低秩矩阵;
28、若所述交换网络层为所述交换网络层的中间层网络层,则以第二预设网络秩将所述线性层进行降秩拆解,以得到第二线性低秩矩阵;
29、若所述交换网络层为所述交换网络层的深层网络层,则以第三预设网络秩将所述线性层进行降秩拆解,以得到第三线性低秩矩阵。
30、第二方面,本申请公开了一种扩散模型微调装置,包括:
31、信息获取模块,用于获取待微调扩散模型和文本字符串信息;其中,所述待微调扩散模型包括当前文本编码器和当前unet网络,所述当前文本编码器包括当前切词器和当前嵌入层;
32、第一微调模块,用于利用所述文本字符串信息对所述当前切词器的词汇表和所述当前嵌入层进行微调训练,以得到更新后切词器和更新后嵌入层,并基于所述更新后切词器和所述更新后嵌入层获取更新后文本编码器;
33、第二微调模块,用于冻结所述更新后文本编码器,并对所述当前unet网络中交叉注意力层的键向量和值向量进行微调训练,以得到更新后unet网络;
34、模型获取模块,用于基于所述更新后文本编码器和所述更新后unet网络获取微调后扩散模型。
35、第三方面,本申请公开了一种电子设备,包括:
36、存储器,用于保存计算机程序;
37、处理器,用于执行所述计算机程序,以实现前述公开的扩散模型微调方法的步骤。
38、第四方面,本申请公开了一种计算机可读存储介质,用于存储计算机程序;其中,所述计算机程序被处理器执行时实现前述公开的扩散模型微调方法的步骤。
39、本申请有益效果为:本申请获取待微调扩散模型和文本字符串信息;其中,所述待微调扩散模型包括当前文本编码器和当前unet网络,所述当前文本编码器包括当前切词器和当前嵌入层;利用所述文本字符串信息对所述当前切词器的词汇表和所述当前嵌入层进行微调训练,以得到更新后切词器和更新后嵌入层,并基于所述更新后切词器和所述更新后嵌入层获取更新后文本编码器;冻结所述更新后文本编码器,并对所述当前unet网络中交叉注意力层的键向量和值向量进行微调训练,以得到更新后unet网络;基于所述更新后文本编码器和所述更新后unet网络获取微调后扩散模型。由此可见,本申请在第一微调阶段只对当前文本编码器中当前切词器和当前嵌入层进行微调训练,以得新的切词器和新的嵌入层,即得到更新后切词器和更新后嵌入层,以本文档来自技高网...
【技术保护点】
1.一种扩散模型微调方法,其特征在于,包括:
2.根据权利要求1所述的扩散模型微调方法,其特征在于,所述利用所述文本字符串信息对所述当前切词器的词汇表和所述当前嵌入层进行微调训练,以得到更新后切词器和更新后嵌入层,包括:
3.根据权利要求2所述的扩散模型微调方法,其特征在于,所述利用所述文本字符串信息对所述当前切词器的词汇表进行微调训练,以得到更新后切词器,包括:
4.根据权利要求2所述的扩散模型微调方法,其特征在于,所述基于所述触发词和所述新词向量编码对所述当前嵌入层进行微调训练,以得到更新后嵌入层,包括:
5.根据权利要求1至4任一项所述的扩散模型微调方法,其特征在于,所述对所述当前UNet网络中交叉注意力层的键向量和值向量进行微调训练,以得到更新后UNet网络,包括:
6.根据权利要求5所述的扩散模型微调方法,其特征在于,所述对所述线性层进行微调训练,以得到更新后UNet网络,包括:
7.根据权利要求6所述的扩散模型微调方法,其特征在于,所述根据所述交叉注意力层在所述交换网络层的层深将所述线性层进行降秩
8.一种扩散模型微调装置,其特征在于,包括:
9.一种电子设备,其特征在于,包括:
10.一种计算机可读存储介质,其特征在于,用于存储计算机程序;其中,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述的扩散模型微调方法的步骤。
...【技术特征摘要】
1.一种扩散模型微调方法,其特征在于,包括:
2.根据权利要求1所述的扩散模型微调方法,其特征在于,所述利用所述文本字符串信息对所述当前切词器的词汇表和所述当前嵌入层进行微调训练,以得到更新后切词器和更新后嵌入层,包括:
3.根据权利要求2所述的扩散模型微调方法,其特征在于,所述利用所述文本字符串信息对所述当前切词器的词汇表进行微调训练,以得到更新后切词器,包括:
4.根据权利要求2所述的扩散模型微调方法,其特征在于,所述基于所述触发词和所述新词向量编码对所述当前嵌入层进行微调训练,以得到更新后嵌入层,包括:
5.根据权利要求1至4任一项所述的扩散模型微调方法,其特征在于,所述对所述当前unet...
【专利技术属性】
技术研发人员:刘艺博,张璐,陶明,
申请(专利权)人:上海任意门科技有限公司,
类型:发明
国别省市:
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。