pytorch 混合精度训练
date
May 27, 2022
Last edited time
Mar 27, 2023 08:52 AM
status
Published
slug
pytorch混合精度训练
tags
DL
summary
type
Post
Field
Plat
简介
自动混合精度训练(auto Mixed Precision,amp)是深度学习比较流行的一个训练技巧,它可以大幅度降低训练的成本并提高训练的速度,因此在竞赛中受到了较多的关注。此前,比较流行的混合精度训练工具是由 NVIDIA 开发的 A PyTorch Extension(Apex),它能够以非常简单的 API 支持自动混合精度训练,不过,PyTorch 从 1.6 版本开始已经内置了 amp 模块,本文简单介绍其使用。
自动混合精度(AMP)
首先来聊聊自动混合精度的由来。下图是常见的浮点数表示形式,它表示单精度浮点数,在编程语言中的体现是 float 型,显然从图中不难看出它需要 4 个 byte 也就是 32bit 来进行存储。深度学习的模型数据均采用 float32 进行表示,这就带来了两个问题:模型 size 大,对显存要求高;32 位计算慢,导致模型训练和推理速度慢。
那么半精度是什么呢,顾名思义,它只用 16 位即 2byte 来进行表示,较小的存储占用以及较快的运算速度可以缓解上面 32 位浮点数的两个主要问题,因此半精度会带来下面的一些优势:
- 显存占用更少,模型只有 32 位的一半存储占用,这也可以使用更大的 batch size 以适应一些对大批尺寸有需求的结构,如 Batch Normalization;
- 计算速度快,float16 的计算吞吐量可以达到 float32 的 2-8 倍左右,且随着 NVIDIA 张量核心的普及,使用半精度计算已经比较成熟,它会是未来深度学习计算的一个重要趋势。
那么,半精度有没有什么问题呢?其实也是有着很致命的问题的,主要是移除错误和舍入误差两个方面,具体可以参考这篇文章,作者解析的很好,我这里就简单复述一下。
溢出错误
FP16 的数值表示范围比 FP32 的表示范围小很多,因此在计算过程中很容易出现上溢出(overflow)和下溢出(underflow)问题,溢出后会出现梯度 nan 问题,导致模型无法正确更新,严重影响网络的收敛。而且,深度模型训练,由于激活函数的梯度往往比权重的梯度要小,更容易出现的是下溢出问题。
舍入误差
舍入误差(Rounding Error)指的是当梯度过小,小于当前区间内的最小间隔时,该次梯度更新可能会失败。上面说的知乎文章的作者用来一张很形象的图进行解释,具体如下,意思是说在 2 − 3 2^{-3} 2−3 到 2 − 2 2^{-2} 2−2 之间, 2 − 3 2^{-3} 2−3 每次变大都会至少加上 2 − 13 2^{-13} 2−13,显然,梯度还在这个间隔内,因此更新是失败的。
那么这两个问题是如何解决的呢,思路来自于 NVIDIA 和百度合作的论文,我这里简述一下方法:混合精度训练和损失缩放。前者的思路是在内存中使用 FP16 做储存和乘法运算以加速计算,用 FP32 做累加运算以避免舍入误差,这样就缓解了舍入误差的问题;后者则是针对梯度值太小从而下溢出的问题,它的思想是:反向传播前,将损失变化手动增大 2 k 2^k 2k 倍,因此反向传播时得到的中间变量(激活函数梯度)则不会溢出;反向传播后,将权重梯度缩小 2 k 2^k 2k 倍,恢复正常值。
研究人员通过引入 FP32 进行混合精度训练以及通过损失缩放来解决 FP16 的不足,从而实现了一套混合精度训练的范式,NVIDIA 以此为基础设计了 Apex 包,不过 Apex 的使用本文就不涉及了,下一节主要关注如何使用 torch.cuda.amp 实现自动混合精度训练,不过这里还需要补充的一点就是目前混合精度训练支持的 N 卡只有包含 Tensor Core 的卡,如 2080Ti、Titan、Tesla 等。
PyTorch 自动混合精度
PyTorch 对混合精度的支持始于 1.6 版本,位于
torch.cuda.amp
模块下,主要是torch.cuda.amp.autocast
和torch.cuda.amp.GradScale
两个模块,autocast 针对选定的代码块自动选取适合的计算精度,以便在保持模型准确率的情况下最大化改善训练效率;GradScaler 通过梯度缩放,以最大程度避免使用 FP16 进行运算时的梯度下溢。官方给的使用这两个模块进行自动精度训练的示例代码链接给出,我对其示例解析如下,这就是一般的训练框架。下面我以简单的 MNIST 任务做测试,使用的显卡为 RTX 3090,代码如下。该代码段中只包含核心的训练模块,模型的定义和数据集的加载熟悉 PyTorch 的应该不难自行补充。
我这里采用的是一个很小的模型,又是一个很简单的任务,因此模型都是很快收敛,因此精度上没有什么明显的区别,不过如果是训练大型模型的话,有人已经用实验证明,内置 amp 和 apex 库都会有精度下降,不过 amp 效果更好一些,下降较少。上面的 loss 变化图也是非常类似的。
再来看存储方面,显存缩减在这个任务中的表现不是特别明显,因为这个任务的参数量不多,前后向过程中的 FP16 存储节省不明显,而因为引入了一些拷贝之类的,反而使得显存略有上升,实际的任务中,这种开销肯定远小于 FP32 的开销的。
最后,不妨看一下使用混合精度最关心的速度问题,实际上混合精度确实会带来一些速度上的优势,一些官方的大模型如 BERT 等训练速度提高了 2-3 倍,这对于工业界的需求来说,启发还是比较多的。
总结
混合精度计算是未来深度学习发展的重要方向,很受工业界的关注,PyTorch 从 1.6 版本开始默认支持 amp,虽然现在还不是特别完善,但以后一定会越来越好,因此熟悉自动混合精度的用法还是有必要的。