Compare commits

...

2 Commits

7 changed files with 140 additions and 18 deletions

View File

@ -22,7 +22,7 @@ extras["quality"] = [
"ruff ~= 0.2.1",
]
extras["docs"] = []
extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized"]
extras["test_prod"] = ["pytest>=7.2.0,<=8.0.0", "pytest-xdist", "pytest-subtests", "parameterized", "expecttest"]
extras["test_dev"] = [
"datasets",
"diffusers",

View File

@ -101,6 +101,11 @@ class ThreadLocalSharedDict(threading.local):
def __set__(self, obj, value):
self._storage = value
def _get_shared_dict_type():
# Prefer global shared dictionary, except when using TPU or `backend == threaded`
if is_torch_xla_available() or (torch.distributed.is_initialized() and torch.distributed.get_backend() == "threaded"):
return ThreadLocalSharedDict
return dict
# Prefer global shared dictionary, except when using TPU.
SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict
@ -162,6 +167,9 @@ class PartialState:
]
def __init__(self, cpu: bool = False, **kwargs):
# This is needed when we are launching tests and have the `threaded` backend
if _get_shared_dict_type() != self._shared_state.__class__:
PartialState._shared_state = _get_shared_dict_type()()
self.__dict__ = self._shared_state
if not self.initialized:
self._cpu = cpu
@ -185,7 +193,7 @@ class PartialState:
self.backend = backend
self.distributed_type = distributed_type
use_deepspeed = False
if not cpu and self.backend != "xla":
if not cpu and self.backend != "xla" and not torch.distributed.is_initialized():
if int(os.environ.get("LOCAL_RANK", -1)) != -1:
# Deal with spawning deepspeed
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
@ -274,9 +282,13 @@ class PartialState:
else:
self.num_processes = torch.distributed.get_world_size()
self.process_index = torch.distributed.get_rank()
self.local_process_index = (
int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank
)
# Setting `local_process_index` requires some care
if dist_information is not None:
self.local_process_index = dist_information.local_rank
elif backend == "threaded":
self.local_process_index = self.process_index
else:
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
self.set_device()
# Now we can change to deepseed
if use_deepspeed:
@ -710,6 +722,10 @@ class PartialState:
) -> tuple[str, DistributedType]:
"Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly"
distributed_type = None
if torch.distributed.is_initialized():
backend = torch.distributed.get_backend()
if backend == "threaded":
distributed_type = DistributedType.MULTI_GPU
if sagemaker_dp:
import smdistributed.dataparallel.torch.torch_smddp # noqa
@ -718,7 +734,7 @@ class PartialState:
elif is_torch_xla_available():
backend = "xla"
distributed_type = DistributedType.XLA
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
elif not cpu and int(os.environ.get("LOCAL_RANK", -1)) != -1:
if is_mlu_available():
backend = "cncl"
distributed_type = DistributedType.MULTI_MLU

View File

@ -807,10 +807,10 @@ def main():
if state.distributed_type == DistributedType.DEEPSPEED:
return
if state.local_process_index == 0:
print("\n**Training integration test**")
training_check(use_seedable_sampler=False)
training_check(use_seedable_sampler=True)
# if state.local_process_index == 0:
# print("\n**Training integration test**")
# training_check(use_seedable_sampler=False)
# training_check(use_seedable_sampler=True)
if state.local_process_index == 0:
print("\n**Breakpoint trigger test**")

View File

@ -30,7 +30,7 @@ import torch
import accelerate
from ..state import AcceleratorState, PartialState
from ..state import PartialState, PartialState
from ..utils import (
gather,
is_bnb_available,
@ -427,14 +427,14 @@ class TempDirTestCase(unittest.TestCase):
class AccelerateTestCase(unittest.TestCase):
"""
A TestCase class that will reset the accelerator state at the end of every test. Every test that checks or utilizes
the `AcceleratorState` class should inherit from this to avoid silent failures due to state being shared between
the `PartialState` class should inherit from this to avoid silent failures due to state being shared between
tests.
"""
def tearDown(self):
super().tearDown()
# Reset the state of the AcceleratorState singleton.
AcceleratorState._reset_state()
# Reset the state of the PartialState singleton.
PartialState._reset_state()
PartialState._reset_state()
@ -472,7 +472,7 @@ class MockingTestCase(unittest.TestCase):
def are_the_same_tensors(tensor):
state = AcceleratorState()
state = PartialState()
tensor = tensor[None].clone().to(state.device)
tensors = gather(tensor).cpu()
tensor = tensor[0].cpu()

View File

@ -18,7 +18,7 @@ from typing import List, Optional, Union
import numpy as np
import torch
from ..state import AcceleratorState
from ..state import PartialState
from .constants import CUDA_DISTRIBUTED_TYPES
from .dataclasses import DistributedType, RNGType
from .imports import is_mlu_available, is_npu_available, is_torch_xla_available, is_xpu_available
@ -41,7 +41,7 @@ def set_seed(seed: int, device_specific: bool = False, deterministic: bool = Fal
Whether to use deterministic algorithms where available. Can slow down training.
"""
if device_specific:
seed += AcceleratorState().process_index
seed += PartialState().process_index
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
@ -84,7 +84,7 @@ def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optiona
rng_state = generator.get_state()
# Broadcast the rng state from device 0 to other devices
state = AcceleratorState()
state = PartialState()
if state.distributed_type == DistributedType.XLA:
rng_state = rng_state.to(xm.xla_device())
xm.collective_broadcast([rng_state])

76
tests/baseline.py Normal file
View File

@ -0,0 +1,76 @@
#!/usr/bin/env python
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from accelerate import PartialState, Accelerator
from accelerate.test_utils.testing import assert_exception
from accelerate.utils.dataclasses import DistributedType
from accelerate.utils.operations import (
DistributedOperationException,
broadcast,
copy_tensor_to_devices,
gather,
gather_object,
pad_across_processes,
reduce,
)
def create_tensor(state):
return (torch.arange(state.num_processes) + 1.0 + (state.num_processes * state.process_index)).to(state.device)
def test_gather(state):
tensor = create_tensor(state)
gathered_tensor = gather(tensor)
assert gathered_tensor.tolist() == list(range(1, state.num_processes**2 + 1))
def test_gather_object(state):
# Gather objects in TorchXLA is not supported.
if state.distributed_type == DistributedType.XLA:
return
obj = [state.process_index]
gathered_obj = gather_object(obj)
assert len(gathered_obj) == state.num_processes, f"{gathered_obj}, {len(gathered_obj)} != {state.num_processes}"
assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}"
def main():
accelerator = Accelerator()
state = accelerator.state
if state.local_process_index == 0:
print("**Initialization**")
state.wait_for_everyone()
if state.distributed_type == DistributedType.MULTI_GPU:
num_processes_per_node = torch.cuda.device_count()
else:
num_processes_per_node = state.num_processes
# We only run this test on non-multinode
if state.process_index == 0:
print("\n**Test gather operation**")
test_gather(state)
if state.process_index == 0:
print("\n**Test gather_object operation**")
test_gather_object(state)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,30 @@
import os
import torch
import inspect
import unittest
from torch.testing._internal.common_distributed import (
MultiThreadedTestCase,
)
from torch.testing._internal.common_utils import run_tests
from accelerate import Accelerator, PartialState
from accelerate.test_utils import device_count
class TrainingTester(MultiThreadedTestCase):
@property
def world_size(self):
return device_count
def setUp(self):
super().setUp()
self._spawn_threads()
# Verify we are running in multiproc
def test_distributed_spawning(self):
state = PartialState()
assert state.local_process_index == torch.distributed.get_rank()
assert state.num_processes == torch.distributed.get_world_size()
if __name__ == "__main__":
run_tests()