浙江大学学报(工学版), 2023, 57(4): 744-752 doi: 10.3785/j.issn.1008-973X.2023.04.012

自动化技术、计算机技术

基于距离度量损失框架的半监督学习方法

刘半藤,, 叶赞挺, 秦海龙, 王柯,, 郑启航, 王章权

1. 浙江树人学院 信息科技学院,浙江 杭州 310015

2. 常州大学 计算机与人工智能学院,江苏 常州 213164

3. 浙江绿城未来数智科技有限公司,浙江 杭州 311121

4. 浙江大学 工业控制技术国家重点实验室,浙江 杭州 310027

Semi-supervised learning method based on distance metric loss framework

LIU Ban-teng,, YE Zan-ting, QIN Hai-long, WANG Ke,, ZHENG Qi-hang, WANG Zhang-quan

1. College of Information Science and Technology, Zhejiang Shuren University, Hangzhou 310015, China

2. College of Computer Science and Artificial Intelligence, Changzhou University, Changzhou 213164, China

3. Zhejiang Lvcheng Future Digital Intelligence Technology Limited Company, Hangzhou 311121, China

4. State Key Laboratory of Industrial Control Technology, Zhejiang University, Hangzhou 310027, China

通讯作者: 王柯, 男, 讲师. orcid.org/0000-0002-1926-4634. E-mail: wangke1992@zju.edu.cn

收稿日期: 2022-04-7  

基金资助: 浙江省“领雁”研发攻关计划资助项目(2022C03122);浙江省公益技术应用研究资助项目(LGF22F020006,LGF21F010004);浙江大学工业控制技术国家重点实验室开放课题资助项目(ICT2022B34)

Received: 2022-04-7  

Fund supported: 浙江省“领雁”研发攻关计划资助项目(2022C03122);浙江省公益技术应用研究资助项目(LGF22F020006,LGF21F010004);浙江大学工业控制技术国家重点实验室开放课题资助项目(ICT2022B34)

作者简介 About authors

刘半藤(1984—),男,教授,从事复合无损技术的研究.orcid.org/0000-0001-8472-0061.E-mail:hupo3@sina.com , E-mail:hupo3@sina.com

摘要

为了解决半监督学习方法训练过程中因损失函数类型不同、损失尺度不统一而导致的损失权重难以调节, 模型优化方向不统一与泛化能力不足的问题, 提出基于距离度量损失框架的半监督学习方法. 该方法从距离度量损失的角度出发, 提出统一损失框架函数, 实现了半监督任务中不同损失函数之间的损失权重调节. 针对损失框架中嵌入向量的目标区域问题, 引入自适应相似度权重,以避免传统度量学习损失函数优化方向的冲突, 提高模型的泛化性能. 为了验证方法的有效性, 分别采用CNN13网络和ResNet18网络,在CIFAR-10、CIFAR-100、SVHN、STL-10标准图像数据集和医疗肺炎数据集Pneumonia Chest X-ray上,构建半监督学习模型与常用半监督方法进行比较. 实验结果表明, 在同等标签数目的条件下, 提出方法具有最优的分类准确度.

关键词: 半监督学习 ; 度量学习 ; 损失函数 ; 损失框架 ; 分类

Abstract

A semi-supervised learning method based on the distance metric loss framework was proposed in order to solve the problems of different types of loss functions and inconsistent loss scales in the training process of semi-supervised learning methods, which make it difficult to adjust the loss weights, inconsistent optimization directions and insufficient generalization ability. A unify loss framework function was proposed from the perspective of distance metric loss, and the adjustment of loss weights between different loss functions in semi-supervised tasks was achieved. Adaptive similarity weights were introduced for the target region problem of embedding vectors in the loss framework in order to avoid the conflict of optimization directions of traditional metric learning loss functions and improve the generalization performance of the model. CNN13 and ResNet18 networks were used to construct semi-supervised learning models on CIFAR-10, CIFAR-100, SVHN, STL-10 standard image dataset and medical pneumonia dataset Pneumonia Chest X-ray, respectively, for comparison with commonly used semi-supervised methods in order to validate the effectiveness of the method. Results show that the method has the optimal classification accuracy under the condition of the same number of labels.

Keywords: semi-supervised learning ; metric learning ; loss function ; loss framework ; classification

PDF (1146KB) 元数据 多维度评价 相关文章 导出 EndNote| Ris| Bibtex  收藏本文

本文引用格式

刘半藤, 叶赞挺, 秦海龙, 王柯, 郑启航, 王章权. 基于距离度量损失框架的半监督学习方法. 浙江大学学报(工学版)[J], 2023, 57(4): 744-752 doi:10.3785/j.issn.1008-973X.2023.04.012

LIU Ban-teng, YE Zan-ting, QIN Hai-long, WANG Ke, ZHENG Qi-hang, WANG Zhang-quan. Semi-supervised learning method based on distance metric loss framework. Journal of Zhejiang University(Engineering Science)[J], 2023, 57(4): 744-752 doi:10.3785/j.issn.1008-973X.2023.04.012

神经网络通过学习标签样本训练智能模型,在多个领域取得了较好的应用[1-3]. 神经网络的学习性能依赖于样本集的完整性. 在实际应用中,获取大量的标签样本数据集需要高昂的人工成本. 为了降低人工标注的成本,提高对无标签数据的利用率,需要建立自动标记机制,以降低模型对训练集大小的依赖性.

半监督学习(semi-supervised learning, SSL)方法有效利用标签和无标签样本,改善模型的分类准确性和泛化性,成为近年来模式识别领域中关注和研究的重点之一.

针对SSL损失函数的改进,相关学者作了大量的研究,提出3种类型的损失函数:1)分类交叉熵损失[4-7];2)熵最小化损失[8];3)一致性损失[9-10]. Wang等[8,11-12]通过提升特征间的类内紧凑性和类间可分离性,更新权重来改进分类交叉熵损失函数. Tarvainen等[5,13]通过加入预测熵的信息,优化熵最小化损失函数. Xie等[6-7,9,14-15]通过改进扰动量计算方式和提升数据增强策略,加强一致性损失. 以上方法未考虑不同任务损失函数间尺度差异的问题,导致整体优化进程被某一个任务所主导,其他任务损失无法影响网络共享层的学习过程[16].

为了防止半监督学习任务被单一损失主导,当前多采用损失权重分配方法,主流方法有关联函数法与人工先验调整法. 利用人工先验方法计算权值相对比较简单,但流程较复杂. Laine等[9]在π-model模型中通过多次实验,将最优检测准确率的参数设置为损失的权值. Lee[4]提出Pseudo Label方法,在平衡熵最小化损失和分类交叉熵损失的过程中,提出调节系数α作为熵最小化损失的权值. Tarvainen等[5]提出Meanteacher方法,通过多次实验确定一致性约束权重. Laine等[9,16]提出利用关联函数的方法进行动态调节权值参数,如无监督数据增强 (unsupervised data augmentation, UDA)方法. 然而人工调整权重系数的方法过于依赖先验知识,基于关联函数的方法需要对各个约束项间的关联进行建模, 流程较复杂.

本文提出基于距离度量损失框架的半监督学习方法,采用余弦距离度量[17-20]中的proxy-based损失,将分类损失、熵最小化损失转换为距离度量的损失形式,在一致性损失的基础上提出伪标签一致性损失函数,将三者均转换为余弦距离相似度的形式,构成统一的半监督损失框架. 本文对统一后的半监督损失函数框架进行优化,提出新的特征向量和嵌入向量映射方法来消除冲突,通过加入间距系数、缩放因子2类参数来提升模型的泛化性.

1. 统一的半监督损失函数框架

根据分类交叉熵损失、熵最小化损失、一致性损失函数在优化目标上基于距离度量差值的共性,将各个损失函数转换为基于余弦距离度量的半监督损失函数. 提出基于伪标签的一致性损失,在一致性损失函数中加入标签作为评价指标,与其他损失函数在计算形式上达成一致,合并得到统一的半监督损失函数. 传统的分类交叉熵损失、熵最小化损失、一致性损失公式如下所示:

$ \mathop { L}\nolimits_{{\text{cla}}} = - \ln \frac{{\exp\; ({{{\boldsymbol{w}}}_{y_i}} \cdot {{{\boldsymbol{f}}}_{y_i}})}}{{\exp\; ({{{\boldsymbol{w}}}_{y_i}} \cdot {{{\boldsymbol{f}}}_{y_i}})+\displaystyle\sum\limits_{j = 1,j \ne \mathop y\nolimits_i }^C {\exp\; \mathop {({{\boldsymbol{w}}}}\nolimits_j \cdot {{{\boldsymbol{f}}}_j})} }}. $

式(1)为分类交叉熵损失函数. 式中:f为模型 $ p(y|x,\theta ) $的嵌入向量,wj为类别对应的分类向量,yi为标签类别所对应的索引.

$ L_{{\text{ent}}} = - \ln \frac{{\exp\; ( {{\boldsymbol{w}}}_{ y_i^{'} } \cdot {{\boldsymbol{f}}}_{ y_i^{'} } )}}{{\exp\; ( {{\boldsymbol{w}}}_{ y^{'}_i } \cdot {{\boldsymbol{f}}}_{ y_i^{'} } )+\displaystyle\sum\limits_{j = 1,j \ne { y_i{'}} }^C {\exp\; {({{\boldsymbol{w}}}}_j \cdot {{\boldsymbol{f}}}_j )} }}. $

式(2)为熵最小化损失函数. 式中:y'为无标签数据对应的伪标签,

$ {y_i}{'}=\left\{ {\begin{array}{l} 1,\quad i=\text{argmax}\;{{ (\boldsymbol{w}}}_{{{y}}_{i}}\cdot {{ \boldsymbol{f}}}_{{{y}}_{i}}),\;i\in [1,C];\\ 0,\quad \text{其他}. \end{array}} \right. $

$ L_{\text{con}} = {\rm{distance}}\;({{\boldsymbol{w}}} \cdot {{\boldsymbol{f}}},\;{{\boldsymbol{w}}} \cdot {\hat {\boldsymbol{f}}}). $

式(4)为一致性损失函数. 式中: $ {\hat {\boldsymbol{f}}} $为添加随机扰动后对应的输出.

1.1. 基于余弦距离度量的半监督损失函数构建

余弦距离度量采用夹角余弦值评估数据间的相似性,表达式为

$ {\text{sim}}({{{\boldsymbol{X}}}_1},{{{\boldsymbol{X}}}_2}) = \cos \;\theta = \frac{{{{{\boldsymbol{X}}}_1} \cdot {{{\boldsymbol{X}}}_2}}}{{\left\| {{{{\boldsymbol{X}}}_1}} \right\| \cdot \left\| {{{{\boldsymbol{X}}}_2}} \right\|}}. $

式中:X1X2为2个不同的向量. 与传统距离度量相比,余弦距离度量方式更加注重2个向量在方向上的差异,数学意义满足本文统一函数的构造要求,可以实现交叉熵损失函数、熵最小化损失函数、一致性损失函数公式的统一.

引入基于余弦度量的类内相似度sp与类间相似度sn,用于描述嵌入向量和类别特征向量之间的相似度. 模型嵌入向量 ${\boldsymbol{f}}_{y_i} $与对应类别向量 ${\boldsymbol{w}}_{y_i} $的内积为基于余弦度量的类内相似度, ${{{\boldsymbol{w}}}_{y_i} \cdot {{\boldsymbol{f}}}_{y_i}} / ({\left| {{{\boldsymbol{w}}}_{y_i}} \right| \cdot \left| {{{\boldsymbol{f}}}_{y_i}} \right|}) = \cos \;\theta_{ y_i }= s_{\rm{p}}$;不同类别嵌入向量和权重的内积为基于余弦度量的类间相似度, ${{{\boldsymbol{w}}}_j \cdot {{\boldsymbol{f}}}_j} / ({\left| {{{\boldsymbol{w}}}_j} \right| \cdot \left| {{{\boldsymbol{f}}}_j} \right|}) = \cos \;\theta _j = s_{\rm{n}}^j$.

将分类交叉熵损失、熵最小化损失与一致性损失统一转换为类内相似度和类间相似度的表达式. 分类交叉熵损失函数的转换过程如下:

$ \begin{split} L^{\prime} _{\text {cla}} & =-\ln \frac{\exp \left(\boldsymbol{w}_{y_i} \cdot \boldsymbol{f}_y\right)}{\exp \left(\boldsymbol{w}_{y_i} \cdot \boldsymbol{f}_{y_i}\right)+\displaystyle\sum\limits_{j=1, j \neq y_i}^c \exp \left(\boldsymbol{w}_j \cdot \boldsymbol{f}_j\right)} =\\ & -\ln \frac{\exp \left(s_{\rm{p}}\right)}{\exp \left(s_{\rm{p}}\right)+\displaystyle\sum\limits_{j=1, j \neq y_i}^c \exp \left(s_{\rm{n}}^j\right)}= \\ & \ln \left[1+\displaystyle\sum\limits_{j=1, j \neq y_i}^c \exp \left(s_{\rm{n}}^j-s_{\rm{p}}\right)\right] . \end{split} $

基于余弦度量的分类交叉熵损失函数 $ L'_{\text{cla}} $的优化目标可以表示为: ${\text{min}}\;(s_{\rm{n}}^j - s_{\rm{p}})$. 熵最小化损失函数的转换过程如下:

$ \begin{split} L^{\prime} _{\text {ent}}& =-\ln \frac{\exp \left(\boldsymbol{w}_{y _i} \cdot \boldsymbol{f}_{y_i}\right)}{\exp \left(\boldsymbol{w}_{y_i} \cdot \boldsymbol{f}_{y_i}\right)+\displaystyle\sum\limits_{j=1, j \neq y_i}^c \exp \left(\boldsymbol{w}_j \cdot \boldsymbol{f}_j\right)}= \\ & -\ln \frac{\exp \left(s_{\rm{p}}^{\prime}\right)}{\exp \left(s_{\rm{p}}^{\prime}\right)+\displaystyle\sum\limits_{j=1, j \neq y_i}^c \exp \left(s_{\rm{n}}^j\right)} =\\ & \ln \left[1+\displaystyle\sum\limits_{j=1, j \neq y_i^{\prime}}^c \exp \left(s_{\rm{n}}^j-s_{\rm{p}}^{\prime}\right)\right] . \end{split} $

基于余弦度量的分类交叉熵损失函数 $L'_{\text{ent}}$的优化目标可以表示为 ${\text{min}}\;(s_{\rm{n}}^j - s'_{\rm{p}})$,其中 $s'_{\rm{p}}$为伪标签样本 $ y_i' $的类内相似度.

1.2. 基于伪标签的一致性损失函数

一致性损失函数表示扰动前、后嵌入向量的输出差值,在计算过程中无须考虑伪标签类别,这与分类交叉熵损失函数和熵最小化损失函数的形式不统一. 提出基于伪标签的一致性损失函数,在已知嵌入向量的伪标签类别及熵最小化损失约束的条件下,嵌入向量将靠近对应伪标签类别的proxy向量. 将扰动后的嵌入向量和对应伪标签类别的proxy向量的距离度量作为一致性损失,保证嵌入向量扰动前、后的输出值差异最小化. 基于伪标签的一致性损失构建示意图如图1所示. 图1(a) 中,有向线段表示一致性损失,不同圆点表示不同无标签数据对应的嵌入向量,五角星表示对应扰动后的嵌入向量. 图1(b) 中,有向线段表示伪标签一致性损失,圆点表示已知伪标签的嵌入向量,五角星表示数据扰动后的嵌入向量. 图1中,六角星表示伪标签所对应的嵌入向量. 基于伪标签的一致性损失可以视为扰动后嵌入向量的熵最小化损失,公式可以表示为

图 1

图 1   一致性损失与熵最小化损失的示意图

Fig.1   Schematic diagram of entropy minimization and consistency loss


$ L_{\text{pse}} = \ln \left[1+\sum\limits_{j = 1,j \ne y_i'}^c {\exp }\; (\hat S_{\rm{n}}^j - {S_{\rm{p}}}')\right]. $

式中: $ S_{\rm{n}}^j $为添加扰动后的嵌入向量与不同类别特征向量的余弦相似度, $ {S_{\rm{p}}}' $为添加扰动后的嵌入向量与伪标签对应类别特征向量的余弦相似度.

基于伪标签的一致性损失函数的优化目标可以表示为 $ \min\; (S_{\rm{n}}^j - {S_{\rm{p}}}') $. 提出的基于伪标签的一致性损失函数在原有的一致性损失函数基础上,考虑了伪标签信息,在建模过程中加强嵌入向量与伪标签嵌入向量proxy的相关性,有利于提高模型对扰动后无标签样本预测的置信度.

$ L_{\rm{ssl}} = \ln\; [1 + \sum\limits_{j = 1,j \ne y_i'\atop {S_{\rm{n}} = \{ S_{\rm{n}}^j , \hat S_{\rm{n}}^j\} \atop S_{\rm{p}} = \{ S_{\rm{p}} , S_{\rm{p}}'\} }}^C {\exp } \;(S_{\rm{n}} - S_{\rm{p}})]. $

综上所述,基于伪标签的一致性损失函数为基于带标签信息的余弦相似度差值,其形式与分类交叉熵损失(式(6))及熵最小损失函数(式(7))相同. 以上损失函数可以统一为基于度量的损失框架,最终框架如式(9)所示. 本文专注于优化度量损失中优化目标的设定,提升了半监督学习效果.

2. 损失函数框架优化

2.1. 特征向量和嵌入向量优化映射方法

目前,基于距离度量的损失函数在对嵌入向量进行优化时,与当前标签不同类别的嵌入向量的优化目标均被设定为与当前类别proxy特征向量相反的区域. 如图2(a) 所示,优化类别A的嵌入向量被定位于对应的proxy特征向量附近,类别B与类别C的样本会被优化至类别A的proxy特征向量相反的附近位置,造成类别B与C的样本嵌入向量的冲突,模型的分类性能下降. 通过改进特征向量和嵌入向量映射的方法来消除冲突,思路如图2(b) 所示. 将不同类别嵌入向量的优化目标定位为正交于当前类别proxy特征向量的位置,既保证了非同类嵌入向量的分离性,又避免了嵌入向量优化过程的冲突问题,以获得具备较强鲁棒性的半监督分类模型.

图 2

图 2   嵌入向量与特征向量优化方法的示意图

Fig.2   Schematic diagram of optimization method of embedd-ing vector and feature vector


设计自适应类内相似度权重 $ \partial_{\rm{p}} $和类间相似度权重 $ \partial _{\rm{n}} $. 在模型优化的过程中,相似度权重越大,更新速率越快,反之亦然,相似度权重为零,模型停止更新. 提出动态 $ \partial _{\rm{p}} $$ \partial_{\rm{n}} $的计算公式:

$ \left. \begin{aligned} \partial _{\rm{p}} = T_{\rm{p}} - S_{\rm{p}}, \\ \partial _{\rm{n}} = S_{\rm{n}} - T_{\rm{n}} . \\ \end{aligned}\right\} $

根据模型优化的要求可知,同类嵌入向量越靠近越好,类间嵌入向量需要正交,即 $ S_{\rm{p}} $的优化目标为 $\theta _j=0 $$ S_{\rm{n}} $的优化目标为 $ \theta _j={\text{π}} /2 $. 设置 $ T_{\rm{p}} = 1 $$ T_{\rm{n}} = 0 $,防止类间相似度优化目标的冲突,保证了模型更新效率.

2.2. 间距系数和缩放因子

损失用于分类时,分类决策面为 $ S_{\rm{n}} = S_{\rm{p}} $,即 $ S_{\rm{n}} - S_{\rm{p}} = 0 $. 在优化 $ \partial _{\rm{n}}S_{\rm{n}} - \partial _{\rm{p}}S_{\rm{p}} $的过程中,添加间距系数可以加强 $ S_{\rm{p}} $的聚类程度. 由于 $ S_{\rm{n}} $$ - S_{\rm{p}} $具有对称特性,只需要为 $ S_{\rm{p}} $添加间距系数gap.

当使用余弦相似度作为度量时,对分类权重与分类向量wf进行归一化处理,余弦相似度数值范围将缩小,梯度范围相应减小,导致模型更新速率减慢. 添加 $\lambda $作为余弦相似度项的缩放因子,以增大梯度.

2.3. Unify Loss函数

在式(9)的基础上,经过上述改进,得到完整的Unify Loss函数,形式如下.

$ L_{\text{uni}} = \ln \left[ 1+ \sum\limits_{j = 1,j \ne {y_i}'}^C \exp\;[ \lambda (\alpha _{\rm{n}}S_{\rm{n}} - \alpha _{\rm{p}}(S_{\rm{p}} - {\rm{gap}}))] \right]. $

本文得到的Unify Loss函数专注于优化 $ {\text{min}}\; (\alpha _{\rm{n}}S_{\rm{n}} - \alpha _{\rm{p}}(S_{\rm{p}} - {\rm{gap}})) $. 该损失将传统半监督学习任务中的3类损失统一为类间相似度与类内相似度的差值形式,取代原半监督学习模型中的分类交叉熵损失、熵最小化损失、一致性损失这3类损失,作为任务中的核心损失函数. 本文的Unify Loss函数曲面图如图3所示,该函数具有全局连续性,在 $ \theta _{y_i} = 0 $$ \theta _j = {\text{π}} /2 $时接近全局最优及损失最低点,符合上述的优化目标设置.

图 3

图 3   Unify Loss函数曲面图

Fig.3   Diagram of Unify Loss function


3. 实验与分析

3.1. 梯度分析

由于模型的优化速率与函数求解过程中的梯度直接相关. 对常用的Triplet Loss、Proxy Loss、Circle Loss函数与Unify Loss函数进行梯度分析,结果如图4所示. 对 $ S_{\rm{n}} $$ S_{\rm{p}} $求偏导作出梯度图,其中x轴为 $ \theta _{y_i} $的分布,y轴为 $ \theta_ j $的分布. 在x-y平面梯度的分布中,提出的Unify Loss函数、梯度全局都有较好的连续性;设置的合理优化目标 $ T_{\rm{n}} $$ T_{\rm{p}} $,使得嵌入向量的梯度在达到最优目标区域处前总能保持较大的值,改善了模型优化过程的收敛性. 其他3类损失,在收敛过程中梯度存在不连续、消失、爆炸等情况,将造成收敛区域过大、收敛过程不可靠、计算成本大等问题,导致模型的泛化性能不佳.

图 4

图 4   各损失梯度的对比图

Fig.4   Contrast diagram of loss gradient


3.2. 实验数据集

实验采用标准图像分类数据集CIFAR-10、 CIFAR-100、SVHN、STL-10和真实的医疗数据集 Pneumonia Chest-ray[21],对提出算法的有效性进行验证. 各个数据集的训练样本及测试样本分布如表1所示. 表中,N为训练集样本总数量,Nt为测试集样本总数量,Nl为标签样本数量,Nu为无标签样本数量. 针对样本不足的问题,以UDA算法为基础进行数据增强,其中无标签数据采用Rand Augment方法进行增强.

表 1   各数据集的训练及测试样本分布

Tab.1  Training and test sample distribution of each dataset

数据集 N Nt
Nl Nu
SVHN 40 70 000 26 000
250 70 000
4 000 70 000
CIFAR-10 40 50 000 5 000
250 50 000
4 000 50 000
CIFAR-100 400 40 000 10 000
1 000 40 000
10 000 40 000
STL-10 1 000 10 000 8 000
5 000 10 000
Pneumonia Chest X-ray 250 4 750 600
500 4 500
1 000 4 000

新窗口打开| 下载CSV


3.3. 实验设计

实验采用MeanTeacher、MixMatch、UDA、Remixmatch、Fixmatch、Flexmatch、虚拟对抗网络(virtual adversarial training, VAT)算法[5-7,22-25],与本文算法进行对比. VAT算法结合半监督思想与对抗网络(generative adversarial network, GAN),运用GAN网络中生成器与判别器2类模块生成器相互博弈学习的方式,提高样本利用率. MeanTeacher与UDA算法模型均采用一致性正则方法,通过一致性损失提高无标签样本的利用率,其中UDA算法使用无监督数据增强策略,分类准确性较MeanTeacher有一定的提高. MixMatch与Fixmatch是整合一致性正则、多视图训练优势的混合方法. 其中MixMatch利用数据增强方法与无监督损失函数,提高无标签样本的利用率;Fixmatch对无标签样本进行增强,使用交叉熵形式的正则化损失计算增强后样本的损失. 这2类方法整合多方优势,均取得较好的结果. Flexmatch算法在Fixmatch的基础上,提出课程伪标签方法(curriculum pseudo labeling, CLP),通过动态调节伪标签置信度阈值,增加模型不同训练阶段获得的伪标签样本数量. 为了对比各算法的性能,各算法均采用小规模网络CNN13[5]和大规模网络ResNet18,用于训练CIFAR-10、CIFAR-100、SVHN、STL-10数据集与Pneumonia Chest X-ray数据集. 不同算法使用相同的网络结构,优化器、学习率及batch size都设置为相同值. 其中优化器使用默认的SGD算法,学习率选用cosine衰减方法,初始学习率设定为0.03,衰减速率 $ \omega = 7{\text{π}} $,共训练1 024×128步. 对于CIFAR-10、CIFAR-100、SVHN及STL-10数据,batch size 设置为64;对于 Pneumonia Chest X-ray数据集,batch size设置为8. 为了保证实验的精确性,对于每个数据集,采用5折叠交叉验证的平均值作为最终的测试结果.

3.4. 参数对比

对Unify Loss函数中的2个超参数 $ \lambda $$ {\rm{gap}} $进行讨论分析. 实验在数据集CIFAR-10上开展,通过分析2类参数在不同取值情况下得到的分类准确度,确定最佳参数.

图5所示为标签数为4 000时参数在各取值下的分类结果. 图中,A为准确率. 如图5(a)所示为 $ \lambda $在不同取值下的分类准确率,在实验过程中,为了确保单一变量,将 $ {\rm{gap}} $参数均设为0. 结果表明,当 $ \lambda $= 256时可以获得最优的分类结果,分类准确率约为94.20%. 如图5(b)所示为不同 $ {\rm{gap}} $取值下的分类准确率,在实验过程中 $ \lambda $设置为256. 实验结果表明,当 $ {\rm{gap}} $= 0.18时可得最优的分类结果,分类准确率为94.25%. 最终将 $ \lambda $$ {\rm{gap}} $参数设定为256与0.18,开展后续实验.

图 5

图 5   参数取值结果的对比图

Fig.5   Comparison diagram of each parameter value


3.5. 仿真结果分析

为了验证Unify Loss的有效性,采用5个公开数据集进行算法的对比分析. 如表2所示为在CIFAR-10、CIFAR-100、SVHN、STL-10标准图像数据集上的仿真结果,训练网络均采用CNN13. 如表3所示为Pneumonia Chest X-ray真实医疗数据集上的仿真结果,训练网络采用ResNet18. 各算法在不同数据集上的测试结果如下.

表 2   各方法在CIFAR-10、CIFAR-100、SVHN、STL-10数据集中的准确率

Tab.2  Accuracy of each method on CIFAR-10, CIFAR-100, SVHN, STL-10 datasets

方法 CIFAR-10 CIFAR-100 SVHN STL-10
Nl = 40 Nl = 250 Nl = 4000 Nl = 400 Nl = 1000 Nl = 10000 Nl = 40 Nl = 250 Nl = 4000 Nl = 1000 Nl = 5000
VAT 27.41 53.97 82.59 5.12 12.89 53.20 35.64 79.78 90.06 68.77 81.26
MeanTeacher 26.54 55.52 83.79 4.94 12.13 55.13 36.45 78.43 93.58 69.01 82.80
MixMatch 41.47 74.08 90.74 9.15 32.39 65.87 47.45 76.02 93.50 79.59 88.41
Remixmatch 42.32 78.71 92.20 12.98 44.56 72.33 52.55 81.92 94.67 84.33 92.89
UDA 44.55 81.66 92.45 10.66 40.72 71.18 47.37 84.31 95.54 82.34 92.74
UDA_unify 49.69 84.22 93.87 18.06 46.12 72.67 53.14 87.32 95.77 87.21 94.34
Fixmatch 50.17 86.28 93.11 12.65 42.95 72.08 56.41 88.75 95.94 89.68 94.58
Fixmatch_unify 55.24 89.94 94.74 18.51 47.79 73.38 65.86 90.09 96.02 93.44 95.25
Flexmatch 52.70 85.08 93.97 27.22 55.12 77.84 62.03 90.42 95.83 92.36 96.06
Unify Loss 61.19 90.33 94.25 35.35 61.75 78.11 67.35 92.47 96.17 93.42 96.51

新窗口打开| 下载CSV


表 3   各方法在 Pneumonia Chest X-ray 数据集中的验证结果

Tab.3  Validation results of each method on Pneumonia Chest X-ray dataset

方法 A/%
Nl/N = 5% Nl/N = 10% Nl/N = 20%
VAT 8.62 25.71 67.12
MeanTeacher 10.72 27.92 69.46
MixMatch 12.24 31.62 73.86
Remixmatch 15.07 33.78 74.02
UDA 11.87 31.89 75.31
UDA_unify 17.66 36.46 76.55
Fixmatch 19.12 37.22 77.54
Fixmatch_unify 23.55 40.91 77.67
Flexmatch 23.38 40.67 78.76
Unify Loss 26.57 42.26 79.24

新窗口打开| 下载CSV


在对CIFAR-10数据集的测试中,当标签数为40时,VAT与MeanTeacher算法的分类准确率分别为27.41%、26.54%,MixMatch、Remixmatch与UDA算法的分类准确率较前2个算法提升了约20%. 5类算法的数据增强方式不同,导致模型的分类准确率差异较大. VAT算法利用标签样本生成对抗样本,在标签数极少的情况下,对抗样本可生成的数量有限,模型分类性能下降. MeanTeacher方法采用一致性约束扩充训练数据,但在标签数严重不足的情况下,为了保证扰动前、后样本的输出一致性,数据扰动范围较小,产生的训练样本有限. MixMatch算法、Remixmatch算法与UDA算法分别采用Weak Augment、CTAugment与RandAugment的数据扩充方式,与传统一致性约束下扩充数据的方法相比,提升了标签数据与无标签数据的有效利用率,获得较好的分类结果. Fixmatch算法综合了UDA与Remixmatch算法中的Strong Augmentation数据增强策略,引入带阈值的交叉熵损失,实现更优的分类效果. 该算法在标签样本数为40时得到了50.17%的分类准确率. Flexmatch算法在Fixmatch的基础上提出CLP方法,给无标签样本设定可动态调整的置信度阈值,以充分利用低置信度样本,方法在标签样本数为40时得到了52.70%的分类准确率. Unify Loss算法保留了Flexmatch的CLP方法,引入本文的损失框架,以取代原算法中的损失函数,方法在标签数为40时得到61.19%的最优分类准确率,与Flexmatch算法相比,分类准确率提高了8.49%. 为了体现本文损失框架的有效性,实验部分在UDA与Fixmatch算法中引入本文损失框架. 改进后的算法与原算法相比,分类效果均有所提升,特别是在标签样本严重匮乏的情况下,分类准确率可以提高5%以上. 在CIFAR-100数据集测试中,图像数量与类别较CIFAR-10大幅增加. 当标签数不足时,VAT模型无法利用足量的标签信息生成对抗样本,导致分类错误率增大. 在MeanTeacher模型中,参数更新过程依赖于标签数量,当标签数量不足以满足模型参数的更新需求时,易导致分类效果降低. 相较于前2种模型对样本标签的依赖性,MixMatch、Remixmatch与UDA算法通过数据增强,提升了在无标签样本上的分类准确性. Fixmatch与Flexmatch方法整合MixMatch、Remixmatch与UDA算法的优势,简化了无标签数据标记方式,以提高模型性能. 当标签样本数为400时,UDA、Fixmatch、Flexmatch算法在引入本文损失后的分类准确率分别提高了7.50%、5.86%与8.13%,当标签样本数量为1 000与10 000时,各改进算法的分类准确率均有所提高.

在SVHN数据集的测试中,当标签数为40时,由于极度缺乏标签信息,导致7种算法对简单数字识别任务的效果不理想. 当标签数为250时,算法的分类检测性能具有明显的提升,均超过75%. 本文的Unify Loss方法分类准确率为92.47%,较VAT算法提升了12.69%,较原Flexmatch方法提升了2.05%. 当标签数由250增加到4 000时,不同算法的提升效果不明显. 这是因为SVHN为门牌号数字,数据变化区域为统一可识别的标准范围. 数字图像的变化程度小,减少了算法对标签的依赖性. 与CIFAT-100相比,分类难度降低,导致标签数超过250以后,不同算法的差异性较小. 相比于VAT和MeanTeacher算法,其余算法在标签数较低时能够有接近50%的分类准确率. 这是因为MixMatch、Remixmatch、UDA等算法运用了特定的数据增强策略,通过在训练过程中合理划分标签和非标签数据,提高样本信息的利用率. 本文的统一损失框架在标签数极度缺乏的情况下,可以有效地提高算法分类性能. 当标签数为40时,与原算法相比,改进后的UDA_unify、Fixmatch_unify、Unify Loss的图像分类准确率分别提高了5.77%、9.45%与5.32%,提升效果明显.

STL-10数据集与SVHN数据集同为10分类数据集,与SVHN数据集相比,STL-10具有更丰富的图像特征,因此模型训练过程中对标签的依赖程度提升. 对比2类数据集的实验结果可知,STL-10数据集需要1 000张带标签数据,才能达到与SVHN标签数为250时相近的分类准确率. 在STL-10数据集测试中,当标签样本数目为1 000时,UDA_unify、Fixmatch_unify、Unify Loss算法相对于原算法的分类准确率分别提高了1.06%、3.76%、4.87%. 实验结果表明,提出的损失应用于特征丰富的图像时具有一定的分类优势.

Pneumonia Chest X-ray为胸腔X射线透视图数据集. 数据集中包含正常与肺炎2类样本,相较于CIFAR-10、CIFAR-100、SVHN与STL-10数据集,Pneumonia Chest X-ray图像的分辨率增大,图像特征丰富,因此分类难度提升. 分析Pneumonia Chest X-ray训练集测试结果可知,当标签样本占比为5%、10%与20%时,本文Unify Loss方法的分类准确率为26.57%、42.26%与79.24%,较原Flexmatch算法的分类准确率分别提高了3.19%、1.59%与0.48%. 在标签样本极少的情况下,本文方法的分类优势明显. 本文方法可以在特定标签难以获取的任务中(如医学图像检测)得到较好的应用,方法具有实用性.

综上所述,Unify Loss方法在5个公开数据集上均获得了最优的分类结果. 仿真结果表明,统一损失后的Unify Loss方法能够有效整合3类损失,避免任务被单一损失主导,提升了模型整体的分类准确率.

4. 结 语

本文提出基于距离度量损失框架的半监督学习方法. 该方法以分类交叉熵损失、熵最小化损失和一致性损失为基础,构建统一损失框架函数(unify loss函数),将其作为半监督任务中的核心损失函数,能够避免多种损失函数的超参数调节. 对比Unify Loss方法与典型半监督学习方法的算法性能,实验结果表明,在不同的图像数据集中,Unify Loss方法只需要少量标签数据就能获得较好的分类结果,具有更好的泛化性和实用性.

下一步将针对实际工程问题展开研究,尤其是对于乳腺癌检测中标签数据缺乏的实际问题,尝试采用本文提出的方法来提高识别准确性.

参考文献

KORNBLITH S, SHLENS J, LE Q V. Do better imagenet models transfer better? [C]// Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. Long Beach: IEEE, 2019: 2661-2671.

[本文引用: 1]

YANG S, LUO P, LOY C C, et al. WIDER FACE: a face detection benchmark [C]// Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. Las Vegas: IEEE, 2016: 5525-5533.

许佳辉, 王敬昌, 陈岭, 等

基于图神经网络的地表水水质预测模型

[J]. 浙江大学学报:工学版, 2021, 55 (4): 601- 607

[本文引用: 1]

XU Jia-hui, WANG Jing-chang, CHEN Ling, et al

Surface water quality prediction model based on graph neural network

[J]. Journal of Zhejiang University: Engineering Science, 2021, 55 (4): 601- 607

[本文引用: 1]

LEE D H. Pseudo-label: the simple and efficient semi-supervised learning method for deep neural networks [C]// ICML 2013 Workshop on Challenges in Representation Learning. Atlanta: PMLR, 2013: 896.

[本文引用: 2]

TARVAINEN A, VALPOLA H

Mean teachers are better role models: weight-averaged consistency targets improve semi-supervised deep learning results

[J]. Advances in Neural Information Processing Systems, 2017, 30: 1195- 1204

[本文引用: 4]

XIE Q, DAI Z, HOVV E, et al

Unsupervised data augmentation for consistency training

[J]. Advances in Neural Information Processing Systems, 2020, 33 (2): 6256- 6268

[本文引用: 1]

MIYATO T, MAEDA S, KOYAMA M, et al

Virtual adversarial training: a regularization method for supervised and semi-supervised learning

[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2018, 41 (8): 1979- 1993

[本文引用: 3]

WANG F, CHENG J, LIU W, et al

Additive margin softmax for face verification

[J]. IEEE Signal Processing Letters, 2018, 25 (7): 926- 930

DOI:10.1109/LSP.2018.2822810      [本文引用: 2]

LAINE S, AILA T. Temporal ensembling for semi-supervised learning [C]// International Conference on Learning Representations. Toulon: [s. n.], 2017: 1-13.

[本文引用: 4]

SAJJADI M, JAVANMARDI M, TASDIZEN T

Regularization with stochastic transformations and perturbations for deep semi-supervised learning

[J]. Advances in Neural Information Processing Systems, 2016, 29 (7): 1163- 1171

[本文引用: 1]

LIU W, WEN Y, YU Z, et al. Large-margin softmax loss for convolutional neural networks [C]// Proceedings of the 33rd International Conference on Machine Learning. New York: PMLR, 2016: 507-516.

[本文引用: 1]

LI Y, GAO F, OU Z, et al. Angular softmax loss for end-to-end speaker verification [C]// 2018 11th International Symposium on Chinese Spoken Language Processing. Taipei: IEEE, 2018: 190-194.

[本文引用: 1]

GRANDVALET Y, BENGIO Y

Semi-supervised learning by entropy minimization

[J]. Advances in Neural Information Processing Systems, 2004, 17: 529- 536

[本文引用: 1]

VERMA V, KAWAGUCHI K, LAMB A, et al

Interpolation consistency training for semi-supervised learning

[J]. Neural Networks, 2022, 145: 90- 106

DOI:10.1016/j.neunet.2021.10.008      [本文引用: 1]

HENDRYCKS D, MU N, CUBUK E D, et al. Augmix: a simple method to improve robustness and uncertainty under data shift [C]// International Conference on Learning Representations. Ethiopia: [s. n.], 2020: 1-15.

[本文引用: 1]

KENDALL A, GAL Y, CIPOLLA R. Multi-task learning using uncertainty to weigh losses for scene geometry and semantics [C]// Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. Salt Lake City: IEEE, 2018: 7482-7491.

[本文引用: 2]

AZIERE N, TODOROVIC S. Ensemble deep manifold similarity learning using hard proxies [C]// Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. Long Beach: IEEE, 2019: 7299-7307.

[本文引用: 1]

QIAN Q, SHANG L, SUN B, et al. Softtriple loss: deep metric learning without triplet sampling [C]// Proceedings of the IEEE/CVF International Conference on Computer Vision. Seoul: IEEE, 2019: 6450-6458.

KIM S, KIM D, CHO M, et al. Proxy anchor loss for deep metric learning [C]// Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. Seattle: IEEE, 2020: 3238-3247.

SUN Y, CHENG C, ZHANG Y, et al. Circle loss: a unified perspective of pair similarity optimization [C]// Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. Seattle: IEEE, 2020: 6398-6407.

[本文引用: 1]

KERMANY D, ZHANG K, GOLDBAUM M

Labeled optical coherence tomography (oct) and chest X-ray images for classification

[J]. Mendeley Data, 2018, 2 (2): 255- 265

[本文引用: 1]

BERTHELOT D, CARLINI N, CUBUK E D, et al. Remixmatch: semi-supervised learning with distribution matching and augmentation anchoring [C]// International Conference on Learning Representations. Addis Ababa: [s. n.], 2020: 1-13.

[本文引用: 1]

BERTHELOT D, CARLINI N, GOODFELLOW I, et al

Mixmatch: a holistic approach to semi-supervised learning

[J]. Advances in Neural Information Processing Systems, 2019, 32: 155- 166

SOHN K, BERTHELOT D, CARLINI N, et al

Fixmatch: simplifying semi-supervised learning with consistency and confidence

[J]. Advances in Neural Information Processing Systems, 2020, 33: 596- 608

ZHANG B, WANG Y, HOU W, et al. FlexMatch: boosting semi-supervised learning with curriculum pseudo labeling [C]// Neural Information Processing Systems. [S. l.]: NIPS, 2021: 1-12.

[本文引用: 1]

/