一种分类模型的训练方法和眼底图像分类方法技术

技术编号:38876872 阅读:12 留言:0更新日期:2023-09-22 14:09
本发明专利技术提供一种分类模型的训练方法和眼底图像分类方法,属于增量学习领域。一种分类模型的训练方法,所述分类模型用于眼底图像分类,所述方法包括:获取预训练的分类模型作为教师模型,其包括特征提取网络以及分类器,所述教师模型能识别的眼底类别归为旧类;获取学生模型,其特征提取网络用教师模型初始化,并且其分类器设置为能对旧类和新类对应的眼底类别进行识别,所述新类是所述旧类之外的眼底类别;针对用到的每个旧类,获取该旧类对应的伪样本;利用生成的旧类的伪样本和属于新类的眼底图像对所述学生模型进行增量训练,训练时基于预设的总损失函数确定的总损失更新学生模型的参数,得到经增量训练的学生模型。本发明专利技术可以缓解灾难性遗忘。明可以缓解灾难性遗忘。明可以缓解灾难性遗忘。

【技术实现步骤摘要】
一种分类模型的训练方法和眼底图像分类方法


[0001]本专利技术涉及增量学习领域,具体来说,涉及应用增量学习的医学成像领域,更具体地说,涉及一种分类模型的训练方法和一种眼底图像分类方法。

技术介绍

[0002]深度神经网络(DNNs)在许多机器学习分类任务中表现出色,例如,深度神经网络在医学成像应用中显示出的性能达到了人类水平,以及,最近的研究表明,基于深度学习的模型能够成功应用于视网膜疾病(如糖尿病视网膜病变DR和青光眼)筛查。其中,视网膜疾病筛查中应用的模型是假设在训练前所有类别都是已知的条件下才可实现的。然而,这一假设在医学领域经常被违背。在实际训练过程中,预先准备所有类别的数据是极难实现的,例如,由于一些疾病在不同的发病阶段,疾病的病变程度也会发生细微的变化,因此,疾病的病变程度发生变化的数据进行预先准备是很难实现的。此外,如果一个模型只学会了识别输入样本(例如疾病对应的样本)所属的泛化类别,而不能对所属的具化类别(例如,发病阶段和发病程度不同时的类别)进行准确识别,这会产生极其严重的后果。因为训练后模型的针对输入样本的输出结果会影响后续的应用效果。例如,同一疾病在不同的发病阶段所使用的治疗策略是不同的,而启动正确的治疗方案是取得良好治疗结果的关键。为此,构造机器学习分类任务中的类别增量学习方法,可以识别所属的具化类别(例如,疾病的病变程度与发病阶段对应的类别),从而在后续的应用中,可以在为患者提供持续有效、及时准确的健康检测以及疾病预警方面提供极大帮助。增量学习方法中需要解决的主要挑战是灾难性遗忘。直观地说,灾难性遗忘是由特征空间中新旧类的表示之间的重叠或混淆引起的,在学习新类时,以前类的决策边界可能会发生巨大变化,统一的分类器会有严重的偏差。
[0003]为解决这一挑战(灾难性遗忘),现有方法中包括以下两个不同的突破方向:第一种是将参数偏向于在旧类上学习的方向;第二种是保持一个来自以前任务的小数据缓冲区(这也被称为经验回放)的方向。针对第一种的采用的正则化的策略,在使用多头分类器以及推理时利用可用的任务标识符的场景中是有效的。针对第二种采用的经验回放方法,最为普遍的做法是以保存少数用于后续模型训练的真实数据来实现经验回放。此外,也有一些方法采用了额外的生成模型,如对抗生成网络(GNN)来生成数据实现经验回放。
[0004]然而,正如一些文献所注意到的,解决灾难性遗忘的方法在类别增量学习(CIL)场景下性能较差。目前较为有效的正则化策略是基于知识蒸馏,强制学生模型完全模仿教师模型。具体来说,经过蒸馏的学生模型旨在模仿教师模型在训练样本上模型全连接层输出的logits,以获得与教师模型类似的泛化性能。然而,完全模仿教师模型的输出可能不是最优的,因为教师模型可能会自信地错误预测一些类,这会增加增量过程中错误信息传递的风险。对于经验回放的方法,常常需要大量的内存来重放之前看到的或建模的数据,以避免灾难性的遗忘问题。然而在某些实际场景(例如,物联网应用的设备上或隐私问题)中,由于内存限制,数据存储可能会受到限制。这样使人们专注于增量地合并新信息,而不存储旧知识,这被称为非基于保存样例的增量学习(Non

exemplar

based Incremental Learning)。
[0005]在非基于保存样例的增量学习中,为了进一步从教师模型中学到更多有用的知识,以确定教师模型的哪些知识有助于建立一个更好的学生模型,除了传统的基于logit的蒸馏方法外,基于特征的蒸馏方法也受到了很多关注。这是因为教师模型的特征比基于logit的模型具有更多的信息,使用特征蒸馏可以使学生模型学习到更丰富的信息。然而,大多数基于特征蒸馏的研究都是手动链接教师模型和学生模型的特征,并通过单独的链接进行蒸馏,存在将不正确的中间过程强加给学生的风险。此外,在所有可能的环节中选择少数环节(换而言之,已有的方法中,是在学生模型蒸馏的步骤中使用人为的方式,选择一些人为认为有代表性的环节进行步骤选择操作),也会限制教师模型充分利用自己的全部知识;而且,在大多数知识蒸馏的情况下,学生模型和教师模型的特征具有不同的宽度、高度和渠道,通常是应用卷积层或全连接层来匹配它们的大小。这样使得在特征调整的过程中,教师模型的一些有用信息可能会丢失。
[0006]基于上述问题,将增量学习应用于医学成像领域,其中特别是视网膜异常筛查时,专利CN106022368A伴随的由于类别差异导致的灾难性遗忘问题。

技术实现思路

[0007]因此,本专利技术的目的在于克服上述现有技术的缺陷,提供一种分类模型的训练方法和一种眼底图像分类方法。
[0008]本专利技术的目的是通过以下技术方案实现的:
[0009]根据本专利技术第一方面,提供一种分类模型的训练方法,所述分类模型用于眼底图像分类,所述方法包括:获取预训练的分类模型作为教师模型,其包括用于对输入眼底图像提取图像特征的特征提取网络以及用于根据图像特征识别该图像特征对应眼底图像所属的眼底类别的分类器,所述分类器包括全连接层和Softmax层,所述教师模型能识别的眼底类别归为旧类;获取学生模型,其特征提取网络的可训练参数用教师模型中特征提取网络的可训练参数初始化,并且其分类器设置为能对旧类和新类对应的眼底类别进行识别,所述新类是所述旧类之外的眼底类别;针对所述预训练中用到的每个旧类,获取该旧类对应的伪样本,所述伪样本是利用教师模型的特征提取网络对属于该旧类的多个眼底图像提取的图像特征生成的;利用生成的旧类的伪样本和属于新类的眼底图像对所述学生模型进行增量训练,训练时基于预设的总损失函数确定的总损失更新学生模型的参数,得到经增量训练的学生模型,所述总损失根据以下损失加权求和确定:旧类分类损失、新类分类损失、第一特征蒸馏损失和第二特征蒸馏损失;其中,所述第一特征蒸馏损失根据旧类的伪样本在以下两者上的输出之间的差异进行确定:所述教师模型的全连接层的输出、所述学生模型的全连接层在旧类上的输出;所述第二特征蒸馏损失根据教师模型中特征提取网络的提取到的所有教师特征和经增量训练的学生模型中特征提取网络的提取到的所有学生特征之间的总的差异进行确定,其中,每个教师特征为一个伪样本经教师模型中特征提取网络提取得到的图像特征,每个学生特征为伪样本和属于新类的眼底图像分别经学生模型中特征提取网络提取得到的图像特征。
[0010]在本专利技术的一些实施例中,通过如下步骤计算所述总的差异:基于注意力机制,确定每个教师特征对各个学生特征的注意力值,以及利用该教师特征对各个学生特征的注意力值构成该教师特征对所有学生特征的注意力向量并进行归一化得到归一化后的注意力
向量;基于每个教师特征和每个学生特征的确定空间距离;计算空间距离和归一化后的注意力向量中对应元素的乘积,并将得到的所有乘积进行求和得到总的差异。
[0011]在本专利技术的一些实施例中,基于注意力机制,每个教师特征对各个学生特征的注意力值通过如下步骤计算:将学生特征进行数据转换,得到该学生特征在注意力机制中的一个key;将教师特征进行本文档来自技高网
...

【技术保护点】

【技术特征摘要】
1.一种分类模型的训练方法,所述分类模型用于眼底图像分类,所述方法包括:获取预训练的分类模型作为教师模型,其包括用于对输入眼底图像提取图像特征的特征提取网络以及用于根据图像特征识别该图像特征对应眼底图像所属的眼底类别的分类器,所述分类器包括全连接层和Softmax层,所述教师模型能识别的眼底类别归为旧类;获取学生模型,其特征提取网络的可训练参数用教师模型中特征提取网络的可训练参数初始化,并且其分类器设置为能对旧类和新类对应的眼底类别进行识别,所述新类是所述旧类之外的眼底类别;针对所述预训练中用到的每个旧类,获取该旧类对应的伪样本,所述伪样本是利用教师模型的特征提取网络对属于该旧类的多个眼底图像提取的图像特征生成的;利用生成的旧类的伪样本和属于新类的眼底图像对所述学生模型进行增量训练,训练时基于预设的总损失函数确定的总损失更新学生模型的参数,得到经增量训练的学生模型,所述总损失根据以下损失加权求和确定:旧类分类损失、新类分类损失、第一特征蒸馏损失和第二特征蒸馏损失;其中,所述第一特征蒸馏损失根据旧类的伪样本在以下两者上的输出之间的差异进行确定:所述教师模型的全连接层的输出、所述学生模型的全连接层在旧类上的输出;所述第二特征蒸馏损失根据教师模型中特征提取网络的提取到的所有教师特征和经增量训练的学生模型中特征提取网络的提取到的所有学生特征之间的总的差异进行确定,其中,每个教师特征为一个伪样本经教师模型中特征提取网络提取得到的图像特征,每个学生特征为伪样本和属于新类的眼底图像分别经学生模型中特征提取网络提取得到的图像特征。2.根据权利要求1所述的方法,其特征在于,通过如下步骤计算所述总的差异:基于注意力机制,确定每个教师特征对各个学生特征的注意力值,以及利用该教师特征对各个学生特征的注意力值构成该教师特征对所有学生特征的注意力向量并进行归一化得到归一化后的注意力向量;基于每个教师特征和每个学生特征的确定空间距离;计算空间距离和归一化后的注意力向量中对应元素的乘积,并将得到的所有乘积进行求和得到总的差异。3.根据权利要求2所述的方法,其特征在于,基于注意力机制,每个教师特征对各个学生特征的注意力值通过如下步骤计算:将学生特征进行数据转换,得到该学生特征在注意力机制中的一个key;将教师特征进行数据转换,得到该教师特征在注意力机制中的一个query;计算每个query对各个key的注意力值。4.根据权利要求3所述的方法,其特征在于,通过如下规则进行数据转换:4.根据权利要求3所述的方法,其特征在于,通过如下规则进行数据转换:其中,q
t
表示在注意力机制中的query,表示第t个教师特征,P
HW
(
·
)表示全局平
均池化,表示的线性变换参数,表示的线性变换参数空间矩阵,f
Q
(
·
)表示第一激活函数,k
s
表示在注意力机制中的key,表示第s个学生特征,表示的线性变换参数,表示的线性变换参数空间矩阵,表示的线性变换参数的空间矩阵,f
K
(
·
)表示第二激活函数,d表示线性变换参数空间矩阵的维度。5.根据权利要求4所述的方法,其特征在于,通过以下规则计算每个query对所有key的注意力向量并进行归一化:其中,softmax(.)表示归一化函数,表示q
t
的转置,表示双线性权值,k
t,1
表示对对应的key值,表示第t个教师特征的位置编码,表示第s个学生特征的位置编码,k
t,S
表示对对应的key值,表示和的乘积,表示的转置。6.根据权利要求5所述的方法,其特征在于,通过以下方法计算第二特征蒸馏损失:其中,α
t,s
表示第t个教师特征对第s个学生特征的归一化后的注意力值,表示空间距离,||.||2表示求L2范数,表示通道平均池化层与L2归一化的组合函数v/||v||2,v表示对进行平均池化得到的向量,表示对使用上采样或下采样得到的特征。7.根据权利要求1所述的方法,其特征在于,按照如下步骤获得旧类的伪样本:T1、基于预训练中用到的旧类对应的多个图像特征计算该旧类对应的类均值向...

【专利技术属性】
技术研发人员:谷洋郭帅文世杰马媛杨昭华翁伟宁陈益强
申请(专利权)人:中国科学院计算技术研究所
类型:发明
国别省市:

网友询问留言 已有0条评论
  • 还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。

1