Pix2seq: A Language Modeling Framework for Object Detection 论文阅读
date
Oct 28, 2022
Last edited time
Mar 27, 2023 08:44 AM
status
Published
slug
Pix2seq论文阅读
tags
DL
CV
summary
转
type
Post
Field
Plat
Motivation
作者提出 Pix2seq 的动机也是从减少归纳偏置的角度出发的,具体来说,现有的目标检测算法都太复杂了,有许多针对检测任务的人为设计,比如说
- Bounding boxes, region proposals, RoI Pooling
- Tailored loss functions like box regression
- Specific performance metrics like IoU over bounding boxes
这些特殊的设计导致了模型相对来说比较复杂,这就导致了
- 训练和维护成本比较高
- 泛化性不太好
- 例如增加不同尺寸、形状的目标需要设计新的 bounding box
- 很难将模型延展到其他任务上
所以作者受到 Vision Transformer 成功的启发,思考能不能去掉这些,作者用 image captioning 的方式来解决目标检测问题,也就是输入一张图片,模型就输出一个序列,包含图片中有哪些物体,和每个物体相应的坐标:
整体框架
如上图所示,Pix2seq 框架参照的是 image captioning 的方式:
- 对输入图片做常规的图像增强:e.g., with random scaling and crops
- 将目标构造成一个序列,包括每个目标的 bounding box 坐标以及目标类别,序列的最后是一个 EOS token:[y1_min, x1_min, y1_max, x1_max, CLS1, y2_min, x2_min, y2_max, x2_max, CLS2, …, EOS]
- 模型的 Architecture 是: Image Encoder + Language Decoder
- Objective / Loss Function: Softmax cross-entropy loss
构造目标序列
具体来说,每个目标都是可以用 5 个 token 来表示的序列:[y_min, x_min, y_max, x_max, CLS]。那 y_min, x_min, y_max, x_max 这四个点如何取值呢? 作者用的是 quantization 的方法,也就是将图片高和宽分成 bins 等份,根据对应的坐标取值。当然, bins 的取值决定了目标检测的分辨率,取值越大分的块就越多,相应的识别分辨率就越高,越能够检测到小物体。
如下图所示,对于 512 x 512 的图片,即使是每个像素一个 bin,也才 512 个 bin,NLP 模型的词汇量几乎都是上万的,所以理论上 quantization 这个方法在输出维度上是可行的。
由于对象的顺序对于检测任务本身并不重要,因此对于一张图片里有多个目标的情况,我们使用随机排序策略。
我们还探索了其他确定性排序策略,下图中展示了三种排序方法:
- 随机排序
- area: 根据目标大小排序
- dist2ori: 根据目标 bounding box 左上角坐标离原点的距离排序
下图是各种排序方式的 Precision 和 Recall 实验结果:
Architecture, Objective and Inference
论文中这部分只有简单介绍,几笔带过,因为基本就是采用了 image captioning 的一套流程,不需要过多讨论。
其中 Architecture 就是 Image Encoder (CNN, Vision Transformer, etc.) + Language Decoder (Transformer)。
损失函数采用的是语言模型常用最大似然损失函数,模型推理时根据上一个预测的 token,结合输入图片的特征向量,预测下一个 token,直到预测出 EOS token。当然,语言模型在推理时对于 token 的预测有很多采样方法,论文里提到了 Nucleus Sampling 这个方法最好,这里不展开讨论。
其中, 为目标 token, 是图像以及上一步为止预测出来的 token。
目标检测的 Recall 问题
作者在目标检测实验中还遇到了 Recall 比较低的情况,也就是模型没有检测出所有的目标就停止了(输出 [EOS] token)。作者提出可能有两方面的原因:
- 数据标注问题:标注者在标注数据的时候难以避免地漏掉了一些物体,因为是监督学习,模型可能会学习标注者的行为,也会经常漏掉一些容易漏掉的物体
- 存在较难预测的数据,例如一些小的物体或者不清晰的物体
最直观的解决方式当然是尽量延迟模型对于 [EOS] token 的预测,迫使模型多预测一些物体,即使是信心比较低的物体。然而这么做可能会导致模型容易输出重复的目标,Recall 问题解决了,却造成了 Precision 问题。
所以本文提出了 Sequence Augmentation 方法,在延迟模型对于 [EOS] token 的预测的同时,缓和 Precision 降低的问题。
具体来说,就是对训练图片增加噪声,有两种方式:
- 对原有的 bounding box 添加噪声(e.g., random scaling or shifting bounding boxes)来生成新的 bounding box
- 生成完全随机的 bounding boxes
如下图所示,红色的框是标注的 bounding box,白色的框都是生成的伪标签。
当然,这么干的话监督训练时就需要对伪标签进行特殊处理,首先,下图是没有伪标签的预测序列和目标序列,二者是完美对齐的:
做了 Sequence Augmentation 后,预测序列多出了黄色的部分,相应的目标序列也需要特殊处理。其中伪目标的坐标不参与损失计算(下图中 n/a 部分),伪目标的类别标记为新类别 [noise],参与损失计算。通过这种方式,一方面鼓励了模型尽量多输出目标,另一方面赋予了模型一定程度上学会将冗余输出目标标记为 [noise] 类别的能力,在提升 recall 的同时避免降低 precision。当然,在模型做推理的时候,只要对 [noise] 类别做一些特殊处理就行了。
实验结果
作者的实验是在 MS-COCO 2017 上做的,主要有两个部分。
一个是 train from scratch,另一个是在更大的数据集 Objects365 上做预训练,再到 COCO 上迁移学习。理论上来说 Pix2seq 没有针对检测任务的归纳偏置,迁移学习的应该会比过去的检测模型表现更好。
上图是 train from scratch 的结果,可以看出来 Pix2seq 与其他模型的表现是相当的。作者特别提出 Pix2seq 在检测小目标上更加出色,应该是 Sequence Augmentation 的功劳。
后话
先怀疑一下,Pix2seq 用的是生成式的方式来做目标检测,每个目标的坐标都需要一个一个预测出来,预测速度估计会是个问题。
另外,这篇论文给人的启发是深度学习越来走向了大一统的方向,Transfromer 既可以解决 NLP 问题,也可以解决 CV 问题,Pix2seq 表明语言模型也可以用来解决 CV 的目标检测问题,另外 ViG 这篇论文用图网络解决 CV 问题同样取得了 SOTA 的结果。既然网络结构如此通用,做多模态学习就会很方便。
从反面来看,监督学习领域也确实遇到了瓶颈,网络模型结构已经不是关键所在,正如吴恩达所说,在下一个突破到来之前,监督学习领域的关键是数据的质量。