基于卷积辅助自注意力的胸部疾病分类网络
Classification network for chest disease based on convolution-assisted self-attention
通讯作者:
收稿日期: 2024-03-1
基金资助: |
|
Received: 2024-03-1
Fund supported: | 国家自然科学基金资助项目(62071323);超声医学工程国家重点实验室开放课题资助项目(2022KFKT004);天津市自然科学基金资助项目(22JCZDJC00220). |
作者简介 About authors
张自然(1998—),男,硕士生,从事深度学习图像处理的研究.orcid.org/0009-0008-2472-5280.E-mail:
针对胸部X光影像中的病变大小不一,纹理复杂,且存在相互影响等问题,提出基于卷积辅助窗口自注意力的胸部X光影像疾病分类网络CAWSNet. 使用Swin Transformer作为骨干网络,以窗口自注意力建模长距离视觉依赖关系,通过引入卷积辅助,在弥补其缺陷的同时,强化局部特征提取能力. 引入图像相对位置编码,通过有向相对位置的动态计算,帮助网络更好地建模像素间的位置关系. 使用类别残差注意力,根据疾病类别来调整分类器关注的区域,突出有效信息,提高多标签分类能力. 提出动态难度损失函数,解决不同疾病分类的难度差异大,数据集中正负样本不平衡的问题. 在公开数据集ChestX-Ray14、CheXpert和MIMIC-CXR-JPG上的实验结果表明,提出CAWSNet的AUC分数分别达到0.853、0.898和0.819,表明该网络在胸部X光影像疾病诊断中的有效性和鲁棒性.
关键词:
A chest disease classification network based on convolution-assisted window self-attention was proposed, called CAWSNet, aiming at the issues of varying lesion sizes, complex textures, and mutual interference in chest X-ray images. The Swin Transformer was utilized as the backbone, employing window self-attention to model long-range visual dependencies. Convolution was introduced to enhance local feature extraction capability while compensating for the deficiencies of window self-attention. Image relative position encoding was used to dynamically calculate directed relative positions, helping the network better model pixel-wise spatial relationships. Class-specific residual attention was employed, and the classifier’s focus area was adjusted based on disease categories in order to highlight effective information and enhance multi-label classification capability. Dynamic difficulty loss function was proposed to alleviate the problem of large differences in disease classification difficulty and the imbalance of positive and negative samples in the dataset. The experimental results on the public datasets ChestX-Ray14, CheXpert and MIMIC-CXR-JPG demonstrate that proposed CAWSNet achieves AUC scores of 0.853, 0.898 and 0.819, respectively, confirming the effectiveness and robustness of the network in diagnosing chest diseases through X-ray images.
Keywords:
本文引用格式
张自然, 李锵, 关欣.
ZHANG Ziran, LI Qiang, GUAN Xin.
胸部疾病的诊断是多标签分类问题[4],由于CXR图像通常包含多个病理标签,这些病变部位在不同的阶段拥有不同的大小、形状和纹理表现[5-7]. 为了提高分类性能,Li等[8-9]基于卷积神经网络(convolut-ional neural networks,CNN)作出了很多尝试. 胸部疾病的病理间常常存在交叉重叠和互相影响的情况. 卷积运算存在难以捕捉长距离视觉和语义信息的缺点,不能完全满足CXR图像分析的需要. 对于这一缺陷,自注意力机制[10-11] 可以通过动态计算相关像素间的关系来自适应地关注不同的区域,捕获更多的信息特征. 全局自注意力的计算复杂度与图像大小为二次关系,计算成本很高,局部特征提取能力较弱.
针对以上问题,本文提出基于卷积辅助窗口自注意力(convolution-assisted window self-attention,CAWS)的胸部X光影像疾病分类网络,命名为CAWSNet. 针对卷积运算容易丢失CXR图像中的长距离视觉语义信息的缺点,选择以Swin Transformer[12]为骨干网络,利用窗口自注意力,实现对较大区域的关注,建模长距离依赖关系. 在窗口自注意力中以轻量的方法引入卷积,提出全新的卷积辅助窗口自注意力模块,强化网络对CXR图像局部纹理、轮廓的特征提取能力,弥补窗口自注意力的缺陷. 引入通过有向映射计算相对位置,利用动态变化的图像相对位置编码(image relative position encoding,IRPE)[13]配合自注意力的分窗计算. 使用类别残差注意力(class-specific residual attention, CSRA)[14]优化分类器,提高针对多种病理的分类能力. 此外,针对疾病样本不均匀,分类难度差异较大的问题,提出动态难度损失函数,优化网络对计算资源的分配. 在ChestX-ray14[15]、CheXpert[16]和MIMIC-CXR-JPG[17]3个公开数据集上的实验结果以及与多种先进方法的评估分析,验证了所提出网络的性能. 工作代码已经公开,代码链接为:https://github.com/ZhangZr11/CAWSNet.git.
1. 相关工作
随着深度学习技术的发展,计算机视觉领域取得了许多突破性的进展,其中包括各种医学图像处理任务. 公开数据集ChestX-ray14、CheXpert和MIMIC-CXR-JPG的发表,使得越来越多的研究者将目光投向胸部疾病分类这一多标签分类任务上. Wang等[15]使用AlexNet、ResNet、VGGNet和GooGLeNet 4个经典的CNN架构,在ChestX-ray14数据集上进行胸部疾病分类的研究,其中ResNet的分类效果最突出. Chen等[18]提出双不对称特征学习网络DualCheXNet,结合基于ResNet和DenseNet的2个非对称子网络,以便从原始CXR图像中不对称地学习互补特征,用于多标签胸部疾病分类. Wang等[19]提出三重注意力学习网络A3Net,使用预训练的DenseNet-121作为骨干网络进行特征提取,将通道、像素和尺度3种注意力集成在统一的框架中,分别关注特征图的通道、病变区域和不同尺度,更好地完成疾病分类任务. Chen等[20]提出基于金字塔卷积模块和Shuffle注意力模块的残差网络. 其中,金字塔卷积模块用于提取病理异常的多尺度判别特征,而Shuffle注意力模块通过分组整合空间和通道注意力,显著提升了对病变区域的聚焦效果. Chen等[21]提出新的语义相似图嵌入框架,该网络根据批量CXR图像的语义标签生成相似性图,以此为依据,使用图卷积网络自适应地重新校准从CNN网络提取的视觉特征,提高多标签CXR图像分类的性能.
在目前主流的方法中,特征提取的任务大多是由卷积神经网络来完成. 随着视觉Transformer的兴起,Jiang等[22]提出金字塔视觉Transformer的新变体实现胸部疾病分类,该网络通过自注意力捕获长距离视觉信息,使用下采样空间缩减注意力,减少使用全局自注意力的资源消耗. Liu等[12]提出新的视觉网络Swin Transformer,使用窗口自注意力处理特征图,令计算复杂度与图像大小成线性关系,提出移位窗口方法,弥补分窗计算带来的信息丢失,在图像分类和目标检测领域取得了优秀的效果. 上述方法虽然在胸部疾病分类任务中取得了优秀的效果,但都无法在特征提取阶段很好地兼顾局部特征和长距离视觉依赖关系,导致有效信息的丢失. 以Swin Transformer为骨干网络,引入卷积作为辅助,提出基于卷积辅助窗口自注意力的胸部X光影像疾病分类网络CAWSNet. 使用自注意力和卷积,互补地提取CXR图像特征,获得更加优秀的分类性能.
2. 研究方法
2.1. 结构概述
CAWSNet的整体架构如图1所示. 网络使用补丁分割层和线性编码层,实现了对图像的编码. 前者将CXR图像分割成大小为
图 1
2.2. CAWS Transformer块和窗口移位方法
CAWS Transformer块遵循了Swin Transformer的设计方式,结构如图2所示. 该块主要由相邻的CAWS模块和MLP层构成,其间使用残差链接相连,并在每个CAWS模块和MLP层前加1个LN (Layer Norm)层. 在网络中,该块以两两一组的方式堆叠使用,且在第2个块中,对CAWS模块使用窗口移位操作.
图 2
自注意力的分窗计算会导致窗口间缺乏信息交互,使用Swin Transformer中提出的窗口移位操作来建立窗口间的链接. 原理是通过改变窗口划分方式,使2次窗口自注意力计算关注不同的区域,以窗口数为4举例,实现方法如图3所示. 其中,图3(a)显示了2个块中分窗方式的不同:第1个块直接从左上角开始均匀分窗,第2个块在第1次分窗方式的基础上,将窗口向右方和下方分别滑动
图 3
2.3. 卷积辅助窗口自注意力
卷积辅助窗口自注意模块是使用窗口自注意力和卷积并行处理特征图的混合模块,结构如图4所示. 卷积的引入有以下3种作用. 1)使用卷积方法作为并行支路处理输入的特征图,可以强化网络的局部特征提取能力. 2)利用卷积支路输出的通道信息,可以强化窗口自注意力支路的通道建模能力. 3)模块输出端将卷积和自注意力两支路的计算结果加权求和,可以利用卷积结果建立自注意力窗口间的链接,弥补分窗的影响.
图 4
受ACmix方法[23]的启发,该模块在实现卷积和窗口自注意力的并行计算时,采取通过共享权重来节约计算资源的方法,将计算拆解为2个阶段. 假设输入和输出特征图分别为
对于卷积运算,第1阶段是使用
式中:
式中:
对于窗口自注意力计算,第1阶段使用
式中:
第2阶段是对投影后的特征图进行分窗,计算每个窗口的多头自注意力:
式中:
窗口自注意力在空间维度上动态计算权重,通道之间缺乏信息交互,而卷积运算的通道间存在充分的信息交互与整合. 模块利用卷积支路的输出,生成通道权重,并将其作用于窗口自注意力支路,强化通道建模能力. 对于通道权重的生成,采取类似挤压激励块[24]的设计:使用平均池化层,对特征图的全局空间信息进行压缩;通过2个
式中:
模块的具体计算流程如下:对于输入特征图
式中:
2.4. 图像相对位置的编码
Swin Transformer中使用的相对位置偏置只计算了像素间的相对距离,没有考虑像素间的相对方向,在处理像素高度结构化的CXR图像时表现不理想. 通过有向映射计算相对位置,使用位置权重与输入进行交互的图像相对位置编码,使得网络更好地捕捉CXR图像中像素间的位置关系.
对于2个像素间相对位置的计算,在x轴和y轴上分别计算相对距离,并映射成为有限集中的整数.
式中:
式中:
采用3种方法通过位置权重
后2种方式使用位置权重与输入的
2.5. 类别残差注意分类器
胸部疾病存在并发性,因此CXR图像中常出现多种疾病,不同疾病的病变区域不同. 多标签胸部疾病分类可以视为多个单一疾病的二元分类任务,为了在每类疾病的识别中更好地关注当前类别疾病所处的区域,提高识别的准确率,在分类阶段使用类别残差注意分类器,结构如图5所示.
图 5
对于要分类的特征向量
式中:
以平均池化结果为主要特征,以空间池化结果为残差特征,计算分类分数:
式中:
2.6. 动态难度损失函数
在胸部疾病分类任务中,各种疾病区域的大小、轮廓、纹理特征均呈现多样化,且数据集中的样本分布不均匀,导致分类难度具有较大的差异,阻碍了多标签分类任务准确率的提高. 在一张CXR图像中,往往只有1~3种疾病标签为正标签,其余均为负,这意味着每类疾病的正样本数量远低于负样本. 为了解决以上问题,在焦点损失函数[25]的基础上,加入基于AUC分数动态调节的平衡系数,提出动态难度损失函数(dynamic difficulty loss,DDL). 对于每种疾病,计算公式如下:
式中:
3. 实验结果与分析
3.1. 实验数据集
采用3个公开可用的数据集作为评估基准:美国国立卫生研究院发布的ChestX-Ray14、斯坦福大学研究人员发布的CheXpert和麻省理工大学发布的MIMIC-CXR-JPG.
MIMIC-CXR-JPG[17](在后续实验中简称为MIMIC-CXR)是包含377 110幅图像和227 835项成像报告的大型数据集,数据来自2011—2016年在美国哈佛医学院贝斯以色列女执事医疗中心急诊部就诊的65 379名患者. 每份成像报告对应1幅或多幅CXR图像,通常为正面或侧面视图. 数据集包含12种胸部疾病类别以及2种非疾病类别“无发现”和“支持设备”,每个类别的观测值被指定为正(1)、负(0)或不确定(−1). 为了公平起见,实验中使用的数据集按照官方公布的方式进行分割.
3.2. 评价指标
接受者操作特性(receiver operating characteristics, ROC)曲线表示算法对每种病理的识别能力,通过计算ROC曲线下面积(area under ROC curve, AUC),对算法能力进行定量分析和比较. 在ROC曲线中,FPR为在所有阴性类别中被错误地认为是阳性类别的阴性类别的百分比. TPR为在所有阳性类别中被正确识别的阳性类别的比例. TPR和FPR的计算如下:
式中:
3.3. 实验细节
实验在Pytorch[26]框架上实现. 对于训练,使用Adam优化器对网络进行优化,batch size为32,训练轮次为20. 初始学习率为0.000 1,每2个轮次学习率乘以0.9. 为了提高网络的收敛速度和学习能力,实验中的骨干网络将在ImageNet上进行预训练. 当验证集上的损失不再减少或开始增加时,训练将停止.
在数据预处理阶段,对输入CXR图像执行数据增强的具体方法和步骤如下.
1)将图像大小调整为
2)随机裁剪图像至
3)以50%的概率水平翻转图像.
4)在[−5°,5°]随机旋转图像.
5)设置图像的对比度、饱和度和色调为90%~110%.
6)将图像转化为向量格式,并进行归一化处理.
3.4. 与现有SOTA方法的比较
为了验证网络在胸部疾病分类任务上的有效性和准确性,将提出的CAWSNet在ChestX-Ray14、CheXpert和MIMIC 3个数据集上进行实验,并与现有的SOTA方法进行比较. 利用DCNN[27]和MXT[21]2种方法,对网络结构进行优化. TransDD[28]、PCAN[29]和PCSANet[19]3种方法使用注意力机制帮助网络关注重要的信息,提高分类效果. SSGE[20]、LCT[30]和 CheXGAT[31]通过对病理标签间的相关性进行建模,辅助网络进行分类. MAE使用掩膜自动编码器,在CXR数据上预训练ViT进行分类[32]. ML-LGL利用临床知识杠杆选择函数生成异常递增的课程,通过课程学习来训练DNN模型[33]. MVCNet在特征和决策层面融合正面和侧面2种视图的CXR图像来辅助分类[34]. MMBT[35]和MedCLIP[36]利用医学报告中的文本信息作为辅助,使用多模态学习的方法完成疾病的识别. 在ChestX-Ray14测试集、CheXpert验证集和MIMIC-CXR测试集上,每种病理的ROC曲线和AUC分数如图6~8所示,可以看出所提方法的分类性能(为了使标记清晰,每条曲线都由间隔取点的40个数据点绘制). 如表1~3所示分别为CAWSNet与其他SOTA方法在3个数据集上的比较结果.
表 2 不同胸部疾病分类网络在CheXpert验证集上的结果比较
Tab.2
疾病类别 | AUC | ||||||
U-Ignore | U-Zeros | U-Ones | PCAN | DCNN | MAE | CAWSNet | |
肺不张 | 0.818 | 0.811 | 0.858 | 0.848 | 0.825 | 0.827 | 0.835 |
心脏肿大 | 0.828 | 0.840 | 0.832 | 0.865 | 0.855 | 0.835 | 0.856 |
肺实变 | 0.938 | 0.932 | 0.899 | 0.908 | 0.937 | 0.925 | 0.917 |
水肿 | 0.934 | 0.929 | 0.941 | 0.912 | 0.930 | 0.938 | 0.953 |
胸膜增厚 | 0.928 | 0.931 | 0.934 | 0.940 | 0.923 | 0.941 | 0.928 |
平均值 | 0.889 | 0.889 | 0.893 | 0.895 | 0.894 | 0.893 | 0.898 |
图 6
图 6 ChestX-ray14测试集上胸部疾病的ROC曲线和AUC值
Fig.6 ROC curves and AUC values of chest diseases on ChestX-ray14 test set
图 7
图 7 CheXpert验证集上胸部疾病的ROC曲线和AUC值
Fig.7 ROC curve and AUC value of chest disease on CheXpert validation set
图 8
图 8 MIMIC-CXR测试集上胸部疾病的ROC曲线和AUC值
Fig.8 ROC curve and AUC value of chest disease on MIMIC-CXR test set
表 1 不同胸部疾病分类网络在ChestX-Ray14测试集上的结果比较
Tab.1
疾病类别 | AUC | ||||||||
MXT | TransDD | PCAN | PCSANet | LCT | CheXGAT | SSGE | ML-LGL | CAWSNet | |
肺不张 | 0.798 | 0.791 | 0.785 | 0.807 | 0.789 | 0.787 | 0.792 | 0.782 | 0.829 |
心脏肿大 | 0.896 | 0.885 | 0.897 | 0.910 | 0.889 | 0.879 | 0.892 | 0.904 | 0.918 |
积液 | 0.842 | 0.842 | 0.837 | 0.879 | 0.842 | 0.837 | 0.840 | 0.835 | 0.892 |
浸润 | 0.719 | 0.715 | 0.706 | 0.698 | 0.694 | 0.699 | 0.714 | 0.707 | 0.726 |
肿块 | 0.856 | 0.837 | 0.834 | 0.824 | 0.843 | 0.839 | 0.848 | 0.853 | 0.857 |
结节 | 0.809 | 0.803 | 0.786 | 0.750 | 0.803 | 0.793 | 0.812 | 0.779 | 0.784 |
肺炎 | 0.758 | 0.745 | 0.730 | 0.750 | 0.742 | 0.741 | 0.733 | 0.739 | 0.782 |
气胸 | 0.879 | 0.885 | 0.871 | 0.850 | 0.896 | 0.879 | 0.885 | 0.889 | 0.903 |
肺实变 | 0.759 | 0.753 | 0.763 | 0.802 | 0.757 | 0.755 | 0.753 | 0.771 | 0.820 |
水肿 | 0.849 | 0.859 | 0.849 | 0.888 | 0.858 | 0.851 | 0.848 | 0.866 | 0.906 |
肺气肿 | 0.906 | 0.944 | 0.921 | 0.890 | 0.944 | 0.945 | 0.948 | 0.949 | 0.935 |
纤维化 | 0.847 | 0.849 | 0.817 | 0.812 | 0.863 | 0.842 | 0.827 | 0.846 | 0.827 |
胸腔积液 | 0.800 | 0.803 | 0.791 | 0.768 | 0.799 | 0.794 | 0.795 | 0.787 | 0.817 |
疝气 | 0.913 | 0.924 | 0.943 | 0.915 | 0.915 | 0.931 | 0.932 | 0.907 | 0.939 |
平均值 | 0.830 | 0.831 | 0.824 | 0.825 | 0.831 | 0.827 | 0.830 | 0.830 | 0.853 |
表 3 不同胸部疾病分类网络在MIMIC-CXR测试集上的结果比较
Tab.3
疾病类别 | AUC | |||
MVCNet | MMBT | MedCLIP | CAWSNet | |
肺不张 | 0.818 | 0.758 | — | 0.841 |
心脏肿大 | 0.848 | 0.826 | — | 0.824 |
实变 | 0.829 | 0.771 | — | 0.833 |
水肿 | 0.919 | 0.843 | — | 0.900 |
心纵膈扩大 | 0.725 | 0.743 | — | 0.771 |
骨折 | 0.665 | 0.729 | — | 0.660 |
肺部异常 | 0.740 | 0.759 | — | 0.804 |
肺不透明 | 0.757 | 0.715 | — | 0.748 |
无发现 | 0.842 | 0.831 | — | 0.867 |
胸膜增厚 | 0.947 | 0.886 | — | 0.922 |
胸膜其他疾病 | 0.825 | 0.869 | — | 0.858 |
肺炎 | 0.715 | 0.752 | — | 0.758 |
气胸 | 0.899 | 0.880 | — | 0.861 |
平均值 | 0.810 | 0.797 | 0.804 | 0.819 |
在ChestX-Ray14测试集上,与其他SOTA方法相比,CAWSNet取得了最好的整体分类效果,14种胸部疾病的平均AUC为0.853. 从表1的对比结果可以得出以下结论. 1)在分类实验中,CAWSNet对于10种胸部疾病的诊断达到了最佳的效果. 与其他网络相比,CAWSNet能够在关注长距离依赖关系的同时兼顾局部特征的提取,在大多数疾病的识别上取得了较好的效果. 2)与大多数方法一样,网络对结节(0.784)和浸润(0.726)2种疾病的识别能力需要改进. “浸润”在影像学上呈斑块状,边缘模糊,其诊断需要高度精准的边缘和纹理特征提取能力,CAWSNet在这一方面不突出. 结节是小病变,容易受到无关特征的影响,识别相对困难. 3)与同样使用Transformer架构的MXT、TransDD和LCT网络相比,CAWSNet使用的CAWS模块拥有更好的局部特征提取能力,在大多数疾病诊断中取得了更好的效果,而针对纤维化和结节这2种分部较广的疾病,MXT网络使用的全局自注意力发挥了其优势. 4)对病理标签相关性进行建模的方法在结节和肺气肿的识别上具有很大的优势,CAWSNet虽然能够通过自注意力捕捉到图像中的语义信息,但未对标签间的依赖关系进行针对性的学习,这是今后的改进方向.
针对CheXpert数据集,Irvin等[16]提出3种策略来处理其中的不确定标签:U-Ignore、U-Ones和U-Zeros,即将不确定标签去除、视为患病和视为非患病. 当CAWSNet采取U-Ones策略时取得了最佳的分类效果,故在CheXpert验证集的实验中均将不确定标签视为患病. 从表2的对比结果可以得到以下结论. 1)相较于其他的SOTA方法,CAWSNet对于5种疾病分类的平均AUC为0.898,取得了最好的整体分类效果. 2)在单一疾病分类上,网络对于水肿(0.953)的诊断效果达到了最先进的水平,对于胸腔积液(0.928)和心脏肿大(0.856)2种疾病的诊断效果接近最优.
在MIMIC-CXR测试集上,对12种胸部疾病和“无发现”共13个类别进行分类实验,采取U-Zeros策略对不确定标签进行处理. 从表3的对比结果可以得到以下结论. 1)与其他SOTA方法相比,CAWSNet取得了最好的整体分类效果,平均AUC为0.819,并在5类疾病和“无发现”的识别中达到最先进的水平. 2)由于识别难度较高和样本量较少,各网络对骨折的识别能力均较差,而使用多模态学习的MMBT网络表现出明显优势,表明使用医学报告与图像进行融合学习,能够为疾病分类提供更多的有效信息. 3)CAWSNet在不同数据集上都取得了较好的分类效果,这证明该网络能够较好地完成胸部疾病分类任务,具有一定的鲁棒性.
3.5. 相对位置编码生成方式对网络的影响
为了验证3种不同的位置编码生成方式对网络分类能力的影响,分别在ChestX-ray14、CheXpert和MIMIC-CXR数据集上设置对比实验,实验结果如表4所示. 从实验结果可知,2种与输入交互生成方式的分类效果均优于偏置方法,这说明相对动态的位置编码能够更好地适用于自注意力的分窗计算. 当位置编码与
表 4 位置编码生成方式对网络分类效果的影响
Tab.4
生成方式 | 平均AUC | ||
ChestX-ray14 | CheXpert | MIMIC-CXR | |
偏置 | 0.849 | 0.890 | 0.813 |
与输入k交互 | 0.853 | 0.898 | 0.819 |
与输入q和k交互 | 0.850 | 0.894 | 0.816 |
3.6. 通道增强加权位置对网络的影响
为了验证CAWS模块中通道增强加权位置对网络的影响,分别在ChestX-ray14、CheXpert和MIMIC-CXR数据集上设置对比实验,实验结果如表5所示.
表 5 通道增强加权位置对网络分类效果的影响
Tab.5
加权位置 | 平均AUC | ||
ChestX-ray14 | CheXpert | MIMIC-CXR | |
对输入q加权 | 0.852 | 0.895 | 0.817 |
对输入k加权 | 0.851 | 0.896 | 0.817 |
对输入v加权 | 0.853 | 0.898 | 0.819 |
从实验结果可知,通道增强的加权位置对网络分类效果产生了一定的影响,当对输入
3.7. 消融实验
3.7.1. 模块消融实验
为了验证CAWS模块、CSRA分类器和图像相对位置编码的有效性,分别在ChestX-ray14、 CheXpert和MIMIC-CXR数据集上进行消融实验. 以CAWSNet作为基准,通过删除相应的模块来验证其对分类精度的影响,实验结果如表6所示.
表 6 不同模块对网络分类效果的影响
Tab.6
CAWS | CSRA | IRPE | 平均AUC | ||
ChestX-ray14 | CheXpert | MIMIC-CXR | |||
✔ | ✔ | ✔ | 0.853 | 0.898 | 0.819 |
✘ | ✔ | ✔ | 0.844 | 0.886 | 0.811 |
✔ | ✘ | ✔ | 0.849 | 0.894 | 0.814 |
✔ | ✔ | ✘ | 0.850 | 0.891 | 0.816 |
对于CAWS模块,将其替换为原始的窗口自注意力,网络在ChestX-ray14、CheXpert和MIMIC-CXR数据集上的AUC分数分别下降了0.9%、1.2%和0.8%. 可见,引入卷积与窗口自注意力组成的混合模块,可以有效地提高网络的分类能力. 对于类别残差注意分类器,将其替换为最常见的全连接层平均池化分类器,网络在3个数据集上的AUC分数分别下降了0.4%、0.4%和0.5%,表明针对类别的空间注意力能够有效地提升多标签分类精度. 针对图像相对位置编码,使用Swin Transformer中的相对位置偏置进行替代,网络在3个数据集上的AUC分数分别下降了0.3%、0.7%和0.3%,可见有向且动态的位置编码能够更好地帮助网络捕捉CXR图像中像素间的位置关系.
3.7.2. 损失函数消融实验
为了验证提出的动态难度损失函数的效果,分别在ChestX-ray14、CheXpert和MIMIC-CXR数据集上进行消融实验,结果如表7所示.
表 7 不同损失函数对网络分类效果的影响
Tab.7
损失函数 | 平均AUC | ||
ChestX-ray14 | CheXpert | MIMIC-CXR | |
交叉熵损失函数 | 0.846 | 0.889 | 0.810 |
焦点损失函数 | 0.850 | 0.892 | 0.816 |
动态难度损失函数 | 0.853 | 0.898 | 0.819 |
从表7可以看出,在3个数据集上动态难度损失函数均取得最优的分类效果. 由此表明,利用动态难度损失函数,可以有效地解决各疾病分类难度不一、正负样本不平衡的问题,提高网络对CXR图像中各疾病的分类准确度.
3.8. 多重交叉验证与显著性检验
为了验证网络的鲁棒性和网络效果的真实性,在ChestX-ray14数据集上进行Stratified KFold交叉验证实验. 在该实验中,整个数据集被随机平分成5份互斥子集,保证各子集中的疾病类别比例大致相同. 每次随机地选择4份作为训练集,剩下的1份作为测试集. 依照此方式,分别对CAWSNet和Swin Transformer骨干网络进行5次训练. 实验结果如表8所示. 可以看出,使用随机划分的训练集和测试集会影响网络的分类效果,CAWSNet在随机划分数据集上的分类表现虽然不如官方划分数据集,但取得了0.838的平均AUC分数,证明网络具有一定的鲁棒性,对不同的数据分布都有较好的效果.
表 8 Stratified KFold交叉验证结果
Tab.8
数据集划分 | 平均AUC | |
CAWSNet | Swin Transformer | |
2、3、4、5训练集,1测试集 | 0.841 | 0.832 |
1、3、4、5训练集,2测试集 | 0.836 | 0.830 |
1、2、4、5训练集,3测试集 | 0.834 | 0.829 |
1、2、3、5训练集,4测试集 | 0.840 | 0.832 |
1、2、3、4训练集,5测试集 | 0.841 | 0.833 |
平均值 | 0.838 | 0.831 |
以5个不同划分的数据集作为5个不同的样本,使用配对样本t检验的方式对CAWSNet进行显著性检验. 备择假设设置为CAWSNet的AUC分数高于Swin Transformer,t统计量为
3.9. 计算复杂度分析
计算复杂度是计算机辅助诊断算法在实际应用中需要考虑的一个要素,计算CAWSNet在数据集ChestX-Ray14的测试阶段单张图片平均消耗的每秒浮点运算次数FLOPs及推理时间tinf,与骨干网络和给出相关指标的方法进行对比,结果如表9所示. 浮点计算量实验中的输入图像尺寸统一设置为
表 9 不同胸部疾病分类网络的计算复杂度比较
Tab.9
从表9可以看出,CAWSNet的推理时间较短,可以快速地处理单张CXR图像,FLOPs较小,与Swin Transformer骨干网络相比只增加0.15×109,但性能有显著提升. 利用CAWS模块共享部分权重来实现卷积和窗口自注意力并行计算的方法,可以有效地节约计算资源. 综合来说,网络较好地平衡了分类准确度和计算复杂度,在胸部疾病诊断算法中具有较强的竞争力.
3.10. 可视化分析
加权梯度类激活映射[38](Grad-CAM)利用梯度信息,给特征映射的各个通道赋予权重,生成热图,可以显示网络关注的病灶区域. 为了验证网络识别的准确性,通过Grad-CAM在一些ChestX-Ray14中的CXR图像上生成热图,与专业医生提供的病变标记图进行比较. 对一部分CXR图像进行疾病预测得分的可视化,直观地表现网络的分类效果.
图 9
图 9 医生标记病变区域(左)与Grad-CAM热图(右)
Fig.9 Doctor's marked lesion area (left) and Grad-CAM heat map (right)
图 10
4. 结 语
本文提出基于卷积辅助窗口自注意力的胸部X光影像疾病分类网络CAWSNet,能够端到端地自动学习不同疾病的病理特征,完成常见胸部疾病的分类. 该网络结合窗口自注意力和卷积2种特征处理方式,捕捉CXR图像中的长距离视觉关系和局部特征. 引入图像相对位置编码,能够更好地提取CXR图像中与位置相关的特征. 使用类别残差注意分类器,在分类阶段根据疾病类别关注不同的区域. 提出动态难度损失函数,有效地解决了各疾病分类难度不一、正负样本不平衡带来的问题. CAWSNet在ChestX-Ray14、CheXpert和MIMIC-CXR-JPG数据集上的平均AUC分别为0.853、0.898和0.819,在多标签胸部疾病分类任务中具有有效性. 在未来的工作中,将研究重点放在对疾病标签间相关性的建模和利用医学报告信息进行多模态学习2个方面,以进一步提高网络的整体性能并增强可解释性.
参考文献
al. Portable chest X-ray in coronavirus disease-19 (COVID-19): a pictorial review
[J].
al. The COVID-19 epidemic analysis and diagnosis using deep learning: a asystematic literature review and future directions
[J].DOI:10.1016/j.compbiomed.2021.105141 [本文引用: 1]
医学影像计算机辅助检测与诊断系统综述
[J].
Survey on medical image computer aided detection and diagnosis systems
[J].
al. Medical image analysis using convolutional neural networks: a review
[J].DOI:10.1007/s10916-018-1088-1 [本文引用: 1]
Generative adversarial network in medical imaging: a review
[J].DOI:10.1016/j.media.2019.101552
al. Deep reinforcement learning in medical imaging: a literature review
[J].DOI:10.1016/j.media.2021.102193 [本文引用: 1]
Multi-level residual feature fusion network for thoracic disease classification in chest x-ray images
[J].
可形变Transformer辅助的胸部X光影像疾病诊断模型
[J].
Chest X-ray imaging disease diagnosis model assisted by deformable Transformer
[J].
al. MIMIC-CXR, a de-identified publicly available database of chest radiographs with free-text reports
[J].DOI:10.1038/s41597-019-0322-0 [本文引用: 2]
al. DualCheXNet: dual asymmetric feature learning for thoracic disease classification in chest X-rays
[J].DOI:10.1016/j.bspc.2019.04.031 [本文引用: 1]
al. Triple attention learning for classification of 14 thoracic diseases using chest radiography
[J].
Thorax disease classification based on pyramidal convolution shuffle attention neural network
[J].DOI:10.1109/ACCESS.2022.3198958 [本文引用: 2]
al. Multi-label chest X-ray image classification via semantic similarity graph embedding
[J].DOI:10.1109/TCSVT.2021.3079900 [本文引用: 3]
al. MXT: a new variant of pyramid vision Transformer for multi-label chest X-ray image classification
[J].DOI:10.1007/s12559-022-10032-4 [本文引用: 1]
al. Focal loss for dense object detection
[J].DOI:10.1109/TPAMI.2018.2858826 [本文引用: 1]
al. Interpreting chest X-rays via CNNs that exploit hierarchical disease dependencies and uncertainty labels
[J].DOI:10.1016/j.neucom.2020.03.127 [本文引用: 1]
TransDD: a transformer-based dual-path decoder for improving the performance of thoracic diseases classification using chest X-ray
[J].
al. PCAN: pixel-wise classification and attention network for thoracic disease classification and weakly supervised localization
[J].DOI:10.1016/j.compmedimag.2022.102137 [本文引用: 2]
Label correlation transformer for automated chest X-ray diagnosis with reliable interpretability
[J].DOI:10.1007/s11547-023-01647-0 [本文引用: 1]
CheXGAT: a disease correlation-aware network for thorax disease diagnosis from chest X-ray images
[J].DOI:10.1016/j.artmed.2022.102382 [本文引用: 1]
Multi-label local to global learning: a novel learning paradigm for chest x-ray abnormality classification
[J].DOI:10.1109/JBHI.2023.3281466 [本文引用: 1]
Label co-occurrence learning with graph convolutional networks for multi-label chest x-ray image classification
[J].DOI:10.1109/JBHI.2020.2967084 [本文引用: 1]
al. Grad-CAM: visual explanations from deep networks via gradient-based localization
[J].DOI:10.1007/s11263-019-01228-7 [本文引用: 1]
/
〈 |
|
〉 |
