mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Fix decorators skipping NCCL tests (#158846)"
This reverts commit c2388201fc85b0748173212de5a17514c7a71f21. Reverted https://github.com/pytorch/pytorch/pull/158846 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing some inductor tests ([comment](https://github.com/pytorch/pytorch/pull/158846#issuecomment-3276471387))
This commit is contained in:
@ -1,7 +1,7 @@
|
|||||||
# Owner(s): ["module: fsdp"]
|
# Owner(s): ["module: fsdp"]
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest.mock
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
@ -37,9 +37,9 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed.fsdp import fully_shard
|
from torch.distributed.fsdp import fully_shard
|
||||||
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
|
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
device = '{device_type.type}'
|
device = {device_type.type}
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
model = nn.Sequential(*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)])
|
model = nn.Sequential(*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)])
|
||||||
for layer in model:
|
for layer in model:
|
||||||
|
@ -13,7 +13,6 @@ from functorch import make_fx
|
|||||||
from torch._inductor.utils import run_and_get_code
|
from torch._inductor.utils import run_and_get_code
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||||
from torch.testing._internal.common_distributed import exit_if_lt_x_accelerators
|
|
||||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||||
|
|
||||||
|
|
||||||
@ -25,7 +24,7 @@ from torch.testing._internal.common_distributed import (
|
|||||||
DistributedTestBase,
|
DistributedTestBase,
|
||||||
MultiThreadedTestCase,
|
MultiThreadedTestCase,
|
||||||
requires_accelerator_dist_backend,
|
requires_accelerator_dist_backend,
|
||||||
skip_if_no_gpu,
|
TEST_SKIPS,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
@ -480,14 +479,25 @@ elif TEST_XPU:
|
|||||||
BACKEND = dist.Backend.XCCL
|
BACKEND = dist.Backend.XCCL
|
||||||
|
|
||||||
|
|
||||||
|
# allows you to check for multiple accelerator irrespective of device type
|
||||||
|
# to add new device types to this check simply follow the same format
|
||||||
|
# and append an elif with the conditional and appropriate device count function for your new device
|
||||||
|
def exit_if_lt_x_accelerators(x):
|
||||||
|
if torch.accelerator.is_available():
|
||||||
|
if torch.accelerator.device_count() < x:
|
||||||
|
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
|
||||||
|
|
||||||
|
|
||||||
def with_comms(func=None):
|
def with_comms(func=None):
|
||||||
if func is None:
|
if func is None:
|
||||||
return partial(with_comms)
|
return partial(with_comms)
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(self, *args, **kwargs):
|
def wrapper(self, *args, **kwargs):
|
||||||
if BACKEND in (dist.Backend.NCCL, dist.Backend.XCCL):
|
if (
|
||||||
exit_if_lt_x_accelerators(self.world_size)
|
BACKEND == dist.Backend.NCCL or BACKEND == dist.Backend.XCCL
|
||||||
|
) and torch.accelerator.device_count() < self.world_size:
|
||||||
|
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||||
|
|
||||||
kwargs["device"] = DEVICE
|
kwargs["device"] = DEVICE
|
||||||
self.pg = self.create_pg(device=DEVICE)
|
self.pg = self.create_pg(device=DEVICE)
|
||||||
@ -500,9 +510,9 @@ def with_comms(func=None):
|
|||||||
|
|
||||||
|
|
||||||
class TestCollectivesWithDistributedBackend(DistributedTestBase):
|
class TestCollectivesWithDistributedBackend(DistributedTestBase):
|
||||||
@skip_if_no_gpu
|
|
||||||
@with_comms()
|
@with_comms()
|
||||||
def test_all_gather_into_tensor_coalesced(self, device):
|
def test_all_gather_into_tensor_coalesced(self, device):
|
||||||
|
exit_if_lt_x_accelerators(self.world_size)
|
||||||
tensors = [
|
tensors = [
|
||||||
torch.ones([4], device=device),
|
torch.ones([4], device=device),
|
||||||
torch.ones([4], device=device) + 1,
|
torch.ones([4], device=device) + 1,
|
||||||
@ -574,8 +584,9 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
|
|||||||
compiled_allreduce(torch.randn(8, device=device), self.pg)
|
compiled_allreduce(torch.randn(8, device=device), self.pg)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@skip_if_no_gpu
|
|
||||||
def test_tracing_with_fakepg(self, device=DEVICE):
|
def test_tracing_with_fakepg(self, device=DEVICE):
|
||||||
|
exit_if_lt_x_accelerators(self.world_size)
|
||||||
|
|
||||||
def allreduce(t, pg):
|
def allreduce(t, pg):
|
||||||
return ft_c.all_reduce(t, "sum", pg)
|
return ft_c.all_reduce(t, "sum", pg)
|
||||||
|
|
||||||
@ -616,9 +627,9 @@ class TestDistributedBackendCollectivesWithWorldSize4(
|
|||||||
def world_size(self):
|
def world_size(self):
|
||||||
return 4
|
return 4
|
||||||
|
|
||||||
@skip_if_no_gpu
|
|
||||||
@with_comms()
|
@with_comms()
|
||||||
def test_permute_tensor_with_sub_group(self, device):
|
def test_permute_tensor_with_sub_group(self, device):
|
||||||
|
exit_if_lt_x_accelerators(self.world_size)
|
||||||
mesh_dim_names = ["dp", "tp"]
|
mesh_dim_names = ["dp", "tp"]
|
||||||
|
|
||||||
mesh_2d = dt.init_device_mesh(
|
mesh_2d = dt.init_device_mesh(
|
||||||
|
@ -118,26 +118,14 @@ def requires_ddp_rank(device):
|
|||||||
return device in DDP_RANK_DEVICES
|
return device in DDP_RANK_DEVICES
|
||||||
|
|
||||||
|
|
||||||
def exit_if_lt_x_cuda_devs(x):
|
|
||||||
"""Exit process unless at least the given number of CUDA devices are available"""
|
|
||||||
if torch.cuda.device_count() < x:
|
|
||||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
|
||||||
|
|
||||||
|
|
||||||
# allows you to check for multiple accelerator irrespective of device type
|
|
||||||
# to add new device types to this check simply follow the same format
|
|
||||||
# and append an elif with the conditional and appropriate device count function for your new device
|
|
||||||
def exit_if_lt_x_accelerators(x):
|
|
||||||
if torch.accelerator.device_count() < x:
|
|
||||||
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
|
|
||||||
|
|
||||||
|
|
||||||
def skip_if_no_gpu(func):
|
def skip_if_no_gpu(func):
|
||||||
"""Skips if the world size exceeds the number of GPUs, ensuring that if the
|
"""Skips if the world size exceeds the number of GPUs, ensuring that if the
|
||||||
test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
|
test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
if not (TEST_CUDA or TEST_HPU or TEST_XPU):
|
||||||
|
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
if TEST_CUDA and torch.cuda.device_count() < world_size:
|
if TEST_CUDA and torch.cuda.device_count() < world_size:
|
||||||
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
|
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
|
||||||
@ -148,9 +136,7 @@ def skip_if_no_gpu(func):
|
|||||||
|
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return unittest.skipUnless(
|
return wrapper
|
||||||
TEST_CUDA or TEST_HPU or TEST_XPU, TEST_SKIPS["no_cuda"].message
|
|
||||||
)(wrapper)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO (kwen2501): what is the purpose of this decorator? Tests with this
|
# TODO (kwen2501): what is the purpose of this decorator? Tests with this
|
||||||
@ -182,16 +168,33 @@ def skip_if_odd_worldsize(func):
|
|||||||
|
|
||||||
|
|
||||||
def require_n_gpus_for_nccl_backend(n, backend):
|
def require_n_gpus_for_nccl_backend(n, backend):
|
||||||
return skip_if_lt_x_gpu(n) if backend == "nccl" else unittest.skipIf(False, None)
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if backend == "nccl" and torch.cuda.device_count() < n:
|
||||||
|
sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code)
|
||||||
|
else:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def import_transformers_or_skip():
|
def import_transformers_or_skip():
|
||||||
try:
|
def decorator(func):
|
||||||
from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401
|
||||||
|
|
||||||
return unittest.skipIf(False)
|
return func(*args, **kwargs)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return unittest.skip(TEST_SKIPS["importerror"].message)
|
sys.exit(TEST_SKIPS["importerror"].exit_code)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def at_least_x_gpu(x):
|
def at_least_x_gpu(x):
|
||||||
@ -205,7 +208,36 @@ def at_least_x_gpu(x):
|
|||||||
|
|
||||||
|
|
||||||
def skip_if_lt_x_gpu(x):
|
def skip_if_lt_x_gpu(x):
|
||||||
return unittest.skipUnless(at_least_x_gpu(x), TEST_SKIPS[f"multi-gpu-{x}"].message)
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
if TEST_HPU and torch.hpu.device_count() >= x:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
if TEST_XPU and torch.xpu.device_count() >= x:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# This decorator helps avoiding initializing cuda while testing other backends
|
||||||
|
def nccl_skip_if_lt_x_gpu(backend, x):
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if backend != "nccl":
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def verify_ddp_error_logged(model_DDP, err_substr):
|
def verify_ddp_error_logged(model_DDP, err_substr):
|
||||||
@ -392,7 +424,14 @@ def requires_multicast_support():
|
|||||||
def skip_if_rocm_multiprocess(func):
|
def skip_if_rocm_multiprocess(func):
|
||||||
"""Skips a test for ROCm"""
|
"""Skips a test for ROCm"""
|
||||||
func.skip_if_rocm_multiprocess = True
|
func.skip_if_rocm_multiprocess = True
|
||||||
return unittest.skipUnless(TEST_WITH_ROCM, TEST_SKIPS["skipIfRocm"].message)(func)
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not TEST_WITH_ROCM:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
sys.exit(TEST_SKIPS["skipIfRocm"].exit_code)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def skip_if_win32():
|
def skip_if_win32():
|
||||||
|
@ -7,9 +7,8 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import rpc
|
from torch.distributed import rpc
|
||||||
from torch.testing._internal.common_distributed import (
|
from torch.testing._internal.common_distributed import (
|
||||||
exit_if_lt_x_cuda_devs,
|
|
||||||
MultiProcessTestCase,
|
MultiProcessTestCase,
|
||||||
require_n_gpus_for_nccl_backend,
|
TEST_SKIPS,
|
||||||
tp_transports,
|
tp_transports,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -95,10 +94,10 @@ def with_comms(func=None, init_rpc=True, backend="nccl"):
|
|||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(self, *args, **kwargs):
|
def wrapper(self, *args, **kwargs):
|
||||||
if backend == "nccl":
|
if backend == "nccl" and torch.cuda.device_count() < self.world_size:
|
||||||
exit_if_lt_x_cuda_devs(self.world_size)
|
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||||
self.init_comms(init_rpc=init_rpc, backend=backend)
|
self.init_comms(init_rpc=init_rpc, backend=backend)
|
||||||
func(self, *args, **kwargs)
|
func(self, *args, **kwargs)
|
||||||
self.destroy_comms(destroy_rpc=init_rpc)
|
self.destroy_comms(destroy_rpc=init_rpc)
|
||||||
|
|
||||||
return require_n_gpus_for_nccl_backend(1, backend)(wrapper)
|
return wrapper
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
import sys
|
||||||
from collections.abc import Iterator, Sequence
|
from collections.abc import Iterator, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
@ -30,12 +31,12 @@ from torch.distributed.tensor.parallel import (
|
|||||||
SequenceParallel,
|
SequenceParallel,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_distributed import (
|
from torch.testing._internal.common_distributed import (
|
||||||
exit_if_lt_x_cuda_devs,
|
|
||||||
MultiProcContinuousTest,
|
MultiProcContinuousTest,
|
||||||
MultiProcessTestCase,
|
MultiProcessTestCase,
|
||||||
MultiThreadedTestCase,
|
MultiThreadedTestCase,
|
||||||
run_subtests,
|
run_subtests,
|
||||||
skip_if_lt_x_gpu,
|
skip_if_lt_x_gpu,
|
||||||
|
TEST_SKIPS,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import TEST_CUDA, TEST_HPU, TEST_XPU
|
from torch.testing._internal.common_utils import TEST_CUDA, TEST_HPU, TEST_XPU
|
||||||
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
|
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
|
||||||
@ -373,8 +374,8 @@ class DTensorTestBase(MultiProcessTestCase):
|
|||||||
return init_device_mesh(self.device_type, (self.world_size,))
|
return init_device_mesh(self.device_type, (self.world_size,))
|
||||||
|
|
||||||
def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
|
def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
|
||||||
if "nccl" in self.backend:
|
if "nccl" in self.backend and torch.cuda.device_count() < self.world_size:
|
||||||
exit_if_lt_x_cuda_devs(self.world_size)
|
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||||
|
|
||||||
if backend is None:
|
if backend is None:
|
||||||
backend = self.backend
|
backend = self.backend
|
||||||
|
@ -59,10 +59,10 @@ from torch.testing._internal.common_distributed import (
|
|||||||
captured_output,
|
captured_output,
|
||||||
cleanup_temp_dir,
|
cleanup_temp_dir,
|
||||||
DistTestCases,
|
DistTestCases,
|
||||||
exit_if_lt_x_cuda_devs,
|
|
||||||
init_multigpu_helper,
|
init_multigpu_helper,
|
||||||
initialize_temp_directories,
|
initialize_temp_directories,
|
||||||
MultiProcessTestCase,
|
MultiProcessTestCase,
|
||||||
|
nccl_skip_if_lt_x_gpu,
|
||||||
require_n_gpus_for_nccl_backend,
|
require_n_gpus_for_nccl_backend,
|
||||||
requires_nccl_version,
|
requires_nccl_version,
|
||||||
simple_sparse_reduce_tests,
|
simple_sparse_reduce_tests,
|
||||||
@ -609,8 +609,10 @@ class TestDistBackend(MultiProcessTestCase):
|
|||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.file_name = file_name
|
self.file_name = file_name
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available() and torch.cuda.device_count() < int(
|
||||||
exit_if_lt_x_cuda_devs(int(self.world_size))
|
self.world_size
|
||||||
|
):
|
||||||
|
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||||
try:
|
try:
|
||||||
pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout)
|
pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout)
|
||||||
timeout = timedelta(seconds=pg_timeout_seconds)
|
timeout = timedelta(seconds=pg_timeout_seconds)
|
||||||
@ -5342,7 +5344,7 @@ class DistributedTest:
|
|||||||
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
|
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
|
||||||
"get_future is only supported on mpi, nccl and gloo",
|
"get_future is only supported on mpi, nccl and gloo",
|
||||||
)
|
)
|
||||||
@require_n_gpus_for_nccl_backend(2, BACKEND)
|
@nccl_skip_if_lt_x_gpu(BACKEND, 2)
|
||||||
def test_accumulate_gradients_no_sync(self):
|
def test_accumulate_gradients_no_sync(self):
|
||||||
"""
|
"""
|
||||||
Runs _test_accumulate_gradients_no_sync using default inputs
|
Runs _test_accumulate_gradients_no_sync using default inputs
|
||||||
@ -5353,7 +5355,7 @@ class DistributedTest:
|
|||||||
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
|
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
|
||||||
"get_future is only supported on mpi, nccl and gloo",
|
"get_future is only supported on mpi, nccl and gloo",
|
||||||
)
|
)
|
||||||
@require_n_gpus_for_nccl_backend(2, BACKEND)
|
@nccl_skip_if_lt_x_gpu(BACKEND, 2)
|
||||||
def test_accumulate_gradients_no_sync_grad_is_view(self):
|
def test_accumulate_gradients_no_sync_grad_is_view(self):
|
||||||
"""
|
"""
|
||||||
Runs _test_accumulate_gradients_no_sync using default inputs
|
Runs _test_accumulate_gradients_no_sync using default inputs
|
||||||
@ -5364,7 +5366,7 @@ class DistributedTest:
|
|||||||
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
|
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
|
||||||
"get_future is only supported on mpi, nccl and gloo",
|
"get_future is only supported on mpi, nccl and gloo",
|
||||||
)
|
)
|
||||||
@require_n_gpus_for_nccl_backend(2, BACKEND)
|
@nccl_skip_if_lt_x_gpu(BACKEND, 2)
|
||||||
def test_accumulate_gradients_no_sync_allreduce_hook(self):
|
def test_accumulate_gradients_no_sync_allreduce_hook(self):
|
||||||
"""
|
"""
|
||||||
Runs multiple iterations on _test_accumulate_gradients_no_sync
|
Runs multiple iterations on _test_accumulate_gradients_no_sync
|
||||||
@ -5392,7 +5394,7 @@ class DistributedTest:
|
|||||||
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
|
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
|
||||||
"get_future is only supported on mpi, nccl and gloo",
|
"get_future is only supported on mpi, nccl and gloo",
|
||||||
)
|
)
|
||||||
@require_n_gpus_for_nccl_backend(2, BACKEND)
|
@nccl_skip_if_lt_x_gpu(BACKEND, 2)
|
||||||
def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self):
|
def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self):
|
||||||
"""
|
"""
|
||||||
Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce
|
Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce
|
||||||
@ -5426,7 +5428,7 @@ class DistributedTest:
|
|||||||
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
|
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
|
||||||
"get_future is only supported on mpi, nccl and gloo",
|
"get_future is only supported on mpi, nccl and gloo",
|
||||||
)
|
)
|
||||||
@require_n_gpus_for_nccl_backend(2, BACKEND)
|
@nccl_skip_if_lt_x_gpu(BACKEND, 2)
|
||||||
def test_get_future(self):
|
def test_get_future(self):
|
||||||
def mult(fut):
|
def mult(fut):
|
||||||
return [t * 3 for t in fut.wait()]
|
return [t * 3 for t in fut.wait()]
|
||||||
|
Reference in New Issue
Block a user