Uncertainty Modeling for Out-of-Distribution Generalization 论文阅读
date
Oct 19, 2022
Last edited time
Mar 27, 2023 08:45 AM
status
Published
slug
Uncertainty_Modeling_for_Out-of-Distribution_Generalization论文阅读
tags
DL
summary
转载
type
Post
Field
Plat
IntroductionModeling Domain Shifts with UncertaintyUncertainty EstimationDistribution of Feature Statistics and ImplementationExperimentsMulti-domain classificationSemantic segmentationInstance retrievalRobustness towards corruptions代码实现
Introduction
深度神经网络在CV领域取得了很大的成功,但是这依赖于一个假设,那就是训练和测试数据是独立同分布的。但在很多现实场景的应用中,这种假设是不成立的。现实中的测试数据和训练数据的不同分布会给网络性能带来不可逆转的下降。
针对这个问题,域泛化成为了一个热门的研究方向,它研究的问题是从若干个具有不同数据分布的数据集中学习一个泛化能力强的模型,以便在未知测试集上取得较好的效果。域泛化与域适应最大的不同是:域适应在训练过程中,源域与目标域数据均能访问;而域泛化问题中,我们只能访问若干个用于训练的源域数据,测试数据是不能访问的。毫无疑问,域泛化比域适应更有挑战性和适用性,毕竟我们都喜欢“一次训练,到处应用”的足够泛化的模型。
有之前的工作(Arbitrary style transfer in real-time with adaptive instance normalization 和 On feature normalization and data augmentation)指出,网络学习到的特征统计量(feature statistics)包含了域的特征(domain characteristic)。这里的统计量指的是特征的均值和标准差,域的特征指与任务目标无关、与域更相关的信息,比如说画作的风格。所以,来自不同域的数据,提取到的特征的统计量应该是不同的。直观上来说,未知的目标域或者说测试域会使特征统计量往不同方向偏移,而这种域偏移是不确定的。所以这篇文章的核心idea就是:将特征统计量建模成一个不确定的分布,在分布中特征统计量的不同采样可以提升模型在不同目标域的泛化性。
Modeling Domain Shifts with Uncertainty
这篇文章提出的方法叫做modeling Domain Shifts with Uncertainty (DSU),考虑目标域的不确定性,假设特征统计量都服从多元高斯分布,
高斯分布的中心是原始的特征统计量,高斯分布的标准差代表域偏移的不确定性范围。所以通过对特征统计量分布的随机采样,模型可以适应不同的域偏移,得到泛化性的提升。
Uncertainty Estimation
是网络中间层的一个特征图,我们用 和 代表在一个minibatch内,通道间特征的均值和标准差:
对于特征的统计量,我们可以估计它们的不确定性即方差:
这里的 ,代表不确定性的估计,它们的数值大小反映了每个minibatch特征通道变化的大小和潜在的域偏移。
Distribution of Feature Statistics and Implementation
当我们估计了每个通道的不确定性之后,我们可以用VAE中的重参数化技巧采样出统计量:
这种随机采样得到的特征统计量可以模拟出各种各样的域偏移。我们对特征图进行如下变化,使特征图服从经过不确定性变换后的特征统计量:
相当于先把特征图变成一个标准高斯分布,然后再把分布的均值变为 方差变为 。这种变换可以插入到网络的任意位置,可以设置一个参数 来控制应用DSU的概率。此外,DSU只在训练时发挥作用,在测试时不起作用。
Experiments
为了方法的有效性,作者在图像分类、语义分割、实例检索和 robustness to corruptions 等任务上都做了实验。并且包含不同种类的域偏移,比如风格偏移、真实与合成图像、场景变换、像素corruption。
Multi-domain classification
这个实验使用的数据集是PACS数据集,包含了画作、卡通、照片和素描四种风格的图像。使用的backbone是ResNet18。实验遵循leave-one-domain-out protocal,也就是说在三种风格上训练,在剩下一种风格上测试。
Semantic segmentation
语义分割的场景是街景分割。模型是在GTA5游戏数据集上训练并在CityScapes真实数据集上测试。使用的backbone是DeepLab-v2。
Instance retrieval
这个实验的具体任务是行人再识别(将同一个行人在两个摄像头下关联起来)。实验在DukeMTMC和Market1501两个数据集上进行,backbone是ResNet50。
Robustness towards corruptions
实验在ImageNet-C上进行,包含15种不同的像素级corruption。APR是另一种在ImageNet-C上SOTA的方法。mCE代表mean corrpution error。