泛览天下

阅读,看尽天下事

论文笔记|Selective Pseudo-Labeling with Reinforcement Learning for Semi-Supervised Domain Adaptation

2022-07-01 19:38:47


    软标签:将分类器预测的每个类的条件概率分配给目标为标记数据Reinforcement LearningMethod方法概述:首先用提出的目标边际损失来训练一个由特征提取器F和分类器C组成的CNN来处理K类分类问题,然后用基于训练好的


《基于强化学习的选择性伪标签的半监督域自适应方法》


Abstract

在半监督自适应(SSDA)中,目标域中只有很少的标记实例

→基于强化学习的选择性伪标签的半监督域自适应方法

传统的伪标签方法很难平衡伪标签数据的正确性和代表性

→deep Q-learning,选择准确和具有代表性的伪标签实例

大边际损失在少数据学习鉴别特征

→提出一种新的目标边际损失用于基础模型训练,以提高其可辨别性

用强化学习来学习适当的策略,以便选择更准确和有代表性的伪标签样本

伪标签+deep Q-learning+新的目标边际损失函数


Introduction

对于领域自适应任务,目的是提高目标域的泛化能力

为了在产生注释成本的情况下增加目标域中的标记实例数量,一个直观的策略是利用当前预测模型产生的目标域样本的伪标签,但是伪标签通常十分嘈杂且包含许多错误的标签,使用错误的标记样本进行训练可能会对原始模型产生负面的影响。


Related Work

Domain Adaptation

UDA:一个标签丰富的源域和一个未标记的目标域。常用方法是添加一个域鉴别器,对样本是从源域还是目标域进行分类,然后应用对抗性学习来最小化源域和目标域之间的特征分布之间的距离。

域对抗神经网络(DANN):提出了标准的域对抗架构,并引入了一个梯度反转层GRL来处理鉴别器产生的域混淆损失。

Pseudo-Labeling

伪标签:根据已标签的数据给出近似的标签

硬标签:为每个目标未标记实例分配一个分类器预测的伪标签,然后将伪标签的目标数据与原始标签向结合,训练一个改进模型。

软标签:将分类器预测的每个类的条件概率分配给目标为标记数据

Reinforcement Learning


Method

方法概述:首先用提出的目标边际损失来训练一个由特征提取器F和分类器C组成的CNN来处理K类分类问题,然后用基于训练好的CNN分类器为目标域中未标记的样本生成伪标签,最后交替使用deep Q-learning来训练智能体,并使用智能体选择伪标签进行CNN训练

输入数据

源域:sufficient labeled dataset D_{s}

目标域:limited labeled instances D_{t}

large set of unlabeled instances D_{u}

Target Margin Loss

最终的半监督损失,α是一个平衡目标边际损失和熵损失的超参数
目标边际损失

对于半监督自适应,域之间的特征分布存在差距,目标域中的标记样本数量比源域中少得多,因此在目标标记数据的损失上增加一个相对的角边缘,可以看作是使决策区域的分离与目标域的特征分布更加对齐。

角边缘损失:通过修改softmax来构建

熵损失

将目标未标记特征聚类到相应的决策区域

Selective Pseudo-Labeling by Reinforcement Learning

输入数据

候选集 D_{c} :由要选择的伪标签样本组成,初始化为从 D_{u} 中随机采样得来

正集 D_{p} :由选择出的伪标签的样本组成,初始化为空集

State

伪标签的代表性能力和准确性与三部分有关: D_{t} D_{p} 中有标签的数据以及 D_{u} 中未标记的数据

D_{c}

代表每个实例,其中F代表特征提取器提取出实例的D维特征向量,C代表分类器C的softmax输出

在实例从 D_{c} 移动到 D_{p} 后,用零值向量替换选中的实例

D_{t}\cup D_{p}

代表每个实例

计算了每个类中的平均向量,然后将它们连接到一个维数为 K×(d+K) 的向量上

D_{u} :由 D_{u} 与第二部分中具有相同操作的实例表示的向量

Action

 D_{c} 中选择一个实例,对每个状态 s_{i} ,agent采取一个动作 a_{i} 来选择 D_{c} 中第 a_{i} 个实例并将其移动到 D_{p}

Rewards

度量函数,衡量伪标签实例是否具有代表性能力和准确性,其中β和λ为超参数, p_{c} 代表分类器预测伪标签类 \hat{y} _{i} 的概率,表示预测的置信度


利用目标标记 D_{t} 的数据,代表 x_{i} D_{t} \cup D_{p} 中伪标签 \hat{y} _{i} 的特征中心之间的余弦距离的softmax输出


代表 D_{t} \cup D_{p} 中第j类特征中心

前两项通过分类器的输出和与目标域中伪类的特征中心的相似性方面反映了伪标签预测的置信度

由于分类器更依赖于源于数据,故添加第二项来具体考虑目标域的分布,这样度量函数就可以更好地评估伪标签样本的准确性

在此添加第三项,代表目标未标记数据熵的减少

H 代表当前状态的熵, H’ 代表下个状态的熵

先计算当前状态下的 H ,然后根据动作 a_{i} 添加一个伪标签样本进行训练,一个训练单元结束后计算 H’ 并计算出 Δe ,为了使熵不受熵损失函数的影响,此处只使用目标边际损失来优化模型。

所选样本越有代表性,熵减小就越大,即 Δe 就越大

最终奖励设定
τ 设定值

Deep Q-learing

Q 的目标值, γ 为一个决定未来累计奖励与当前奖励的重要性的折扣因子
deep Q-learning的迭代更新,Ω代表Q网络的参数

使用一个三层全连接网络作为深度Q网络,每个完全连接的层后面都有一个ReLU激活函数,当调用agent即深度Q网络来选择伪标签实例时,使用一下策略输出如下操作

s_{t} 代表当前状态
模型完整算法,交替训练分类网络和Q网络

Experiments

Datasets

DomainNet:大规模的域自适应基准数据集

Office-31:广泛应用的视觉领域适应数据集

Office-Home:办公室和家庭设置中的另一个域自适应数据集

Baselines

S+T、DANN、ADR、CDAN、ENT、MME、TML_SPL(基于置信度的选择性伪标签方法)

Result


将边际参数扩展到源域
验证目标域中标记样本数量从1到20的情况(验证随着目标域中标记样本数量增加算法的性能)
DomainNet数据集
Office-Home数据集
Office-31数据集