一种基于深度迁移学习的图片分类方法技术

技术编号:20682220 阅读:40 留言:0更新日期:2019-03-27 19:18
本发明专利技术请求保护一种基于深度迁移学习的图片分类方法,其中,所述的领域适应至少包含两个领域的数据,分别为源域和目标域,并且源域数据为已标记的样本数据.所述方法主要包括以下步骤:步骤1)数据准备阶段.准备源域数据和目标域数据,确定目标类别集合.步骤2)特征提取模型构建阶段.使用ResNet和自注意力网络构建基础特征提取模型.步骤3)领域对抗模型构建阶段.使用领域对抗模型预测样本类别和样本领域;步骤4)训练阶段.对源域和目标域样本进行领域标记,设置基于样本迁移权重的损失函数.步骤5)预测阶段.对目标域数据进行预测,将类别预测结果作为最终结果.本发明专利技术降低标记成本,达到知识迁移的目的。

【技术实现步骤摘要】
一种基于深度迁移学习的图片分类方法
本专利技术属于计算机信息处理
,具体涉及人工智能及大数据处理相关领域。
技术介绍
随着数据规模和计算资源的日益增长,大数据处理技术也经历着高速的发展.机器学习作为大数据处理技术的有效工具之一,在大数据处理技术中发挥着关键作用,监督学习是机器学习中的一个重要分支,其特点是包含标记信息的学习,在现实生活中,某些任务的标记往往是难以获得的,例如图片数据,需要大量的人力进行标记,现有最大的已标记的图片数据集为ImageNet,包含1500万张已有标记的图片数据.其标记任务由167个国家的48940位工作人员,使用了2年时间完成.因此如何降低标记成本是监督学习的一个主要挑战,使用其他领域已标记的样本数据向目标领域进行领域适应是降低标记成本的有效方法,因此,领域适应也是近期较为火热的技术之一。在电商系统中,图片是展示产品的主要方法,但是对于电商图片的分类需要大量的标记工作,如何不依赖额外的标记工作,利用其他领域已标记的数据,对电商图片进行分类,达到无标记学习的目的,正是本专利技术将要解决的问题,另外,本专利技术不仅仅限于电商图片的分类,对于有相同目标类别本文档来自技高网...

【技术保护点】
1.一种基于深度迁移学习的图片分类方法,其特征在于,包括以下步骤:步骤1、数据准备阶段:获取目标域数据,并确定相应的目标类别集合,获取已标记有目标类别的数据,作为源域数据;步骤2、特征提取模型构建阶段:使用ResNet构建模型基础网络,提取基础特征;将ResNet的输出连接到一个自注意力网络,增强图片的结构信息,减少颜色及纹理对类别的影响;将步骤1准备的数据输入到步骤2的特征提取模型,提取图片的特征,输出特征向量;步骤3、领域对抗模型构建阶段:将步骤2输出的特征输入到步骤3的领域对抗模型,输出预测的领域及该图片的类别,将特征提取模型的输出连接到一个预测目标类别的全连接层,得到预测的类别信息,将...

【技术特征摘要】
1.一种基于深度迁移学习的图片分类方法,其特征在于,包括以下步骤:步骤1、数据准备阶段:获取目标域数据,并确定相应的目标类别集合,获取已标记有目标类别的数据,作为源域数据;步骤2、特征提取模型构建阶段:使用ResNet构建模型基础网络,提取基础特征;将ResNet的输出连接到一个自注意力网络,增强图片的结构信息,减少颜色及纹理对类别的影响;将步骤1准备的数据输入到步骤2的特征提取模型,提取图片的特征,输出特征向量;步骤3、领域对抗模型构建阶段:将步骤2输出的特征输入到步骤3的领域对抗模型,输出预测的领域及该图片的类别,将特征提取模型的输出连接到一个预测目标类别的全连接层,得到预测的类别信息,将预测的类别信息与特征提取模型的输出进行拼接,输入至领域判别的全连接层,并通过梯度取反逆向更新领域判别全连接层的参数;步骤4、训练阶段:将源域和目标域数据进行领域标记,使用源域、目标域的领域预测结果分别计算出迁移权重和迁移损失,使用源域的类别预测结果计算出分类损失,将迁移损失和分类损失的和作为反向传播的损失更新模型;步骤5、预测阶段:将目标域的数据输入模型,将类别输出作为最终的预测结果。2.根据权利要求1所述的一种基于深度迁移学习的图片分类方法,其特征在于,所述步骤1数据准备阶段:获取目标域数据,并确定相应的目标类别集合,获取已标记有目标类别的数据,作为源域数据,具体包括:确定待分类的服装商品图片样本集作为目标域样本数据集Dtarget,确定将要分为的类别集合C,其中包含的主要类别为<sweet,wet,shirt,...>,下载ImageNet数据集,并选择其中含有Clothing标记的图片样本集合,作为已有标记的源域样本数据集Dsource,并将Dsource按照样本的类别标记为文件夹名称分别保存。3.根据权利要求1所述的一种基于深度迁移学习的图片分类方法,其特征在于,所述步骤2特征提取模型构建阶段的具体操作包括:使用预训练的含有101层的ResNet模型作为基础的特征提取模型,去除ResNet最后一层全连接层,替换为一个包含2048个神经元的自注意力网络,用来增强图片的全局结构特征,注意力网络的注意力权重计算公式如下,其中,x为图片样本,f(x)为ResNet的输出,i、j表示坐标位置,表示第i个位置对于第j个位置的关注度;其中特征提取模型的输出为注意力网络加权在ResNet模型上的输出,其计算公式如下;同样,f(x)为ResNet的输出,β为上式所计算出的关注度参数,得到的F(x)为特征提取模型的输出.F(x)out=f(x)·β模型训练参数设置为(batch_size:64,iteration:50000),优化器参数设置为(learnin...

【专利技术属性】
技术研发人员:王进王科李林洁杨俏孙开伟刘彬
申请(专利权)人:重庆邮电大学
类型:发明
国别省市:重庆,50

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1