Fix seeding of new generator for multi GPU (#3459)

* fix new generator seeding

* remaining arbitrary fixed seed

* test
This commit is contained in:
Albert Thomas
2025-03-28 17:48:05 +01:00
committed by GitHub
parent 803b6648b4
commit 3f636d6260
2 changed files with 36 additions and 2 deletions

View File

@ -1181,7 +1181,9 @@ def prepare_data_loader(
# isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
generator = torch.Generator(
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
).manual_seed(42)
)
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator.manual_seed(seed)
dataloader.generator = generator
dataloader.sampler.generator = generator
# No change if no multiprocess
@ -1203,6 +1205,8 @@ def prepare_data_loader(
sampler.generator = torch.Generator(
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
)
seed = int(torch.empty((), dtype=torch.int64).random_().item())
sampler.generator.manual_seed(seed)
synchronized_generator = sampler.generator
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
new_batch_sampler = BatchSamplerShard(

View File

@ -34,7 +34,7 @@ from accelerate.data_loader import (
)
from accelerate.state import GradientState
from accelerate.test_utils.testing import AccelerateTestCase, require_torchdata_stateful_dataloader
from accelerate.utils import is_torchdata_stateful_dataloader_available
from accelerate.utils import is_torchdata_stateful_dataloader_available, set_seed
if is_torchdata_stateful_dataloader_available():
@ -422,6 +422,36 @@ class DataLoaderTester(AccelerateTestCase):
for d in dataloader:
assert isinstance(d, torch.Tensor)
@parameterized.expand([1, 2], name_func=parameterized_custom_name_func)
def test_reproducibility(self, num_processes):
set_seed(21)
dataset = list(range(6))
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
dataloader = prepare_data_loader(dataloader, num_processes=num_processes)
vals_1 = []
for val in dataloader:
vals_1.append(val)
# check same order for same seed
set_seed(21)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
dataloader = prepare_data_loader(dataloader, num_processes=num_processes)
vals_2 = []
for val in dataloader:
vals_2.append(val)
assert vals_1 == vals_2
# check different order for different seed
set_seed(42)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
dataloader = prepare_data_loader(dataloader, num_processes=num_processes)
vals_3 = []
for val in dataloader:
vals_3.append(val)
assert vals_1 != vals_3
def test_skip_batch_sampler(self):
batch_sampler = BatchSampler(range(16), batch_size=4, drop_last=False)
new_batch_sampler = SkipBatchSampler(batch_sampler, 2)