mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-14 14:14:32 +08:00
Compare commits
2 Commits
v1.1.0
...
fork-teste
| Author | SHA1 | Date | |
|---|---|---|---|
| d9a7459981 | |||
| 70d31ba566 |
2
setup.py
2
setup.py
@ -22,7 +22,7 @@ extras["quality"] = [
|
|||||||
"ruff ~= 0.2.1",
|
"ruff ~= 0.2.1",
|
||||||
]
|
]
|
||||||
extras["docs"] = []
|
extras["docs"] = []
|
||||||
extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized"]
|
extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized", "expecttest"]
|
||||||
extras["test_dev"] = [
|
extras["test_dev"] = [
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers",
|
"diffusers",
|
||||||
|
|||||||
@ -101,6 +101,11 @@ class ThreadLocalSharedDict(threading.local):
|
|||||||
def __set__(self, obj, value):
|
def __set__(self, obj, value):
|
||||||
self._storage = value
|
self._storage = value
|
||||||
|
|
||||||
|
def _get_shared_dict_type():
|
||||||
|
# Prefer global shared dictionary, except when using TPU or `backend == threaded`
|
||||||
|
if is_torch_xla_available() or (torch.distributed.is_initialized() and torch.distributed.get_backend() == "threaded"):
|
||||||
|
return ThreadLocalSharedDict
|
||||||
|
return dict
|
||||||
|
|
||||||
# Prefer global shared dictionary, except when using TPU.
|
# Prefer global shared dictionary, except when using TPU.
|
||||||
SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict
|
SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict
|
||||||
@ -162,6 +167,9 @@ class PartialState:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, cpu: bool = False, **kwargs):
|
def __init__(self, cpu: bool = False, **kwargs):
|
||||||
|
# This is needed when we are launching tests and have the `threaded` backend
|
||||||
|
if _get_shared_dict_type() != self._shared_state.__class__:
|
||||||
|
PartialState._shared_state = _get_shared_dict_type()()
|
||||||
self.__dict__ = self._shared_state
|
self.__dict__ = self._shared_state
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
self._cpu = cpu
|
self._cpu = cpu
|
||||||
@ -185,7 +193,7 @@ class PartialState:
|
|||||||
self.backend = backend
|
self.backend = backend
|
||||||
self.distributed_type = distributed_type
|
self.distributed_type = distributed_type
|
||||||
use_deepspeed = False
|
use_deepspeed = False
|
||||||
if not cpu and self.backend != "xla":
|
if not cpu and self.backend != "xla" and not torch.distributed.is_initialized():
|
||||||
if int(os.environ.get("LOCAL_RANK", -1)) != -1:
|
if int(os.environ.get("LOCAL_RANK", -1)) != -1:
|
||||||
# Deal with spawning deepspeed
|
# Deal with spawning deepspeed
|
||||||
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
||||||
@ -274,9 +282,13 @@ class PartialState:
|
|||||||
else:
|
else:
|
||||||
self.num_processes = torch.distributed.get_world_size()
|
self.num_processes = torch.distributed.get_world_size()
|
||||||
self.process_index = torch.distributed.get_rank()
|
self.process_index = torch.distributed.get_rank()
|
||||||
self.local_process_index = (
|
# Setting `local_process_index` requires some care
|
||||||
int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank
|
if dist_information is not None:
|
||||||
)
|
self.local_process_index = dist_information.local_rank
|
||||||
|
elif backend == "threaded":
|
||||||
|
self.local_process_index = self.process_index
|
||||||
|
else:
|
||||||
|
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
|
||||||
self.set_device()
|
self.set_device()
|
||||||
# Now we can change to deepseed
|
# Now we can change to deepseed
|
||||||
if use_deepspeed:
|
if use_deepspeed:
|
||||||
@ -710,6 +722,10 @@ class PartialState:
|
|||||||
) -> tuple[str, DistributedType]:
|
) -> tuple[str, DistributedType]:
|
||||||
"Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly"
|
"Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly"
|
||||||
distributed_type = None
|
distributed_type = None
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
backend = torch.distributed.get_backend()
|
||||||
|
if backend == "threaded":
|
||||||
|
distributed_type = DistributedType.MULTI_GPU
|
||||||
if sagemaker_dp:
|
if sagemaker_dp:
|
||||||
import smdistributed.dataparallel.torch.torch_smddp # noqa
|
import smdistributed.dataparallel.torch.torch_smddp # noqa
|
||||||
|
|
||||||
@ -718,7 +734,7 @@ class PartialState:
|
|||||||
elif is_torch_xla_available():
|
elif is_torch_xla_available():
|
||||||
backend = "xla"
|
backend = "xla"
|
||||||
distributed_type = DistributedType.XLA
|
distributed_type = DistributedType.XLA
|
||||||
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
|
elif not cpu and int(os.environ.get("LOCAL_RANK", -1)) != -1:
|
||||||
if is_mlu_available():
|
if is_mlu_available():
|
||||||
backend = "cncl"
|
backend = "cncl"
|
||||||
distributed_type = DistributedType.MULTI_MLU
|
distributed_type = DistributedType.MULTI_MLU
|
||||||
|
|||||||
@ -807,10 +807,10 @@ def main():
|
|||||||
if state.distributed_type == DistributedType.DEEPSPEED:
|
if state.distributed_type == DistributedType.DEEPSPEED:
|
||||||
return
|
return
|
||||||
|
|
||||||
if state.local_process_index == 0:
|
# if state.local_process_index == 0:
|
||||||
print("\n**Training integration test**")
|
# print("\n**Training integration test**")
|
||||||
training_check(use_seedable_sampler=False)
|
# training_check(use_seedable_sampler=False)
|
||||||
training_check(use_seedable_sampler=True)
|
# training_check(use_seedable_sampler=True)
|
||||||
|
|
||||||
if state.local_process_index == 0:
|
if state.local_process_index == 0:
|
||||||
print("\n**Breakpoint trigger test**")
|
print("\n**Breakpoint trigger test**")
|
||||||
|
|||||||
@ -30,7 +30,7 @@ import torch
|
|||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
|
|
||||||
from ..state import AcceleratorState, PartialState
|
from ..state import PartialState, PartialState
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
gather,
|
gather,
|
||||||
is_bnb_available,
|
is_bnb_available,
|
||||||
@ -427,14 +427,14 @@ class TempDirTestCase(unittest.TestCase):
|
|||||||
class AccelerateTestCase(unittest.TestCase):
|
class AccelerateTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
A TestCase class that will reset the accelerator state at the end of every test. Every test that checks or utilizes
|
A TestCase class that will reset the accelerator state at the end of every test. Every test that checks or utilizes
|
||||||
the `AcceleratorState` class should inherit from this to avoid silent failures due to state being shared between
|
the `PartialState` class should inherit from this to avoid silent failures due to state being shared between
|
||||||
tests.
|
tests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
# Reset the state of the AcceleratorState singleton.
|
# Reset the state of the PartialState singleton.
|
||||||
AcceleratorState._reset_state()
|
PartialState._reset_state()
|
||||||
PartialState._reset_state()
|
PartialState._reset_state()
|
||||||
|
|
||||||
|
|
||||||
@ -472,7 +472,7 @@ class MockingTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
def are_the_same_tensors(tensor):
|
def are_the_same_tensors(tensor):
|
||||||
state = AcceleratorState()
|
state = PartialState()
|
||||||
tensor = tensor[None].clone().to(state.device)
|
tensor = tensor[None].clone().to(state.device)
|
||||||
tensors = gather(tensor).cpu()
|
tensors = gather(tensor).cpu()
|
||||||
tensor = tensor[0].cpu()
|
tensor = tensor[0].cpu()
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from typing import List, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..state import AcceleratorState
|
from ..state import PartialState
|
||||||
from .constants import CUDA_DISTRIBUTED_TYPES
|
from .constants import CUDA_DISTRIBUTED_TYPES
|
||||||
from .dataclasses import DistributedType, RNGType
|
from .dataclasses import DistributedType, RNGType
|
||||||
from .imports import is_mlu_available, is_npu_available, is_torch_xla_available, is_xpu_available
|
from .imports import is_mlu_available, is_npu_available, is_torch_xla_available, is_xpu_available
|
||||||
@ -41,7 +41,7 @@ def set_seed(seed: int, device_specific: bool = False, deterministic: bool = Fal
|
|||||||
Whether to use deterministic algorithms where available. Can slow down training.
|
Whether to use deterministic algorithms where available. Can slow down training.
|
||||||
"""
|
"""
|
||||||
if device_specific:
|
if device_specific:
|
||||||
seed += AcceleratorState().process_index
|
seed += PartialState().process_index
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
@ -84,7 +84,7 @@ def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optiona
|
|||||||
rng_state = generator.get_state()
|
rng_state = generator.get_state()
|
||||||
|
|
||||||
# Broadcast the rng state from device 0 to other devices
|
# Broadcast the rng state from device 0 to other devices
|
||||||
state = AcceleratorState()
|
state = PartialState()
|
||||||
if state.distributed_type == DistributedType.XLA:
|
if state.distributed_type == DistributedType.XLA:
|
||||||
rng_state = rng_state.to(xm.xla_device())
|
rng_state = rng_state.to(xm.xla_device())
|
||||||
xm.collective_broadcast([rng_state])
|
xm.collective_broadcast([rng_state])
|
||||||
|
|||||||
76
tests/baseline.py
Normal file
76
tests/baseline.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from accelerate import PartialState, Accelerator
|
||||||
|
from accelerate.test_utils.testing import assert_exception
|
||||||
|
from accelerate.utils.dataclasses import DistributedType
|
||||||
|
from accelerate.utils.operations import (
|
||||||
|
DistributedOperationException,
|
||||||
|
broadcast,
|
||||||
|
copy_tensor_to_devices,
|
||||||
|
gather,
|
||||||
|
gather_object,
|
||||||
|
pad_across_processes,
|
||||||
|
reduce,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_tensor(state):
|
||||||
|
return (torch.arange(state.num_processes) + 1.0 + (state.num_processes * state.process_index)).to(state.device)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gather(state):
|
||||||
|
tensor = create_tensor(state)
|
||||||
|
gathered_tensor = gather(tensor)
|
||||||
|
assert gathered_tensor.tolist() == list(range(1, state.num_processes**2 + 1))
|
||||||
|
|
||||||
|
|
||||||
|
def test_gather_object(state):
|
||||||
|
# Gather objects in TorchXLA is not supported.
|
||||||
|
if state.distributed_type == DistributedType.XLA:
|
||||||
|
return
|
||||||
|
obj = [state.process_index]
|
||||||
|
gathered_obj = gather_object(obj)
|
||||||
|
assert len(gathered_obj) == state.num_processes, f"{gathered_obj}, {len(gathered_obj)} != {state.num_processes}"
|
||||||
|
assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
accelerator = Accelerator()
|
||||||
|
state = accelerator.state
|
||||||
|
if state.local_process_index == 0:
|
||||||
|
print("**Initialization**")
|
||||||
|
state.wait_for_everyone()
|
||||||
|
|
||||||
|
if state.distributed_type == DistributedType.MULTI_GPU:
|
||||||
|
num_processes_per_node = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
num_processes_per_node = state.num_processes
|
||||||
|
|
||||||
|
# We only run this test on non-multinode
|
||||||
|
if state.process_index == 0:
|
||||||
|
print("\n**Test gather operation**")
|
||||||
|
test_gather(state)
|
||||||
|
if state.process_index == 0:
|
||||||
|
print("\n**Test gather_object operation**")
|
||||||
|
test_gather_object(state)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
30
tests/test_multigpu_new.py
Normal file
30
tests/test_multigpu_new.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import inspect
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from torch.testing._internal.common_distributed import (
|
||||||
|
MultiThreadedTestCase,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
|
from accelerate import Accelerator, PartialState
|
||||||
|
from accelerate.test_utils import device_count
|
||||||
|
|
||||||
|
class TrainingTester(MultiThreadedTestCase):
|
||||||
|
@property
|
||||||
|
def world_size(self):
|
||||||
|
return device_count
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self._spawn_threads()
|
||||||
|
|
||||||
|
# Verify we are running in multiproc
|
||||||
|
def test_distributed_spawning(self):
|
||||||
|
state = PartialState()
|
||||||
|
assert state.local_process_index == torch.distributed.get_rank()
|
||||||
|
assert state.num_processes == torch.distributed.get_world_size()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
||||||
Reference in New Issue
Block a user