一种预训练模型抽选框架的构建方法及装置制造方法及图纸

技术编号:34125750 阅读:8 留言:0更新日期:2022-07-14 14:09
本发明专利技术提供了一种预训练模型抽选框架的构建方法及装置,该方法包括:选取图像数据集和自监督对比学习框架;根据图像数据集和自监督对比学习框架对构建的超网预训练模型进行训练,得到训练好的超网预训练模型;选取下游迁移任务和下游迁移数据集;在基于自监督对比学习框架获取的采样空间中筛选符合预设条件的第一模型,基于下游迁移任务和下游迁移数据集计算第一模型与训练好的超网预训练模型的相似度;基于相似度的计算结果,确定与训练好的超网预训练模型共享权重的目标预训练模型,得到预训练模型抽选框架。该方法可以实现高效的下游任务定制化抽取,抽取出的模型具有极佳的泛化能力。的泛化能力。的泛化能力。

【技术实现步骤摘要】
一种预训练模型抽选框架的构建方法及装置


[0001]本专利技术涉及对比自监督学习领域,尤其涉及一种预训练模型抽选框架的构建方法及装置。

技术介绍

[0002]模型自监督预训练是一个重要且具有挑战性的计算机视觉任务,可以通过在相应的提取框架中提取自监督预训练模型实现,在医学图像诊断,图像分割等有标注数据的领域具有广泛的应用。
[0003]在各个应用场景下,基于现有的提取框架进行模型自监督预训练,达到类似COCO数据集的标注数量级,成本太高,难以实现,一般只能获取少量低成本的标注数据,数据受限导致模型训练十分困难,一般都需要选取预训练好的模型在下游数据集微调权重。与此同时,在各种不同的应用场景下,可使用的硬件资源差异十分明显,能够部署的模型也不相同,针对不同模型需要进行单独的预训练,模型的复用性很差,十分浪费硬件资源。

技术实现思路

[0004]本专利技术提供一种种预训练模型抽选框架的构建方法及装置,用以解决现有技术中对现有对比学习自监督方法模型收敛慢,以及下游模型选择耗费过多资源的缺陷,可以提升模型的收敛速度、减少资源浪费和提升模型的复用性。
[0005]第一方面,本专利技术提供了一种预训练模型抽选框架的构建方法,包括:选取图像数据集和自监督对比学习框架;根据所述图像数据集和所述自监督对比学习框架对构建的超网预训练模型进行训练,得到训练好的超网预训练模型;选取下游迁移任务和下游迁移数据集;在基于所述自监督对比学习框架获取的采样空间中筛选符合预设条件的第一模型,基于所述下游迁移任务和所述下游迁移数据集计算所述第一模型与所述训练好的超网预训练模型的相似度;基于所述相似度的计算结果,确定与所述训练好的超网预训练模型共享权重的目标预训练模型,得到预训练模型抽选框架。
[0006]进一步地,根据所述图像数据集和所述自监督对比学习框架对构建的超网预训练模型进行训练,得到训练好的超网预训练模型,包括:将所述图像数据集输入所述自监督对比学习框架进行计算,得到构建的超网预训练模型的损失函数;基于所述损失函数对所述构建的超网预训练模型进行训练,得到训练好的超网预训练模型。
[0007]进一步地,所述将所述图像数据集输入所述自监督对比学习框架进行计算,得到构建的超网预训练模型的损失函数,包括:在每一训练轮次中,将所述图片数据集中的图片数据按预设数目的批次进行划分,并将每一批次的每个图片数据分别进行两次数据增强,得到所述每一批次的每张图片数据对应的两组数据增强的图片数据;基于所述自监督对比学习框架的特征提取骨干网络,设定所述采样空间,在所述采样空间中随机选取一个模型结构,基于选取的模型结构对所述监督对比学习框架中的梯度更新分支网络进行修改;将所述每一批次的每张图片数据对应的两组数据增强的图片数据,分别输入所述监督对比学
习框架中的修改的梯度更新分支网络和非梯度更新分支网络进行计算,得到所述损失函数。
[0008]进一步地,所述将每一批次的每个图片数据分别进行两次数据增强,得到所述每一批次的每张图片数据对应的两组数据增强的图片数据,包括:将所述每一批次的每个图片数据分别进行两次缩放、翻转、颜色转化和裁剪,得到所述每一批次的每张图片数据对应的两组数据增强的图片数据。
[0009]进一步地,所述基于所述下游迁移任务和所述下游迁移数据集计算所述第一模型与所述训练好的超网预训练模型的相似度,包括:基于所述下游迁移任务,筛选所述第一模型与所述训练好的超网预训练模型在所述下游迁移数据集中进行推理得到的第一模型中间特征相似图和训练好的超网预训练模型中间特征相似图;对所述第一模型中间特征相似图和所述训练好的超网预训练模型中间特征相似图进行相似度计算,得到所述第一模型与所述训练好的超网预训练模型的相似度。
[0010]进一步地,所述基于所述相似度的计算结果,确定与所述训练好的超网预训练模型共享权重的目标预训练模型,包括:将与所述训练好的超网预训练模型的相似度最大的第一模型作为所述目标预训练模型。
[0011]第二方面,本专利技术还提供例如一种预训练模型抽选框架的构建装置,包括:第一选取模块,用于选取图像数据集和自监督对比学习框架;训练模块,用于根据所述图像数据集和所述自监督对比学习框架对构建的超网预训练模型进行训练,得到训练好的超网预训练模型;第二选取模块,用于选取下游迁移任务和下游迁移数据集;计算模块,用于在基于所述自监督对比学习框架获取的采样空间中筛选符合预设条件的第一模型,基于所述下游迁移任务和所述下游迁移数据集计算所述第一模型与所述训练好的超网预训练模型的相似度;确定模块,用于基于所述相似度的计算结果,确定与所述训练好的超网预训练模型共享权重的目标预训练模型,得到预训练模型抽选框架。
[0012]第三方面,本专利技术实施例还提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如第一方面所述的预训练模型抽选框架的构建方法的步骤。
[0013]第四方面,本专利技术实施例还提供了一种非暂态计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现如第一方面所述的预训练模型抽选框架的构建方法的步骤。
[0014]第五方面,本专利技术实施例还提供了一种计算机程序产品,其上存储有可执行指令,该指令被处理器执行时使处理器实现如第一方面所述的预训练模型抽选框架的构建方法的步骤。
[0015]本专利技术提供的预训练模型抽选框架的构建方法及装置,通过选取图像数据集和自监督对比学习框架;根据图像数据集和自监督对比学习框架对构建的超网预训练模型进行训练,得到训练好的超网预训练模型;选取下游迁移任务和下游迁移数据集;在基于自监督对比学习框架获取的采样空间中筛选符合预设条件的第一模型,基于下游迁移任务和下游迁移数据集计算第一模型与所述训练好的超网预训练模型的相似度;基于相似度的计算结果,确定与训练好的超网预训练模型共享权重的目标预训练模型,得到预训练模型抽选框架。通过确定目标预训练模型得到的预训练模型抽选框架,可以实现高效的下游任务定制
化抽取,抽取出的模型具有极佳的泛化能力,能够更好的适应下游任务,无需任何下游有监督训练即可进行模型抽取。在实际应用场景中,下游任务有标签数据较多时,相较于常规方法进行模型选取时需要耗费大量GPU资源在下游任务上进行梯度更新训练,耗费过多资源,本方法仅需要模型在下游数据集上进行推理,无需任何GPU的硬件需求即可选出最优模型。
附图说明
[0016]为了更清楚地说明本专利技术或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本专利技术的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
[0017]图1为本专利技术提供的预训练模型抽选框架的构建方法实施例的流程示意图;
[0018]图2为本专利技术提供的获取训练好的超网预训练模型的方法实施例的流程示意图;
[0019]图3为本发本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种预训练模型抽选框架的构建方法,其特征在于,包括:选取图像数据集和自监督对比学习框架;根据所述图像数据集和所述自监督对比学习框架对构建的超网预训练模型进行训练,得到训练好的超网预训练模型;选取下游迁移任务和下游迁移数据集;在基于所述自监督对比学习框架获取的采样空间中筛选符合预设条件的第一模型,基于所述下游迁移任务和所述下游迁移数据集计算所述第一模型与所述训练好的超网预训练模型的相似度;基于所述相似度的计算结果,确定与所述训练好的超网预训练模型共享权重的目标预训练模型,得到预训练模型抽选框架。2.根据权利要求1所述的预训练模型抽选框架的构建方法,其特征在于,根据所述图像数据集和所述自监督对比学习框架对构建的超网预训练模型进行训练,得到训练好的超网预训练模型,包括:将所述图像数据集输入所述自监督对比学习框架进行计算,得到构建的超网预训练模型的损失函数;基于所述损失函数对所述构建的超网预训练模型进行训练,得到训练好的超网预训练模型。3.根据权利要求2所述的预训练模型抽选框架的构建方法,其特征在于,所述将所述图像数据集输入所述自监督对比学习框架进行计算,得到构建的超网预训练模型的损失函数,包括:在每一训练轮次中,将所述图片数据集中的图片数据按预设数目的批次进行划分,并将每一批次的每个图片数据分别进行两次数据增强,得到所述每一批次的每张图片数据对应的两组数据增强的图片数据;基于所述自监督对比学习框架的特征提取骨干网络,设定所述采样空间,在所述采样空间中随机选取一个模型结构,基于选取的模型结构对所述监督对比学习框架中的梯度更新分支网络进行修改;将所述每一批次的每张图片数据对应的两组数据增强的图片数据,分别输入所述监督对比学习框架中的修改的梯度更新分支网络和非梯度更新分支网络进行计算,得到所述损失函数。4.根据权利要求3所述的基于预训练模型抽选框架的构建方法,其特征在于,所述将每一批次的每个图片数据分别进行两次数据增强,得到所述每一批次的每张图片数据对应的两组数据增强的图片数据,包括:将所述每一批次的每个图片数据分别进行两次缩放、翻转、颜色转化和裁剪,得到所述每一批次的每张图片数据对应的两组数据增强的图片数据。...

【专利技术属性】
技术研发人员:张兆翔常清彭君然
申请(专利权)人:中国科学院自动化研究所
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1