mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-13 15:14:35 +08:00
Compare commits
2 Commits
v1.0.1
...
fork-teste
| Author | SHA1 | Date | |
|---|---|---|---|
| d9a7459981 | |||
| 70d31ba566 |
2
setup.py
2
setup.py
@ -22,7 +22,7 @@ extras["quality"] = [
|
||||
"ruff ~= 0.2.1",
|
||||
]
|
||||
extras["docs"] = []
|
||||
extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized"]
|
||||
extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized", "expecttest"]
|
||||
extras["test_dev"] = [
|
||||
"datasets",
|
||||
"diffusers",
|
||||
|
||||
@ -101,6 +101,11 @@ class ThreadLocalSharedDict(threading.local):
|
||||
def __set__(self, obj, value):
|
||||
self._storage = value
|
||||
|
||||
def _get_shared_dict_type():
|
||||
# Prefer global shared dictionary, except when using TPU or `backend == threaded`
|
||||
if is_torch_xla_available() or (torch.distributed.is_initialized() and torch.distributed.get_backend() == "threaded"):
|
||||
return ThreadLocalSharedDict
|
||||
return dict
|
||||
|
||||
# Prefer global shared dictionary, except when using TPU.
|
||||
SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict
|
||||
@ -162,6 +167,9 @@ class PartialState:
|
||||
]
|
||||
|
||||
def __init__(self, cpu: bool = False, **kwargs):
|
||||
# This is needed when we are launching tests and have the `threaded` backend
|
||||
if _get_shared_dict_type() != self._shared_state.__class__:
|
||||
PartialState._shared_state = _get_shared_dict_type()()
|
||||
self.__dict__ = self._shared_state
|
||||
if not self.initialized:
|
||||
self._cpu = cpu
|
||||
@ -185,7 +193,7 @@ class PartialState:
|
||||
self.backend = backend
|
||||
self.distributed_type = distributed_type
|
||||
use_deepspeed = False
|
||||
if not cpu and self.backend != "xla":
|
||||
if not cpu and self.backend != "xla" and not torch.distributed.is_initialized():
|
||||
if int(os.environ.get("LOCAL_RANK", -1)) != -1:
|
||||
# Deal with spawning deepspeed
|
||||
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
||||
@ -274,9 +282,13 @@ class PartialState:
|
||||
else:
|
||||
self.num_processes = torch.distributed.get_world_size()
|
||||
self.process_index = torch.distributed.get_rank()
|
||||
self.local_process_index = (
|
||||
int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank
|
||||
)
|
||||
# Setting `local_process_index` requires some care
|
||||
if dist_information is not None:
|
||||
self.local_process_index = dist_information.local_rank
|
||||
elif backend == "threaded":
|
||||
self.local_process_index = self.process_index
|
||||
else:
|
||||
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
|
||||
self.set_device()
|
||||
# Now we can change to deepseed
|
||||
if use_deepspeed:
|
||||
@ -710,6 +722,10 @@ class PartialState:
|
||||
) -> tuple[str, DistributedType]:
|
||||
"Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly"
|
||||
distributed_type = None
|
||||
if torch.distributed.is_initialized():
|
||||
backend = torch.distributed.get_backend()
|
||||
if backend == "threaded":
|
||||
distributed_type = DistributedType.MULTI_GPU
|
||||
if sagemaker_dp:
|
||||
import smdistributed.dataparallel.torch.torch_smddp # noqa
|
||||
|
||||
@ -718,7 +734,7 @@ class PartialState:
|
||||
elif is_torch_xla_available():
|
||||
backend = "xla"
|
||||
distributed_type = DistributedType.XLA
|
||||
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
|
||||
elif not cpu and int(os.environ.get("LOCAL_RANK", -1)) != -1:
|
||||
if is_mlu_available():
|
||||
backend = "cncl"
|
||||
distributed_type = DistributedType.MULTI_MLU
|
||||
|
||||
@ -807,10 +807,10 @@ def main():
|
||||
if state.distributed_type == DistributedType.DEEPSPEED:
|
||||
return
|
||||
|
||||
if state.local_process_index == 0:
|
||||
print("\n**Training integration test**")
|
||||
training_check(use_seedable_sampler=False)
|
||||
training_check(use_seedable_sampler=True)
|
||||
# if state.local_process_index == 0:
|
||||
# print("\n**Training integration test**")
|
||||
# training_check(use_seedable_sampler=False)
|
||||
# training_check(use_seedable_sampler=True)
|
||||
|
||||
if state.local_process_index == 0:
|
||||
print("\n**Breakpoint trigger test**")
|
||||
|
||||
@ -30,7 +30,7 @@ import torch
|
||||
|
||||
import accelerate
|
||||
|
||||
from ..state import AcceleratorState, PartialState
|
||||
from ..state import PartialState, PartialState
|
||||
from ..utils import (
|
||||
gather,
|
||||
is_bnb_available,
|
||||
@ -427,14 +427,14 @@ class TempDirTestCase(unittest.TestCase):
|
||||
class AccelerateTestCase(unittest.TestCase):
|
||||
"""
|
||||
A TestCase class that will reset the accelerator state at the end of every test. Every test that checks or utilizes
|
||||
the `AcceleratorState` class should inherit from this to avoid silent failures due to state being shared between
|
||||
the `PartialState` class should inherit from this to avoid silent failures due to state being shared between
|
||||
tests.
|
||||
"""
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# Reset the state of the AcceleratorState singleton.
|
||||
AcceleratorState._reset_state()
|
||||
# Reset the state of the PartialState singleton.
|
||||
PartialState._reset_state()
|
||||
PartialState._reset_state()
|
||||
|
||||
|
||||
@ -472,7 +472,7 @@ class MockingTestCase(unittest.TestCase):
|
||||
|
||||
|
||||
def are_the_same_tensors(tensor):
|
||||
state = AcceleratorState()
|
||||
state = PartialState()
|
||||
tensor = tensor[None].clone().to(state.device)
|
||||
tensors = gather(tensor).cpu()
|
||||
tensor = tensor[0].cpu()
|
||||
|
||||
@ -18,7 +18,7 @@ from typing import List, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..state import AcceleratorState
|
||||
from ..state import PartialState
|
||||
from .constants import CUDA_DISTRIBUTED_TYPES
|
||||
from .dataclasses import DistributedType, RNGType
|
||||
from .imports import is_mlu_available, is_npu_available, is_torch_xla_available, is_xpu_available
|
||||
@ -41,7 +41,7 @@ def set_seed(seed: int, device_specific: bool = False, deterministic: bool = Fal
|
||||
Whether to use deterministic algorithms where available. Can slow down training.
|
||||
"""
|
||||
if device_specific:
|
||||
seed += AcceleratorState().process_index
|
||||
seed += PartialState().process_index
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
@ -84,7 +84,7 @@ def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optiona
|
||||
rng_state = generator.get_state()
|
||||
|
||||
# Broadcast the rng state from device 0 to other devices
|
||||
state = AcceleratorState()
|
||||
state = PartialState()
|
||||
if state.distributed_type == DistributedType.XLA:
|
||||
rng_state = rng_state.to(xm.xla_device())
|
||||
xm.collective_broadcast([rng_state])
|
||||
|
||||
76
tests/baseline.py
Normal file
76
tests/baseline.py
Normal file
@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from accelerate import PartialState, Accelerator
|
||||
from accelerate.test_utils.testing import assert_exception
|
||||
from accelerate.utils.dataclasses import DistributedType
|
||||
from accelerate.utils.operations import (
|
||||
DistributedOperationException,
|
||||
broadcast,
|
||||
copy_tensor_to_devices,
|
||||
gather,
|
||||
gather_object,
|
||||
pad_across_processes,
|
||||
reduce,
|
||||
)
|
||||
|
||||
|
||||
def create_tensor(state):
|
||||
return (torch.arange(state.num_processes) + 1.0 + (state.num_processes * state.process_index)).to(state.device)
|
||||
|
||||
|
||||
def test_gather(state):
|
||||
tensor = create_tensor(state)
|
||||
gathered_tensor = gather(tensor)
|
||||
assert gathered_tensor.tolist() == list(range(1, state.num_processes**2 + 1))
|
||||
|
||||
|
||||
def test_gather_object(state):
|
||||
# Gather objects in TorchXLA is not supported.
|
||||
if state.distributed_type == DistributedType.XLA:
|
||||
return
|
||||
obj = [state.process_index]
|
||||
gathered_obj = gather_object(obj)
|
||||
assert len(gathered_obj) == state.num_processes, f"{gathered_obj}, {len(gathered_obj)} != {state.num_processes}"
|
||||
assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}"
|
||||
|
||||
|
||||
def main():
|
||||
accelerator = Accelerator()
|
||||
state = accelerator.state
|
||||
if state.local_process_index == 0:
|
||||
print("**Initialization**")
|
||||
state.wait_for_everyone()
|
||||
|
||||
if state.distributed_type == DistributedType.MULTI_GPU:
|
||||
num_processes_per_node = torch.cuda.device_count()
|
||||
else:
|
||||
num_processes_per_node = state.num_processes
|
||||
|
||||
# We only run this test on non-multinode
|
||||
if state.process_index == 0:
|
||||
print("\n**Test gather operation**")
|
||||
test_gather(state)
|
||||
if state.process_index == 0:
|
||||
print("\n**Test gather_object operation**")
|
||||
test_gather_object(state)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
30
tests/test_multigpu_new.py
Normal file
30
tests/test_multigpu_new.py
Normal file
@ -0,0 +1,30 @@
|
||||
import os
|
||||
import torch
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiThreadedTestCase,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
from accelerate import Accelerator, PartialState
|
||||
from accelerate.test_utils import device_count
|
||||
|
||||
class TrainingTester(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return device_count
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_threads()
|
||||
|
||||
# Verify we are running in multiproc
|
||||
def test_distributed_spawning(self):
|
||||
state = PartialState()
|
||||
assert state.local_process_index == torch.distributed.get_rank()
|
||||
assert state.num_processes == torch.distributed.get_world_size()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
Reference in New Issue
Block a user