torch.utils.data.distributed.DistributedSampler 解析
date
May 26, 2022
Last edited time
Mar 27, 2023 08:52 AM
status
Published
slug
torch.utils.data.distributed.DistributedSampler解析
tags
DL
summary
type
Post
Field
Plat
无卡虚空看代码
分布式训练时,torch.utils.data.distributed.DistributedSampler 做了什么?
试验用到的 code
试验过程
执行 code: case 1, 不使用
torch.utils.data.distributed.DistributedSampler
, 结果显示,每块卡上(每个进程)每个 epoch 中都迭代了所有的数据。执行 code: case 2, 使用
torch.utils.data.distributed.DistributedSampler
, 结果显示,数据被平分到两块卡上,每个 epoch 被分配到每块卡上的数据都一样。为了解决 case 2 中每块卡上分配的数据相同的问题,执行 code: case 3, 在每个
epoch
中加入sampler.set_epoch(epoch)
执行 code: case 4, 数据集里有 6 例数据,在两张卡,
batch_size=4, drop_last=False
时,每张卡上平均分配了 3 例数据;当drop_last=True
时,不足 4 例数据的会被丢掉,在数据集只有 6 例数据时,每张卡上分配的 3 例数据都会被丢掉;执行 code: case 5, 数据集里有 5 例数据,两张卡,
batch_size=4, drop_last=False
时,每张卡上平均分配了 2.5 例数据, 会向上补齐到 6 例数据,每张卡上三张,补齐的标准是把数据集的第一例数据(在本例 1 中 index=4)用来补齐;如果将sampler.set_epoch(epoch)
加入其中,补齐标准不变,
在本例 2 中,第一个 epoch 补齐的是index=4
,第二个 epoch 补齐的是index=0
当多进程同时工作时,执行
case 6
时,有的迭代中,会出现batch_size=1
的情况,如果模型中存在BatchNormalize
这样的模块时,运行可能报错。为了避免
case 6
这种情况,可以引入BatchSampler
这样的模块,运行 case 7, 将drop_last=True