OHEM 详细解析

date
May 6, 2022
Last edited time
Mar 27, 2023 08:59 AM
status
Published
slug
OHEM详细解析
tags
DL
CV
summary
type
Post
Field
Plat
标题:Training Region-based Object Detectors with Online Hard Example Mining

前言

首先,这篇文章干了一件什么事儿?
提出了一种困难负样本挖掘的方法。困难负样本是指 AI 模型难以区分的负样本。在模型不断训练的过程中,模型通常会对正样本有着比较高的 confidence,但少不了对某些负样本也留有余芥,给了一个不那么接近 0 的 confidence。而困难负例挖掘就是找到这些负例,然后针对性地训练。OHEM 提出是一种线上的困难负例挖掘解决方案。使用了这个 trick 以后,检测模型的准确性有一定提升。

OHEM

OHEM 的基准算法是 Fast R-CNN,目的就是对其进行一丢丢改进就大大提高其性能。其原理可以浓缩到一张图:
notion image
图 1 OHEM 示意图
仔细看图 1 中的 (b) 模块,这是比 Fast R-CNN 多出的一部分。挪去这一部分,就是一个初始的 Fast R-CNN 模型了。对于像 Fast R-CNN 这种 2-stage 的检测模型,都可以抽象成 “推荐” 和“分类回归”两个部分。而 OHEM 的作用恰好是两个部分的中间。
那我就对 2-stage 模型统一讲:在推荐网络 (下文称 “A 部分”,Fast RCNN 用的是 selective search,Faster RCNN 用的是 RPN) 之后,会有很多 ROI 传输到后面网络。这部分的输出便是推荐网络所推荐的 “可能存在目标的位置 “,以 image patch 的形式给到后面的分类和回归网络 (下文称 “B 部分”)。
通常 “A 部分” 会给数以千计的小片 (你可以叫 ROI,也可以叫推荐框,还可以叫 image patch) 传输到 “B 部分”,B 部分会对这些小片进行分类、坐标和尺寸回归以及置信度打分 (confidence, 置信度是指,模型对这个输出有多大把握)。如果在训练阶段,那可以通过这些参数与标注结果进行计算得出损失值
而 OHEM,恰恰利用了这个损失值。
有了上面的基础,下面详细介绍 OHEM 的操作流程:
  1. 正常进行一次 Faster RCNN 的前向传播,获得每个小片单独的损失值;
  1. 对小片们进行非极大值抑制 (NMS),不了解 NMS 点链接去了解,非常简单;
  1. 对 nms 之后剩下的小片按损失值进行排序,然后选用损失值最大的前一部分小片当作输入再进一遍 B 部分,在原文描述如下:
Hard examples are selected by sorting the input RoIs by loss and taking the B/N examples for which the current network performs worst.
到第 3 步可以发现,通过这种方法,可以屏蔽掉 loss 值非常低的小片。loss 值非常高的小片意味着,模型训练很多次还对这些小片有着很高的 loss,那么就认为这是困难负例。所谓的线上挖掘,就是先计算 loss-> 筛选 -> 得到困难负例。
  1. 把困难负例输入到图 1 中 (b) 模块,(b)模块是 (a) 模块的 copy 版,连参数都是一样的。只是 (a) 模块是不可训练的,用于寻找困难负例嘛。(b)模块是用来反向传播的部分,然后把更新的参数共享到 (a) 部分 (a 部分也跟着一起更新)。其实在程序的实现上,(a) 模块可以复用,不需要额外来一个 (b) 模块。作者之所以这么设计的原因有俩:1,ab 部分可以并行操作:a 对下一张图像进行前向,b 对上一张图像进行反向。2,这样看着更高级点,更像顶会能中的文章的样子。
总结
OHEM 可以帮助 2-stage 检测算法提升训练效果,通过对 ROI loss 值进行排序从而筛选出 loss 值非常大的 ROI,这便是所谓的 “困难负例”。
 

© Lazurite 2021 - 2024