mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-18 00:14:36 +08:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 00301b27b7 | |||
| eb8c535c17 | |||
| b7686ccb44 | |||
| f3229872bc | |||
| 7843286f2e | |||
| 11e2e99cfc | |||
| 07e745f1c4 | |||
| c7c99a30ea | |||
| 8f45a2eae8 | |||
| 9fd64b7ea9 | |||
| 5be16ad90b | |||
| dab62832de | |||
| caa9f9bcbb | |||
| 943efedb88 | |||
| 50acb0c2ec | |||
| e6d96e5f70 | |||
| 1dfb6e9304 | |||
| 4bef6bc511 | |||
| 73640d0463 | |||
| 7a1159143e | |||
| cbb0b82fa2 | |||
| 5ae6111180 | |||
| 230a5f541b | |||
| 956114ac92 |
1
.github/workflows/build_documentation.yml
vendored
1
.github/workflows/build_documentation.yml
vendored
@ -14,5 +14,4 @@ jobs:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: accelerate
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
|
||||
@ -51,9 +51,9 @@ jobs:
|
||||
run: |
|
||||
source activate accelerate
|
||||
git config --global --add safe.directory '*'
|
||||
git checkout main && git pull
|
||||
git checkout main && git pull && git fetch --tags
|
||||
if [[ ${{ matrix.transformers-version }} = pypi ]]; then
|
||||
git checkout $(git describe --tags `git rev-list --tags --max-count=1`)
|
||||
git checkout $(git tag --sort=taggerdate | tail -1)
|
||||
fi
|
||||
pip install .[torch,deepspeed-testing]
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ will attempt to fill all the space in your GPU(s), then loading them to the CPU,
|
||||
|
||||
<Tip>
|
||||
|
||||
For more details on desigining your own device map, see this section of the [concept guide](../concept_guide/big_model_inference#desigining-a-device-map)
|
||||
For more details on desigining your own device map, see this section of the [concept guide](../concept_guide/big_model_inference#designing-a-device-map)
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
@ -96,6 +96,8 @@ all-gather while executing in the forward pass. only use with Static graphs.
|
||||
Useful in cases such as parameter-efficient fine-tuning.
|
||||
Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019)
|
||||
|
||||
`CPU RAM Efficient Model loading`: If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. Only applicable for 🤗 Transformers models. This should be set to False if you experience errors when loading the pretrained 🤗 Transformers model via `from_pretrained` method. When using this, `Sync Module States` needs to be True else all the processes expect the main process would have random empty weights leading to unexpected behaviour during training.
|
||||
|
||||
`Sync Module States`: If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0
|
||||
```
|
||||
|
||||
|
||||
@ -32,6 +32,27 @@ Currently we support searching for models that can be used in `timm` and `transf
|
||||
|
||||
</Tip>
|
||||
|
||||
## Gradio Demos
|
||||
|
||||
Below are a few gradio demos related to what was described above. The first is the official Hugging Face memory estimation space, utilizing Accelerate directly:
|
||||
|
||||
<div class="block dark:hidden">
|
||||
<iframe
|
||||
src="https://hf-accelerate-model-memory-usage.hf.space?__theme=light"
|
||||
width="850"
|
||||
height="1600"
|
||||
></iframe>
|
||||
</div>
|
||||
<div class="hidden dark:block">
|
||||
<iframe
|
||||
src="https://hf-accelerate-model-memory-usage.hf.space?__theme=dark"
|
||||
width="850"
|
||||
height="1600"
|
||||
></iframe>
|
||||
</div>
|
||||
|
||||
A community member has taken the idea and expended it further, allowing you to filter models directly and see if you can run a particular LLM given GPU constraints and LoRA configurations. To play with it, see [here](https://huggingface.co/spaces/Vokturz/can-it-run-llm) for more details.
|
||||
|
||||
## The Command
|
||||
|
||||
When using `accelerate estimate-memory`, you need to pass in the name of the model you want to use, potentially the framework
|
||||
@ -113,9 +134,4 @@ This calculator will tell you how much memory is needed to purely load the model
|
||||
This calculation is accurate within a few % of the actual value, so it is a very good view of just how much memory it will take. For instance loading `bert-base-cased` actually takes `413.68 MB` when loaded on CUDA in full precision, and the calculator estimates `413.18 MB`.
|
||||
|
||||
When performing inference you can expect to add up to an additional 20% as found by [EleutherAI](https://blog.eleuther.ai/transformer-math/). We'll be conducting research into finding a more accurate estimate to these values, and will update
|
||||
this calculator once done.
|
||||
|
||||
## Live Gradio Demo
|
||||
|
||||
Lastly, we invite you to try the [live Gradio demo](https://huggingface.co/spaces/hf-accelerate/model-memory-usage) of this utility,
|
||||
which includes an option to post a discussion thread on a models repository with this data. Doing so will help provide access to these numbers in the community faster and help users know what you've learned!
|
||||
this calculator once done.
|
||||
6
setup.py
6
setup.py
@ -19,7 +19,9 @@ extras = {}
|
||||
extras["quality"] = ["black ~= 23.1", "ruff >= 0.0.241", "hf-doc-builder >= 0.3.0", "urllib3 < 2.0.0"]
|
||||
extras["docs"] = []
|
||||
extras["test_prod"] = ["pytest", "pytest-xdist", "pytest-subtests", "parameterized"]
|
||||
extras["test_dev"] = ["datasets", "evaluate", "transformers", "scipy", "scikit-learn", "deepspeed", "tqdm", "bitsandbytes", "timm"]
|
||||
extras["test_dev"] = [
|
||||
"datasets", "evaluate", "transformers", "scipy", "scikit-learn", "deepspeed", "tqdm", "bitsandbytes", "timm"
|
||||
]
|
||||
extras["testing"] = extras["test_prod"] + extras["test_dev"]
|
||||
extras["rich"] = ["rich"]
|
||||
|
||||
@ -32,7 +34,7 @@ extras["sagemaker"] = [
|
||||
|
||||
setup(
|
||||
name="accelerate",
|
||||
version="0.24.0.dev0",
|
||||
version="0.24.0",
|
||||
description="Accelerate",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
__version__ = "0.24.0.dev0"
|
||||
__version__ = "0.24.0"
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .big_modeling import (
|
||||
|
||||
@ -63,6 +63,7 @@ from .utils import (
|
||||
ProjectConfiguration,
|
||||
RNGType,
|
||||
TorchDynamoPlugin,
|
||||
check_os_kernel,
|
||||
compare_versions,
|
||||
convert_model,
|
||||
convert_outputs_to_fp32,
|
||||
@ -264,6 +265,7 @@ class Accelerator:
|
||||
kwargs_handlers: list[KwargsHandler] | None = None,
|
||||
dynamo_backend: DynamoBackend | str | None = None,
|
||||
):
|
||||
self.trackers = []
|
||||
if project_config is not None:
|
||||
self.project_configuration = project_config
|
||||
else:
|
||||
@ -469,6 +471,8 @@ class Accelerator:
|
||||
# Set a flag tensor for early stopping and other breakpoints
|
||||
self.flag_tensor = None
|
||||
|
||||
check_os_kernel()
|
||||
|
||||
@property
|
||||
def use_distributed(self):
|
||||
"""
|
||||
@ -2399,7 +2403,6 @@ class Accelerator:
|
||||
... )
|
||||
```
|
||||
"""
|
||||
self.trackers = []
|
||||
for tracker in self.log_with:
|
||||
if issubclass(type(tracker), GeneralTracker):
|
||||
# Custom trackers are already initialized
|
||||
@ -2441,7 +2444,7 @@ class Accelerator:
|
||||
>>> tensorboard_tracker = accelerator.get_tracker("tensorboard")
|
||||
```
|
||||
"""
|
||||
if len(getattr(self, "trackers", [])) > 0:
|
||||
if len(self.trackers) > 0:
|
||||
for tracker in self.trackers:
|
||||
if tracker.name == name:
|
||||
return tracker.tracker if unwrap else tracker
|
||||
@ -2508,6 +2511,10 @@ class Accelerator:
|
||||
f (`str` or `os.PathLike`): Where to save the content of `obj`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save `obj` using `safetensors`
|
||||
|
||||
Note:
|
||||
If `save_on_each_node` was passed in as a `ProjectConfiguration`, will save the object once per node,
|
||||
rather than only once on the main node.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
@ -2518,7 +2525,12 @@ class Accelerator:
|
||||
>>> accelerator.save(arr, "array.pkl")
|
||||
```
|
||||
"""
|
||||
save(obj, f, safe_serialization=safe_serialization)
|
||||
save(
|
||||
obj,
|
||||
f,
|
||||
save_on_each_node=self.project_configuration.save_on_each_node,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
def save_model(
|
||||
self,
|
||||
@ -2787,16 +2799,26 @@ class Accelerator:
|
||||
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
|
||||
schedulers = self._schedulers
|
||||
|
||||
# Save the samplers of the dataloaders
|
||||
dataloaders = self._dataloaders
|
||||
|
||||
# Call model loading hooks that might have been registered with
|
||||
# accelerator.register_model_state_hook
|
||||
for hook in self._save_model_state_pre_hook.values():
|
||||
hook(self._models, weights, output_dir)
|
||||
|
||||
save_location = save_accelerator_state(
|
||||
output_dir, weights, optimizers, schedulers, self.state.process_index, self.scaler
|
||||
output_dir,
|
||||
weights,
|
||||
optimizers,
|
||||
schedulers,
|
||||
dataloaders,
|
||||
self.state.process_index,
|
||||
self.scaler,
|
||||
save_on_each_node=self.project_configuration.save_on_each_node,
|
||||
)
|
||||
for i, obj in enumerate(self._custom_objects):
|
||||
save_custom_state(obj, output_dir, i)
|
||||
save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node)
|
||||
self.project_configuration.iteration += 1
|
||||
return save_location
|
||||
|
||||
@ -2920,6 +2942,8 @@ class Accelerator:
|
||||
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
|
||||
schedulers = self._schedulers
|
||||
|
||||
dataloaders = self._dataloaders
|
||||
|
||||
# Call model loading hooks that might have been registered with
|
||||
# accelerator.register_model_state_hook
|
||||
for hook in self._load_model_state_pre_hook.values():
|
||||
@ -2940,6 +2964,7 @@ class Accelerator:
|
||||
models,
|
||||
optimizers,
|
||||
schedulers,
|
||||
dataloaders,
|
||||
self.state.process_index,
|
||||
self.scaler,
|
||||
map_location,
|
||||
|
||||
@ -25,6 +25,7 @@ from .utils import (
|
||||
MODEL_NAME,
|
||||
OPTIMIZER_NAME,
|
||||
RNG_STATE_NAME,
|
||||
SAMPLER_NAME,
|
||||
SCALER_NAME,
|
||||
SCHEDULER_NAME,
|
||||
get_pretty_name,
|
||||
@ -49,8 +50,10 @@ def save_accelerator_state(
|
||||
model_states: List[dict],
|
||||
optimizers: list,
|
||||
schedulers: list,
|
||||
dataloaders: list,
|
||||
process_index: int,
|
||||
scaler: GradScaler = None,
|
||||
save_on_each_node: bool = False,
|
||||
):
|
||||
"""
|
||||
Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
|
||||
@ -64,31 +67,49 @@ def save_accelerator_state(
|
||||
A list of optimizer instances
|
||||
schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
|
||||
A list of learning rate schedulers
|
||||
dataloaders (`List[torch.utils.data.DataLoader]`):
|
||||
A list of dataloader instances to save their sampler states
|
||||
process_index (`int`):
|
||||
The current process index in the Accelerator state
|
||||
scaler (`torch.cuda.amp.GradScaler`, *optional*):
|
||||
An optional gradient scaler instance to save
|
||||
save_on_each_node (`bool`, *optional*):
|
||||
Whether to save on every node, or only the main node.
|
||||
"""
|
||||
# Model states
|
||||
for i, state in enumerate(model_states):
|
||||
weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin"
|
||||
output_model_file = os.path.join(output_dir, weights_name)
|
||||
save(state, output_model_file)
|
||||
save(state, output_model_file, save_on_each_node=save_on_each_node)
|
||||
logger.info(f"Model weights saved in {output_model_file}")
|
||||
# Optimizer states
|
||||
for i, opt in enumerate(optimizers):
|
||||
state = opt.state_dict()
|
||||
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
|
||||
output_optimizer_file = os.path.join(output_dir, optimizer_name)
|
||||
save(state, output_optimizer_file)
|
||||
save(state, output_optimizer_file, save_on_each_node=save_on_each_node)
|
||||
logger.info(f"Optimizer state saved in {output_optimizer_file}")
|
||||
# Scheduler states
|
||||
for i, scheduler in enumerate(schedulers):
|
||||
state = scheduler.state_dict()
|
||||
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
|
||||
output_scheduler_file = os.path.join(output_dir, scheduler_name)
|
||||
save(state, output_scheduler_file)
|
||||
save(state, output_scheduler_file, save_on_each_node=save_on_each_node)
|
||||
logger.info(f"Scheduler state saved in {output_scheduler_file}")
|
||||
# DataLoader states
|
||||
for i, dataloader in enumerate(dataloaders):
|
||||
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
|
||||
output_sampler_file = os.path.join(output_dir, sampler_name)
|
||||
# Only save if we have our custom sampler
|
||||
from .data_loader import IterableDatasetShard, SeedableRandomSampler
|
||||
|
||||
if isinstance(dataloader.dataset, IterableDatasetShard):
|
||||
sampler = dataloader.sampler.sampler
|
||||
|
||||
if isinstance(sampler, SeedableRandomSampler):
|
||||
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node)
|
||||
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
|
||||
|
||||
# GradScaler state
|
||||
if scaler is not None:
|
||||
state = scaler.state_dict()
|
||||
@ -118,6 +139,7 @@ def load_accelerator_state(
|
||||
models,
|
||||
optimizers,
|
||||
schedulers,
|
||||
dataloaders,
|
||||
process_index,
|
||||
scaler=None,
|
||||
map_location=None,
|
||||
@ -174,6 +196,19 @@ def load_accelerator_state(
|
||||
scheduler.load_state_dict(torch.load(input_scheduler_file))
|
||||
logger.info("All scheduler states loaded successfully")
|
||||
|
||||
for i, dataloader in enumerate(dataloaders):
|
||||
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
|
||||
input_sampler_file = os.path.join(input_dir, sampler_name)
|
||||
# Only load if we have our custom sampler
|
||||
from .data_loader import IterableDatasetShard, SeedableRandomSampler
|
||||
|
||||
if isinstance(dataloader.dataset, IterableDatasetShard):
|
||||
sampler = dataloader.sampler.sampler
|
||||
|
||||
if isinstance(sampler, SeedableRandomSampler):
|
||||
dataloader.sampler.sampler = torch.load(input_sampler_file)
|
||||
logger.info("All dataloader sampler states loaded successfully")
|
||||
|
||||
# GradScaler state
|
||||
if scaler is not None:
|
||||
input_scaler_file = os.path.join(input_dir, SCALER_NAME)
|
||||
@ -197,14 +232,14 @@ def load_accelerator_state(
|
||||
logger.info("Could not load random states")
|
||||
|
||||
|
||||
def save_custom_state(obj, path, index: int = 0):
|
||||
def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False):
|
||||
"""
|
||||
Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl`
|
||||
"""
|
||||
# Should this be the right way to get a qual_name type value from `obj`?
|
||||
save_location = Path(path) / f"custom_checkpoint_{index}.pkl"
|
||||
logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}")
|
||||
torch.save(obj.state_dict(), save_location)
|
||||
save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node)
|
||||
|
||||
|
||||
def load_custom_state(obj, path, index: int = 0):
|
||||
|
||||
@ -386,12 +386,21 @@ def get_cluster_input():
|
||||
default=False,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
fsdp_config["fsdp_sync_module_states"] = _ask_field(
|
||||
"Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ",
|
||||
fsdp_config["fsdp_cpu_ram_efficient_loading"] = _ask_field(
|
||||
"Do you want to enable CPU RAM efficient model loading? Only applicable for 🤗 Transformers models. [YES/no]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=True,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
if fsdp_config["fsdp_cpu_ram_efficient_loading"]:
|
||||
fsdp_config["fsdp_sync_module_states"] = True
|
||||
else:
|
||||
fsdp_config["fsdp_sync_module_states"] = _ask_field(
|
||||
"Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=True,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
|
||||
megatron_lm_config = {}
|
||||
if distributed_type in [DistributedType.MULTI_GPU]:
|
||||
|
||||
@ -524,6 +524,14 @@ def launch_command_parser(subparsers=None):
|
||||
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres."
|
||||
" (useful only when `use_fsdp` flag is passed).",
|
||||
)
|
||||
fsdp_args.add_argument(
|
||||
"--fsdp_cpu_ram_efficient_loading",
|
||||
default="true",
|
||||
type=str,
|
||||
help="If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. "
|
||||
"Only applicable for 🤗 Transformers. When using this, `--fsdp_sync_module_states` needs to True. "
|
||||
"(useful only when `use_fsdp` flag is passed).",
|
||||
)
|
||||
fsdp_args.add_argument(
|
||||
"--fsdp_sync_module_states",
|
||||
default="true",
|
||||
|
||||
@ -17,7 +17,7 @@ from contextlib import suppress
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
|
||||
|
||||
from .logging import get_logger
|
||||
from .state import AcceleratorState, DistributedType, GradientState, is_tpu_available
|
||||
@ -64,6 +64,41 @@ for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items():
|
||||
_PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)
|
||||
|
||||
|
||||
class SeedableRandomSampler(RandomSampler):
|
||||
"""
|
||||
Same as a random sampler, except that in `__iter__` a seed can be used.
|
||||
|
||||
Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
|
||||
and be fully reproducable on multiple iterations.
|
||||
|
||||
If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
|
||||
(stored in `self.epoch`).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.epoch = 0
|
||||
|
||||
def __iter__(self):
|
||||
g = torch.Generator()
|
||||
if self.generator is not None:
|
||||
seed = self.epoch + self.generator.initial_seed()
|
||||
else:
|
||||
seed = self.epoch
|
||||
g.manual_seed(seed)
|
||||
n = len(self.data_source)
|
||||
# Taken 1:1 from torch.utils.data.sampler.RandomSampler.__iter__
|
||||
if self.replacement:
|
||||
for _ in range(self.num_samples // 32):
|
||||
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist()
|
||||
else:
|
||||
yield from torch.randperm(n, generator=g).tolist()
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
"Sets the current iteration of the sampler."
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
class BatchSamplerShard(BatchSampler):
|
||||
"""
|
||||
Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
|
||||
@ -271,7 +306,25 @@ class IterableDatasetShard(IterableDataset):
|
||||
self.process_index = process_index
|
||||
self.split_batches = split_batches
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
if hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
def __len__(self):
|
||||
# We will just raise the downstream error if the underlying dataset is not sized
|
||||
if self.drop_last:
|
||||
return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
|
||||
else:
|
||||
return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
|
||||
|
||||
def __iter__(self):
|
||||
if (
|
||||
not hasattr(self.dataset, "set_epoch")
|
||||
and hasattr(self.dataset, "generator")
|
||||
and isinstance(self.dataset.generator, torch.Generator)
|
||||
):
|
||||
self.dataset.generator.manual_seed(self.epoch)
|
||||
real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)
|
||||
process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size
|
||||
process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)
|
||||
@ -324,8 +377,9 @@ class DataLoaderStateMixin:
|
||||
"Prepares the gradient state for the current dataloader"
|
||||
self.reset()
|
||||
with suppress(Exception):
|
||||
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
|
||||
self.remainder = length % self.total_batch_size
|
||||
if not self._drop_last:
|
||||
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
|
||||
self.remainder = length % self.total_batch_size
|
||||
self.gradient_state._add_dataloader(self)
|
||||
|
||||
def end(self):
|
||||
@ -352,7 +406,7 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin):
|
||||
- `"generator"`: an optional `torch.Generator`
|
||||
synchronized_generator (`torch.Generator`, *optional*):
|
||||
A random number generator to keep synchronized across processes.
|
||||
split_batches (`int`, *optional*, defaults to 0):
|
||||
skip_batches (`int`, *optional*, defaults to 0):
|
||||
The number of batches to skip at the beginning.
|
||||
kwargs:
|
||||
All other keyword arguments to pass to the regular `DataLoader` initialization.
|
||||
@ -366,18 +420,31 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin):
|
||||
- **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, device=None, rng_types=None, synchronized_generator=None, skip_batches=0, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
device=None,
|
||||
rng_types=None,
|
||||
synchronized_generator=None,
|
||||
skip_batches=0,
|
||||
_drop_last: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(dataset, **kwargs)
|
||||
self.device = device
|
||||
self.rng_types = rng_types
|
||||
self.synchronized_generator = synchronized_generator
|
||||
self.skip_batches = skip_batches
|
||||
self.gradient_state = GradientState()
|
||||
self._drop_last = _drop_last
|
||||
self.iteration = 0
|
||||
|
||||
def __iter__(self):
|
||||
if self.rng_types is not None:
|
||||
synchronize_rng_states(self.rng_types, self.synchronized_generator)
|
||||
self.begin()
|
||||
|
||||
self.set_epoch(self.iteration)
|
||||
dataloader_iter = super().__iter__()
|
||||
# We iterate one batch ahead to check when we are at the end
|
||||
try:
|
||||
@ -401,8 +468,21 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin):
|
||||
if batch_index >= self.skip_batches:
|
||||
yield current_batch
|
||||
break
|
||||
|
||||
self.iteration += 1
|
||||
self.end()
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
# In case it is manually passed in, the user can set it to what they like
|
||||
if self.iteration != epoch:
|
||||
self.iteration = epoch
|
||||
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
|
||||
self.batch_sampler.sampler.set_epoch(epoch)
|
||||
# We support if a custom `Dataset` implementation has `set_epoch`
|
||||
# or in general HF datasets `Datasets`
|
||||
elif hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
@property
|
||||
def total_batch_size(self):
|
||||
batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
|
||||
@ -506,6 +586,7 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
|
||||
self.skip_batches = skip_batches
|
||||
|
||||
self.slice_fn = slice_tensors if slice_fn is None else slice_fn
|
||||
self.iteration = 0
|
||||
|
||||
def _fetch_batches(self, iterator):
|
||||
batches, batch = None, None
|
||||
@ -546,6 +627,7 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
|
||||
|
||||
def __iter__(self):
|
||||
self.begin()
|
||||
self.set_epoch(self.iteration)
|
||||
main_iterator = None
|
||||
if is_torch_version(">=", "2.0.1"):
|
||||
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
|
||||
@ -615,8 +697,18 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
|
||||
if batch_index >= self.skip_batches:
|
||||
yield batch
|
||||
batch_index += 1
|
||||
self.iteration += 1
|
||||
self.end()
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
# In case it is manually passed in, the user can set it to what they like
|
||||
if self.iteration != epoch:
|
||||
self.iteration = epoch
|
||||
if hasattr(self.batch_sampler.sampler, "set_epoch"):
|
||||
self.batch_sampler.sampler.set_epoch(epoch)
|
||||
elif hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
def __len__(self):
|
||||
whole_length = super().__len__()
|
||||
if self.split_batches:
|
||||
@ -739,6 +831,23 @@ def prepare_data_loader(
|
||||
new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
|
||||
sampler_is_batch_sampler = False
|
||||
synchronized_generator = None
|
||||
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
||||
if sampler_is_batch_sampler:
|
||||
sampler = dataloader.sampler.sampler
|
||||
else:
|
||||
sampler = dataloader.batch_sampler.sampler
|
||||
if isinstance(sampler, RandomSampler) and num_processes > 1:
|
||||
# When iterating through the dataloader during distributed processes
|
||||
# we want to ensure that on each process we are iterating through the same
|
||||
# samples in the same order if a seed is set. This requires a tweak
|
||||
# to the `torch.utils.data.RandomSampler` class (if used).
|
||||
sampler = SeedableRandomSampler(
|
||||
data_source=sampler.data_source,
|
||||
replacement=sampler.replacement,
|
||||
num_samples=sampler._num_samples,
|
||||
generator=getattr(sampler, "generator", torch.Generator()),
|
||||
)
|
||||
|
||||
# No change if no multiprocess
|
||||
if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
|
||||
if isinstance(new_dataset, IterableDataset):
|
||||
@ -753,17 +862,6 @@ def prepare_data_loader(
|
||||
split_batches=split_batches,
|
||||
)
|
||||
else:
|
||||
# New batch sampler for the current process.
|
||||
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
||||
if sampler_is_batch_sampler:
|
||||
sampler = dataloader.sampler.sampler
|
||||
else:
|
||||
sampler = dataloader.batch_sampler.sampler
|
||||
if hasattr(sampler, "generator"):
|
||||
if sampler.generator is None:
|
||||
sampler.generator = torch.Generator()
|
||||
synchronized_generator = sampler.generator
|
||||
|
||||
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
|
||||
new_batch_sampler = BatchSamplerShard(
|
||||
batch_sampler,
|
||||
@ -797,7 +895,11 @@ def prepare_data_loader(
|
||||
kwargs["batch_size"] = (
|
||||
dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
|
||||
)
|
||||
|
||||
if isinstance(sampler, SeedableRandomSampler):
|
||||
if sampler_is_batch_sampler:
|
||||
dataloader.sampler.sampler = sampler
|
||||
else:
|
||||
dataloader.batch_sampler.sampler = sampler
|
||||
if dispatch_batches:
|
||||
kwargs.pop("generator")
|
||||
dataloader = DataLoaderDispatcher(
|
||||
@ -815,6 +917,7 @@ def prepare_data_loader(
|
||||
sampler=new_batch_sampler,
|
||||
batch_size=dataloader.batch_size,
|
||||
rng_types=rng_types,
|
||||
_drop_last=dataloader.drop_last,
|
||||
synchronized_generator=synchronized_generator,
|
||||
**kwargs,
|
||||
)
|
||||
@ -825,6 +928,7 @@ def prepare_data_loader(
|
||||
batch_sampler=new_batch_sampler,
|
||||
rng_types=rng_types,
|
||||
synchronized_generator=synchronized_generator,
|
||||
_drop_last=dataloader.drop_last,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -85,9 +85,11 @@ def get_logger(name: str, log_level: str = None):
|
||||
|
||||
```python
|
||||
>>> from accelerate.logging import get_logger
|
||||
>>> from accelerate import Accelerator
|
||||
|
||||
>>> logger = get_logger(__name__)
|
||||
|
||||
>>> accelerator = Accelerator()
|
||||
>>> logger.info("My log", main_process_only=False)
|
||||
>>> logger.debug("My log", main_process_only=True)
|
||||
|
||||
@ -95,9 +97,6 @@ def get_logger(name: str, log_level: str = None):
|
||||
>>> logger.info("My log")
|
||||
>>> logger.debug("My second log")
|
||||
|
||||
>>> from accelerate import Accelerator
|
||||
|
||||
>>> accelerator = Accelerator()
|
||||
>>> array = ["a", "b", "c", "d"]
|
||||
>>> letter_at_rank = array[accelerator.process_index]
|
||||
>>> logger.info(letter_at_rank, in_order=True)
|
||||
|
||||
@ -219,6 +219,25 @@ def test_gather_for_metrics_with_iterable_dataset():
|
||||
logger.removeHandler(list_handler)
|
||||
|
||||
|
||||
def test_gather_for_metrics_drop_last():
|
||||
accelerator = Accelerator()
|
||||
per_device_batch_size = 5
|
||||
num_items = (10 * accelerator.num_processes) + 1
|
||||
dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True)
|
||||
dataloader = accelerator.prepare(dataloader)
|
||||
|
||||
iterator = iter(dataloader)
|
||||
next(iterator) # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0')
|
||||
batch = next(iterator)
|
||||
gathered_items = accelerator.gather_for_metrics(batch)
|
||||
|
||||
# Should return a full set of complete batches from each GPU
|
||||
num_expected_items = per_device_batch_size * accelerator.num_processes
|
||||
assert gathered_items.size(0) == (
|
||||
num_expected_items
|
||||
), f"Expected number of items: {num_expected_items}, Actual: {gathered_items.size(0)}"
|
||||
|
||||
|
||||
def main():
|
||||
accelerator = Accelerator(split_batches=False, dispatch_batches=False)
|
||||
if accelerator.is_local_main_process:
|
||||
@ -255,6 +274,10 @@ def main():
|
||||
accelerator = Accelerator()
|
||||
test_torch_metrics(accelerator, 512)
|
||||
accelerator.state._reset_state()
|
||||
if accelerator.is_local_main_process:
|
||||
print("**Test that `drop_last` is taken into account**")
|
||||
test_gather_for_metrics_drop_last()
|
||||
accelerator.state._reset_state()
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
||||
@ -25,7 +25,7 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.data_loader import prepare_data_loader
|
||||
from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.test_utils import RegressionDataset, are_the_same_tensors
|
||||
from accelerate.utils import (
|
||||
@ -292,7 +292,17 @@ def mock_training(length, batch_size, generator):
|
||||
set_seed(42)
|
||||
generator.manual_seed(42)
|
||||
train_set = RegressionDataset(length=length, seed=42)
|
||||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
|
||||
if AcceleratorState().num_processes > 1:
|
||||
# The SeedableRandomSampler is needed during distributed setups
|
||||
# for full reproducability across processes with the `DataLoader`
|
||||
sampler = SeedableRandomSampler(
|
||||
generator=generator,
|
||||
data_source=train_set,
|
||||
num_samples=len(train_set),
|
||||
)
|
||||
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
|
||||
else:
|
||||
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
|
||||
model = RegressionModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
for epoch in range(3):
|
||||
|
||||
@ -4,6 +4,7 @@ from .constants import (
|
||||
RNG_STATE_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
SAMPLER_NAME,
|
||||
SCALER_NAME,
|
||||
SCHEDULER_NAME,
|
||||
TORCH_DISTRIBUTED_OPERATION_TYPES,
|
||||
@ -164,6 +165,7 @@ from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer
|
||||
from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler
|
||||
from .memory import find_executable_batch_size, release_memory
|
||||
from .other import (
|
||||
check_os_kernel,
|
||||
clear_environment,
|
||||
convert_bytes,
|
||||
extract_model_from_parallel,
|
||||
|
||||
@ -20,6 +20,7 @@ MODEL_NAME = "pytorch_model"
|
||||
RNG_STATE_NAME = "random_states"
|
||||
OPTIMIZER_NAME = "optimizer"
|
||||
SCHEDULER_NAME = "scheduler"
|
||||
SAMPLER_NAME = "sampler"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
|
||||
@ -420,6 +420,16 @@ class ProjectConfiguration:
|
||||
metadata={"help": "The current save iteration."},
|
||||
)
|
||||
|
||||
save_on_each_node: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"When doing multi-node distributed training, whether to save models and checkpoints on each node, or"
|
||||
" only on the main one"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
def set_directories(self, project_dir: str = None):
|
||||
"Sets `self.project_dir` and `self.logging_dir` to the appropriate values."
|
||||
self.project_dir = project_dir
|
||||
@ -732,7 +742,7 @@ class DeepSpeedPlugin:
|
||||
or ds_config["train_micro_batch_size_per_gpu"] == "auto"
|
||||
):
|
||||
ds_config["train_micro_batch_size_per_gpu"] = 1
|
||||
if ds_config["train_batch_size"] == "auto":
|
||||
if ds_config.get("train_batch_size", None) == "auto":
|
||||
del ds_config["train_batch_size"]
|
||||
|
||||
if compare_versions("transformers", "<", "4.33"):
|
||||
|
||||
@ -174,6 +174,9 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:
|
||||
|
||||
if args.use_fsdp:
|
||||
current_env["ACCELERATE_USE_FSDP"] = "true"
|
||||
if args.fsdp_cpu_ram_efficient_loading and not args.fsdp_sync_module_states:
|
||||
raise ValueError("When using `--fsdp_cpu_ram_efficient_loading` set `--fsdp_sync_module_states` to `True`")
|
||||
|
||||
current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy)
|
||||
current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower()
|
||||
current_env["FSDP_MIN_NUM_PARAMS"] = str(args.fsdp_min_num_params)
|
||||
@ -187,6 +190,7 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]:
|
||||
current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type)
|
||||
current_env["FSDP_FORWARD_PREFETCH"] = str(args.fsdp_forward_prefetch).lower()
|
||||
current_env["FSDP_USE_ORIG_PARAMS"] = str(args.fsdp_use_orig_params).lower()
|
||||
current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
|
||||
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
|
||||
|
||||
if args.use_megatron_lm:
|
||||
|
||||
@ -250,7 +250,7 @@ def set_module_tensor_to_device(
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module in which the tensor we want to move lives.
|
||||
param_name (`str`):
|
||||
tensor_name (`str`):
|
||||
The full name of the parameter/buffer.
|
||||
device (`int`, `str` or `torch.device`):
|
||||
The device on which to set the tensor.
|
||||
@ -1458,6 +1458,7 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg
|
||||
DistributedType.MULTI_GPU,
|
||||
DistributedType.MULTI_NPU,
|
||||
DistributedType.MULTI_XPU,
|
||||
DistributedType.FSDP,
|
||||
]:
|
||||
return torch.autocast(device_type=state.device.type, dtype=torch.bfloat16, **autocast_kwargs)
|
||||
else:
|
||||
|
||||
@ -25,7 +25,7 @@ import torch
|
||||
from ..state import PartialState
|
||||
from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
|
||||
from .dataclasses import DistributedType, TensorInformation
|
||||
from .imports import is_torch_distributed_available, is_tpu_available
|
||||
from .imports import is_torch_distributed_available, is_torch_version, is_tpu_available
|
||||
|
||||
|
||||
if is_tpu_available(check_device=False):
|
||||
@ -280,6 +280,12 @@ def _tpu_gather(tensor):
|
||||
|
||||
|
||||
def _gpu_gather(tensor):
|
||||
state = PartialState()
|
||||
if is_torch_version(">=", "1.13"):
|
||||
gather_op = torch.distributed.all_gather_into_tensor
|
||||
else:
|
||||
gather_op = torch.distributed._all_gather_base
|
||||
|
||||
def _gpu_gather_one(tensor):
|
||||
if tensor.ndim == 0:
|
||||
tensor = tensor.clone()[None]
|
||||
@ -287,9 +293,26 @@ def _gpu_gather(tensor):
|
||||
# Can only gather contiguous tensors
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
output_tensors = [torch.empty_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensor)
|
||||
return torch.cat(output_tensors, dim=0)
|
||||
|
||||
if state.backend is not None and state.backend != "gloo":
|
||||
# We use `empty` as `all_gather_into_tensor` slightly
|
||||
# differs from `all_gather` for better efficiency,
|
||||
# and we rely on the number of items in the tensor
|
||||
# rather than its direct shape
|
||||
output_tensors = torch.empty(
|
||||
state.num_processes * tensor.numel(),
|
||||
dtype=tensor.dtype,
|
||||
device=state.device,
|
||||
)
|
||||
gather_op(output_tensors, tensor)
|
||||
return output_tensors.view(-1, *tensor.size()[1:])
|
||||
else:
|
||||
# a backend of `None` is always CPU
|
||||
# also gloo does not support `all_gather_into_tensor`,
|
||||
# which will result in a larger memory overhead for the op
|
||||
output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)]
|
||||
torch.distributed.all_gather(output_tensors, tensor)
|
||||
return torch.cat(output_tensors, dim=0)
|
||||
|
||||
return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
|
||||
|
||||
|
||||
@ -13,13 +13,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import socket
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
|
||||
from ..commands.config.default import write_basic_config # noqa: F401
|
||||
from ..logging import get_logger
|
||||
from ..state import PartialState
|
||||
from .constants import FSDP_PYTORCH_VERSION
|
||||
from .dataclasses import DistributedType
|
||||
@ -28,6 +33,9 @@ from .transformer_engine import convert_model
|
||||
from .versions import is_torch_version
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
if is_tpu_available(check_device=False):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
@ -109,22 +117,27 @@ def wait_for_everyone():
|
||||
PartialState().wait_for_everyone()
|
||||
|
||||
|
||||
def save(obj, f, safe_serialization=False):
|
||||
def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False):
|
||||
"""
|
||||
Save the data to disk. Use in place of `torch.save()`.
|
||||
|
||||
Args:
|
||||
obj: The data to save
|
||||
f: The file (or file-like object) to use to save the data
|
||||
safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save `obj` using `safetensors`
|
||||
obj:
|
||||
The data to save
|
||||
f:
|
||||
The file (or file-like object) to use to save the data
|
||||
save_on_each_node (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only save on the global main process
|
||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save `obj` using `safetensors`
|
||||
"""
|
||||
save_func = torch.save if not safe_serialization else partial(safe_save_file, metadata={"format": "pt"})
|
||||
if PartialState().distributed_type == DistributedType.TPU:
|
||||
xm.save(obj, f)
|
||||
elif PartialState().local_process_index == 0:
|
||||
if safe_serialization:
|
||||
safe_save_file(obj, f, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(obj, f)
|
||||
elif PartialState().is_main_process and not save_on_each_node:
|
||||
save_func(obj, f)
|
||||
elif PartialState().is_local_main_process and save_on_each_node:
|
||||
save_func(obj, f)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -246,3 +259,21 @@ def convert_bytes(size):
|
||||
size /= 1024.0
|
||||
|
||||
return f"{round(size, 2)} PB"
|
||||
|
||||
|
||||
def check_os_kernel():
|
||||
"""Warns if the kernel version is below the recommended minimum on Linux."""
|
||||
# see issue #1929
|
||||
info = platform.uname()
|
||||
system = info.system
|
||||
if system != "Linux":
|
||||
return
|
||||
|
||||
_, version, *_ = re.split(r"(\d+\.\d+\.\d+)", info.release)
|
||||
min_version = "5.5.0"
|
||||
if Version(version) < Version(min_version):
|
||||
msg = (
|
||||
f"Detected kernel version {version}, which is below the recommended minimum of {min_version}; this can "
|
||||
"cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher."
|
||||
)
|
||||
logger.warning(msg, main_process_only=True)
|
||||
|
||||
@ -55,8 +55,6 @@ from accelerate.utils.other import patch_environment
|
||||
|
||||
set_seed(42)
|
||||
|
||||
T5_SMALL = "t5-small"
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
GPT2_TINY = "sshleifer/tiny-gpt2"
|
||||
|
||||
ZERO2 = "zero2"
|
||||
|
||||
@ -15,13 +15,17 @@
|
||||
import os
|
||||
import pickle
|
||||
import unittest
|
||||
import warnings
|
||||
from collections import UserDict, namedtuple
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from accelerate.state import PartialState
|
||||
from accelerate.test_utils.testing import require_cuda, require_torch_min_version
|
||||
from accelerate.test_utils.training import RegressionModel
|
||||
from accelerate.utils import (
|
||||
check_os_kernel,
|
||||
convert_outputs_to_fp32,
|
||||
extract_model_from_parallel,
|
||||
find_device,
|
||||
@ -36,6 +40,10 @@ ExampleNamedTuple = namedtuple("ExampleNamedTuple", "a b c")
|
||||
|
||||
|
||||
class UtilsTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# logging requires initialized state
|
||||
PartialState()
|
||||
|
||||
def test_send_to_device(self):
|
||||
tensor = torch.randn(5, 2)
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
@ -173,3 +181,27 @@ class UtilsTester(unittest.TestCase):
|
||||
self.assertEqual(find_device([1, "a", torch.tensor([1, 2, 3])]), torch.device("cpu"))
|
||||
self.assertEqual(find_device({"a": 1, "b": torch.tensor([1, 2, 3])}), torch.device("cpu"))
|
||||
self.assertIsNone(find_device([1, "a"]))
|
||||
|
||||
def test_check_os_kernel_no_warning_when_release_gt_min(self):
|
||||
# min version is 5.5
|
||||
with patch("platform.uname", return_value=Mock(release="5.15.0-35-generic", system="Linux")):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
check_os_kernel()
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
def test_check_os_kernel_no_warning_when_not_linux(self):
|
||||
# system must be Linux
|
||||
with patch("platform.uname", return_value=Mock(release="5.4.0-35-generic", system="Darwin")):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
check_os_kernel()
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
def test_check_os_kernel_warning_when_release_lt_min(self):
|
||||
# min version is 5.5
|
||||
with patch("platform.uname", return_value=Mock(release="5.4.0-35-generic", system="Linux")):
|
||||
with self.assertLogs() as ctx:
|
||||
check_os_kernel()
|
||||
self.assertEqual(len(ctx.records), 1)
|
||||
self.assertEqual(ctx.records[0].levelname, "WARNING")
|
||||
self.assertIn("5.4.0", ctx.records[0].msg)
|
||||
self.assertIn("5.5.0", ctx.records[0].msg)
|
||||
|
||||
@ -17,6 +17,7 @@ https://github.com/allenai/allennlp.
|
||||
"""
|
||||
import os
|
||||
from datetime import datetime as dt
|
||||
from datetime import timezone
|
||||
|
||||
from github import Github
|
||||
|
||||
@ -36,7 +37,7 @@ def main():
|
||||
for issue in open_issues:
|
||||
comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True)
|
||||
last_comment = comments[0] if len(comments) > 0 else None
|
||||
current_time = dt.utcnow()
|
||||
current_time = dt.now(timezone.utc)
|
||||
days_since_updated = (current_time - issue.updated_at).days
|
||||
days_since_creation = (current_time - issue.created_at).days
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user