torch.utils.data.distributed.DistributedSampler 解析

date
May 26, 2022
Last edited time
Mar 27, 2023 08:52 AM
status
Published
slug
torch.utils.data.distributed.DistributedSampler解析
tags
DL
summary
type
Post
Field
Plat
无卡虚空看代码

分布式训练时,torch.utils.data.distributed.DistributedSampler 做了什么?

试验用到的 code

试验过程

执行 code: case 1, 不使用torch.utils.data.distributed.DistributedSampler, 结果显示,每块卡上(每个进程)每个 epoch 中都迭代了所有的数据。
执行 code: case 2, 使用torch.utils.data.distributed.DistributedSampler, 结果显示,数据被平分到两块卡上,每个 epoch 被分配到每块卡上的数据都一样。
为了解决 case 2 中每块卡上分配的数据相同的问题,执行 code: case 3, 在每个epoch中加入sampler.set_epoch(epoch)
执行 code: case 4, 数据集里有 6 例数据,在两张卡,batch_size=4, drop_last=False时,每张卡上平均分配了 3 例数据;当drop_last=True时,不足 4 例数据的会被丢掉,在数据集只有 6 例数据时,每张卡上分配的 3 例数据都会被丢掉;
执行 code: case 5, 数据集里有 5 例数据,两张卡,batch_size=4, drop_last=False时,每张卡上平均分配了 2.5 例数据, 会向上补齐到 6 例数据,每张卡上三张,补齐的标准是把数据集的第一例数据(在本例 1 中 index=4)用来补齐;如果将sampler.set_epoch(epoch)加入其中,补齐标准不变, 在本例 2 中,第一个 epoch 补齐的是index=4,第二个 epoch 补齐的是index=0
当多进程同时工作时,执行case 6时,有的迭代中,会出现batch_size=1的情况,如果模型中存在BatchNormalize这样的模块时,运行可能报错。
为了避免case 6这种情况,可以引入BatchSampler这样的模块,运行 case 7, 将drop_last=True

© Lazurite 2021 - 2023