mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-13 20:34:58 +08:00
Compare commits
3 Commits
make-versi
...
v0.24.1
| Author | SHA1 | Date | |
|---|---|---|---|
| 8d1479def0 | |||
| 62fcf16429 | |||
| 00301b27b7 |
2
setup.py
2
setup.py
@ -34,7 +34,7 @@ extras["sagemaker"] = [
|
||||
|
||||
setup(
|
||||
name="accelerate",
|
||||
version="0.24.0.dev0",
|
||||
version="0.24.1",
|
||||
description="Accelerate",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
__version__ = "0.24.0.dev0"
|
||||
__version__ = "0.24.1"
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .big_modeling import (
|
||||
|
||||
@ -833,9 +833,9 @@ def prepare_data_loader(
|
||||
synchronized_generator = None
|
||||
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
||||
if sampler_is_batch_sampler:
|
||||
sampler = dataloader.sampler.sampler
|
||||
sampler = getattr(dataloader.sampler, "sampler", None)
|
||||
else:
|
||||
sampler = dataloader.batch_sampler.sampler
|
||||
sampler = getattr(dataloader.batch_sampler, "sampler", None)
|
||||
if isinstance(sampler, RandomSampler) and num_processes > 1:
|
||||
# When iterating through the dataloader during distributed processes
|
||||
# we want to ensure that on each process we are iterating through the same
|
||||
|
||||
@ -21,8 +21,9 @@ import time
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
|
||||
@ -288,6 +289,52 @@ def central_dl_preparation_check():
|
||||
print("Shuffled central dataloader passing.")
|
||||
|
||||
|
||||
def custom_sampler_check():
|
||||
state = AcceleratorState()
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
class CustomBatchSampler:
|
||||
def __init__(self, dataset_length: int, batch_size: int, shuffle: bool = True):
|
||||
self.batch_size = batch_size
|
||||
self.data_index = np.arange(dataset_length)
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
num_batches = len(self)
|
||||
if self.shuffle:
|
||||
index = np.random.permutation(self.data_index)
|
||||
else:
|
||||
index = self.data_index
|
||||
output = np.array_split(index, num_batches)
|
||||
yield from output
|
||||
|
||||
def __len__(self):
|
||||
return math.ceil(len(self.data_index) / self.batch_size)
|
||||
|
||||
dataset = CustomDataset(range(32 * state.num_processes))
|
||||
sampler = CustomBatchSampler(len(dataset), batch_size=8)
|
||||
dl = DataLoader(dataset, batch_sampler=sampler)
|
||||
dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index)
|
||||
# We need just ensure that `dl.batch_sampler` (or `dl.batch_sampler.batch_sampler` is indeed the old batch sampler
|
||||
if hasattr(dl.batch_sampler, "batch_sampler"):
|
||||
assert isinstance(
|
||||
dl.batch_sampler.batch_sampler, CustomBatchSampler
|
||||
), "Custom sampler was changed after calling `prepare_data_loader`"
|
||||
else:
|
||||
assert isinstance(
|
||||
dl.batch_sampler, CustomBatchSampler
|
||||
), "Custom sampler was changed after calling `prepare_data_loader`"
|
||||
|
||||
|
||||
def mock_training(length, batch_size, generator):
|
||||
set_seed(42)
|
||||
generator.manual_seed(42)
|
||||
@ -608,6 +655,7 @@ def main():
|
||||
dl_preparation_check()
|
||||
if state.distributed_type != DistributedType.TPU:
|
||||
central_dl_preparation_check()
|
||||
custom_sampler_check()
|
||||
|
||||
# Trainings are not exactly the same in DeepSpeed and CPU mode
|
||||
if state.distributed_type == DistributedType.DEEPSPEED:
|
||||
|
||||
Reference in New Issue
Block a user