mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-14 14:14:32 +08:00
Compare commits
5 Commits
fork-teste
...
v0.29.2
| Author | SHA1 | Date | |
|---|---|---|---|
| 39e0a8ef59 | |||
| 759a9336ce | |||
| 210778370e | |||
| 12eed81eb8 | |||
| ec88c8f54a |
2
setup.py
2
setup.py
@ -47,7 +47,7 @@ extras["sagemaker"] = [
|
||||
|
||||
setup(
|
||||
name="accelerate",
|
||||
version="0.29.0.dev",
|
||||
version="0.29.2",
|
||||
description="Accelerate",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
__version__ = "0.29.0.dev0"
|
||||
__version__ = "0.29.2"
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .big_modeling import (
|
||||
|
||||
@ -183,18 +183,7 @@ class PartialState:
|
||||
self.backend = backend
|
||||
self.distributed_type = distributed_type
|
||||
use_deepspeed = False
|
||||
if not cpu:
|
||||
# Deal with XLA
|
||||
if is_torch_xla_available():
|
||||
self.device = xm.xla_device()
|
||||
xm.set_replication(self.device, xm.get_xla_supported_devices())
|
||||
self.num_processes = xm.xrt_world_size()
|
||||
self.process_index = xm.get_ordinal()
|
||||
if is_torch_xla_available(check_is_tpu=True):
|
||||
self.local_process_index = xm.get_local_ordinal()
|
||||
else:
|
||||
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
|
||||
self.distributed_type = DistributedType.XLA
|
||||
if not cpu and self.backend != "xla":
|
||||
if int(os.environ.get("LOCAL_RANK", -1)) != -1:
|
||||
# Deal with spawning deepspeed
|
||||
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
||||
@ -204,7 +193,7 @@ class PartialState:
|
||||
)
|
||||
from deepspeed import comm as dist
|
||||
|
||||
if is_xpu_available and is_ccl_available():
|
||||
if is_xpu_available() and is_ccl_available():
|
||||
os.environ["CCL_PROCESS_LAUNCHER"] = "none"
|
||||
os.environ["CCL_LOCAL_SIZE"] = os.environ.get("LOCAL_WORLD_SIZE", "1")
|
||||
os.environ["CCL_LOCAL_RANK"] = os.environ.get("LOCAL_RANK", "0")
|
||||
@ -270,6 +259,16 @@ class PartialState:
|
||||
self.num_processes = 1
|
||||
self.process_index = 0
|
||||
self.local_process_index = 0
|
||||
elif self.backend == "xla":
|
||||
# XLA needs device setting first for `set_replication`
|
||||
self.set_device()
|
||||
xm.set_replication(self.device, xm.get_xla_supported_devices())
|
||||
self.num_processes = xm.xrt_world_size()
|
||||
self.process_index = xm.get_ordinal()
|
||||
if is_torch_xla_available(check_is_tpu=True):
|
||||
self.local_process_index = xm.get_local_ordinal()
|
||||
else:
|
||||
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
|
||||
else:
|
||||
self.num_processes = torch.distributed.get_world_size()
|
||||
self.process_index = torch.distributed.get_rank()
|
||||
@ -284,16 +283,17 @@ class PartialState:
|
||||
# Set CPU affinity if enabled
|
||||
if parse_flag_from_env("ACCELERATE_CPU_AFFINITY", False):
|
||||
set_numa_affinity(self.local_process_index)
|
||||
self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)
|
||||
|
||||
# Check for old RTX 4000's that can't use P2P or IB and are on old drivers
|
||||
if self.device.type == "cuda" and not check_cuda_p2p_ib_support():
|
||||
if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
|
||||
raise NotImplementedError(
|
||||
"Using RTX 4000 series doesn't support faster communication broadband via P2P or IB. "
|
||||
'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
|
||||
"will do this automatically."
|
||||
)
|
||||
# Check for old RTX 4000's that can't use P2P or IB and are on old drivers
|
||||
if self.device.type == "cuda" and not check_cuda_p2p_ib_support():
|
||||
if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
|
||||
raise NotImplementedError(
|
||||
"Using RTX 4000 series doesn't support faster communication broadband via P2P or IB. "
|
||||
'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
|
||||
"will do this automatically."
|
||||
)
|
||||
# Important: This should be the *only* code outside of `self.initialized!`
|
||||
self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
@ -715,6 +715,9 @@ class PartialState:
|
||||
|
||||
backend = "smddp"
|
||||
distributed_type = DistributedType.MULTI_GPU
|
||||
elif is_torch_xla_available():
|
||||
backend = "xla"
|
||||
distributed_type = DistributedType.XLA
|
||||
elif int(os.environ.get("LOCAL_RANK", -1)) != -1:
|
||||
if not cpu:
|
||||
if is_mlu_available():
|
||||
@ -758,17 +761,20 @@ class PartialState:
|
||||
"""
|
||||
if self.device is not None:
|
||||
return
|
||||
if self.num_processes == 1:
|
||||
if self.distributed_type == DistributedType.NO:
|
||||
self.device = torch.device("cpu") if self._cpu else self.default_device
|
||||
return
|
||||
device = str(self.distributed_type).split(".")[-1].replace("MULTI_", "").lower()
|
||||
if device not in ("cpu", "gpu", "mlu", "npu", "xpu"):
|
||||
if device not in ("cpu", "gpu", "mlu", "npu", "xpu", "xla"):
|
||||
raise ValueError(
|
||||
f"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!"
|
||||
)
|
||||
if device == "gpu":
|
||||
device = "cuda"
|
||||
self.device = torch.device(device, self.local_process_index)
|
||||
if device == "xla":
|
||||
self.device = xm.xla_device()
|
||||
else:
|
||||
if device == "gpu":
|
||||
device = "cuda"
|
||||
self.device = torch.device(device, self.local_process_index)
|
||||
if self.device is not None:
|
||||
if device == "xpu":
|
||||
torch.xpu.set_device(self.device)
|
||||
|
||||
@ -22,7 +22,6 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
@ -711,6 +710,8 @@ def test_trigger():
|
||||
|
||||
|
||||
def test_reinstantiated_state():
|
||||
import pytest
|
||||
|
||||
AcceleratorState._reset_state()
|
||||
simple_model = torch.nn.Linear(1, 1)
|
||||
# First define an accelerator
|
||||
|
||||
Reference in New Issue
Block a user