mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
Fix seeding of new generator for multi GPU (#3459)
* fix new generator seeding * remaining arbitrary fixed seed * test
This commit is contained in:
@ -1181,7 +1181,9 @@ def prepare_data_loader(
|
||||
# isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
|
||||
generator = torch.Generator(
|
||||
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
|
||||
).manual_seed(42)
|
||||
)
|
||||
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
||||
generator.manual_seed(seed)
|
||||
dataloader.generator = generator
|
||||
dataloader.sampler.generator = generator
|
||||
# No change if no multiprocess
|
||||
@ -1203,6 +1205,8 @@ def prepare_data_loader(
|
||||
sampler.generator = torch.Generator(
|
||||
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
|
||||
)
|
||||
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
||||
sampler.generator.manual_seed(seed)
|
||||
synchronized_generator = sampler.generator
|
||||
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
|
||||
new_batch_sampler = BatchSamplerShard(
|
||||
|
@ -34,7 +34,7 @@ from accelerate.data_loader import (
|
||||
)
|
||||
from accelerate.state import GradientState
|
||||
from accelerate.test_utils.testing import AccelerateTestCase, require_torchdata_stateful_dataloader
|
||||
from accelerate.utils import is_torchdata_stateful_dataloader_available
|
||||
from accelerate.utils import is_torchdata_stateful_dataloader_available, set_seed
|
||||
|
||||
|
||||
if is_torchdata_stateful_dataloader_available():
|
||||
@ -422,6 +422,36 @@ class DataLoaderTester(AccelerateTestCase):
|
||||
for d in dataloader:
|
||||
assert isinstance(d, torch.Tensor)
|
||||
|
||||
@parameterized.expand([1, 2], name_func=parameterized_custom_name_func)
|
||||
def test_reproducibility(self, num_processes):
|
||||
set_seed(21)
|
||||
dataset = list(range(6))
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
|
||||
dataloader = prepare_data_loader(dataloader, num_processes=num_processes)
|
||||
vals_1 = []
|
||||
for val in dataloader:
|
||||
vals_1.append(val)
|
||||
|
||||
# check same order for same seed
|
||||
set_seed(21)
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
|
||||
dataloader = prepare_data_loader(dataloader, num_processes=num_processes)
|
||||
vals_2 = []
|
||||
for val in dataloader:
|
||||
vals_2.append(val)
|
||||
|
||||
assert vals_1 == vals_2
|
||||
|
||||
# check different order for different seed
|
||||
set_seed(42)
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
|
||||
dataloader = prepare_data_loader(dataloader, num_processes=num_processes)
|
||||
vals_3 = []
|
||||
for val in dataloader:
|
||||
vals_3.append(val)
|
||||
|
||||
assert vals_1 != vals_3
|
||||
|
||||
def test_skip_batch_sampler(self):
|
||||
batch_sampler = BatchSampler(range(16), batch_size=4, drop_last=False)
|
||||
new_batch_sampler = SkipBatchSampler(batch_sampler, 2)
|
||||
|
Reference in New Issue
Block a user