Compare commits

...

3 Commits

Author SHA1 Message Date
8d1479def0 Release: v0.24.1 2023-10-30 10:07:36 -04:00
62fcf16429 Fix batch sampler (#2097)
* Fix batch sampler

* Clean

* Fix tests

* Fix

* Better comment

* Base case
2023-10-30 10:07:23 -04:00
00301b27b7 Release: v0.24.0 2023-10-24 12:59:04 -04:00
4 changed files with 53 additions and 5 deletions

View File

@ -34,7 +34,7 @@ extras["sagemaker"] = [
setup(
name="accelerate",
version="0.24.0.dev0",
version="0.24.1",
description="Accelerate",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@ -1,4 +1,4 @@
__version__ = "0.24.0.dev0"
__version__ = "0.24.1"
from .accelerator import Accelerator
from .big_modeling import (

View File

@ -833,9 +833,9 @@ def prepare_data_loader(
synchronized_generator = None
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
sampler = dataloader.sampler.sampler
sampler = getattr(dataloader.sampler, "sampler", None)
else:
sampler = dataloader.batch_sampler.sampler
sampler = getattr(dataloader.batch_sampler, "sampler", None)
if isinstance(sampler, RandomSampler) and num_processes > 1:
# When iterating through the dataloader during distributed processes
# we want to ensure that on each process we are iterating through the same

View File

@ -21,8 +21,9 @@ import time
from copy import deepcopy
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
from accelerate import Accelerator
from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
@ -288,6 +289,52 @@ def central_dl_preparation_check():
print("Shuffled central dataloader passing.")
def custom_sampler_check():
state = AcceleratorState()
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
class CustomBatchSampler:
def __init__(self, dataset_length: int, batch_size: int, shuffle: bool = True):
self.batch_size = batch_size
self.data_index = np.arange(dataset_length)
self.shuffle = shuffle
def __iter__(self):
num_batches = len(self)
if self.shuffle:
index = np.random.permutation(self.data_index)
else:
index = self.data_index
output = np.array_split(index, num_batches)
yield from output
def __len__(self):
return math.ceil(len(self.data_index) / self.batch_size)
dataset = CustomDataset(range(32 * state.num_processes))
sampler = CustomBatchSampler(len(dataset), batch_size=8)
dl = DataLoader(dataset, batch_sampler=sampler)
dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index)
# We need just ensure that `dl.batch_sampler` (or `dl.batch_sampler.batch_sampler` is indeed the old batch sampler
if hasattr(dl.batch_sampler, "batch_sampler"):
assert isinstance(
dl.batch_sampler.batch_sampler, CustomBatchSampler
), "Custom sampler was changed after calling `prepare_data_loader`"
else:
assert isinstance(
dl.batch_sampler, CustomBatchSampler
), "Custom sampler was changed after calling `prepare_data_loader`"
def mock_training(length, batch_size, generator):
set_seed(42)
generator.manual_seed(42)
@ -608,6 +655,7 @@ def main():
dl_preparation_check()
if state.distributed_type != DistributedType.TPU:
central_dl_preparation_check()
custom_sampler_check()
# Trainings are not exactly the same in DeepSpeed and CPU mode
if state.distributed_type == DistributedType.DEEPSPEED: