mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add device generalisation support for distributed tests (#152471)
### MOTIVATION To generalize Distributed test cases for non-CUDA devices ### CHANGES - test/distributed/optim/test_zero_redundancy_optimizer.py - test/distributed/test_c10d_logger.py - test/distributed/test_compute_comm_reordering.py Replaced hard coded device names with get_devtype from torch.testing._internal.common_fsdp. DistributedTestBase is used instead of MultiProcessTestCase, to make use of helper functions. - torch/testing/_internal/common_distributed.py extended common utility functions Pull Request resolved: https://github.com/pytorch/pytorch/pull/152471 Approved by: https://github.com/d4l3k
This commit is contained in:
committed by
PyTorch MergeBot
parent
0aed855b2b
commit
e1f28fe17b
@ -6,7 +6,6 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, cast
|
||||
@ -30,12 +29,23 @@ from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import AdamW, SGD
|
||||
from torch.testing._internal import common_distributed
|
||||
from torch.testing._internal.common_distributed import (
|
||||
DistributedTestBase,
|
||||
logger,
|
||||
requires_accelerator_dist_backend,
|
||||
requires_ddp_rank,
|
||||
requires_gloo,
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_no_gpu,
|
||||
skip_if_rocm_multiprocess,
|
||||
skip_if_win32,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
IS_WINDOWS,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfHpu,
|
||||
)
|
||||
|
||||
|
||||
@ -47,63 +57,24 @@ except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
|
||||
|
||||
# Use GLOO on GPU when running CUDA + Windows
|
||||
def _get_backend_for_tests():
|
||||
return (
|
||||
dist.Backend.NCCL
|
||||
if not IS_WINDOWS and torch.cuda.is_available()
|
||||
# Windows only has GLOO, but GLOO GPU works. And use GLOO CPU when
|
||||
# no GPUs are available.
|
||||
else dist.Backend.GLOO
|
||||
)
|
||||
device_type = str(get_devtype())
|
||||
|
||||
|
||||
BACKEND = _get_backend_for_tests()
|
||||
|
||||
|
||||
class TestZeroRedundancyOptimizer(common_distributed.MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
os.environ["WORLD_SIZE"] = str(self.world_size)
|
||||
self._spawn_processes()
|
||||
|
||||
class TestZeroRedundancyOptimizer(DistributedTestBase):
|
||||
@property
|
||||
def device(self):
|
||||
return (
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
return device_type
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 1
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
torch.distributed.destroy_process_group()
|
||||
except AssertionError:
|
||||
pass
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def dist_init(self, rank, world_size=-1, backend=BACKEND):
|
||||
if world_size < 1:
|
||||
world_size = self.world_size
|
||||
store = dist.FileStore(self.file_name, world_size)
|
||||
return dist.init_process_group(
|
||||
backend=backend,
|
||||
store=store,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
|
||||
class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
||||
def test_state_dict(self):
|
||||
"""Check that ZeroRedundancyOptimizer exposes the expected state dict
|
||||
interface, irrespective of the sharding."""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
LR1 = 0.1
|
||||
LR2 = 0.01
|
||||
MOMENTUM = 0.9
|
||||
@ -171,7 +142,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
||||
def test_lr_scheduler(self):
|
||||
"""Check that a normal PyTorch ``lr_scheduler`` is usable with
|
||||
ZeroRedundancyOptimizer."""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
NUM_ITERS = 5
|
||||
LR = 0.01
|
||||
x = torch.tensor([1.0], device=self.device, requires_grad=True)
|
||||
@ -193,7 +164,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
||||
|
||||
def test_step_with_kwargs(self):
|
||||
"""Check that the ``step(**kwargs)`` interface is properly exposed."""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
LR = 0.1
|
||||
|
||||
class SGDWithStepKWArg(torch.optim.SGD):
|
||||
@ -217,7 +188,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
||||
"""Check that ZeroRedundancyOptimizer wrapping an optimizer that adds
|
||||
extra keys to ``param_groups`` exposes those keys through ZeRO's own
|
||||
``param_groups``."""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
LR = 0.1
|
||||
|
||||
class SGDWithNewKey(torch.optim.SGD):
|
||||
@ -236,7 +207,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
||||
def test_step_without_closure(self):
|
||||
"""Check that the ``step()`` method (without closure) is handled as
|
||||
expected."""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
LR = 0.1
|
||||
|
||||
class SGDWithoutClosure(torch.optim.SGD):
|
||||
@ -255,7 +226,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
||||
|
||||
def test_zero_grad(self):
|
||||
"""Check that the ``zero_grad`` method is properly handled."""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
LR = 0.01
|
||||
x = torch.rand(1)
|
||||
m = torch.nn.Linear(1, 1)
|
||||
@ -271,7 +242,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
||||
def test_constructor(self):
|
||||
"""Check the robustness of the ZeroRedundancyOptimizer constructor by
|
||||
passing different values for the ``params`` argument."""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
LR = 0.01
|
||||
m = torch.nn.Sequential(
|
||||
torch.nn.Linear(5, 10),
|
||||
@ -336,7 +307,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
||||
NOTE: This test should be removed once support for sparse parameters
|
||||
and varying parameter types is added.
|
||||
"""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
LR = 0.01
|
||||
inputs = [
|
||||
[torch.sparse_coo_tensor(size=(2, 3))],
|
||||
@ -353,25 +324,16 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
||||
|
||||
|
||||
class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
@property
|
||||
def device(self):
|
||||
return (
|
||||
torch.device(self.rank)
|
||||
if torch.cuda.is_available()
|
||||
else torch.device("cpu")
|
||||
)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return min(4, max(2, torch.cuda.device_count()))
|
||||
return min(4, max(2, torch.get_device_module(self.device).device_count()))
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
return (
|
||||
nullcontext()
|
||||
if not torch.cuda.is_available()
|
||||
else torch.cuda.device(self.rank)
|
||||
)
|
||||
if requires_ddp_rank(self.device):
|
||||
return torch.get_device_module(self.device).device(self.rank)
|
||||
else:
|
||||
return nullcontext()
|
||||
|
||||
def _check_same_model_params(
|
||||
self,
|
||||
@ -396,12 +358,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message,
|
||||
)
|
||||
|
||||
@common_distributed.skip_if_no_gpu
|
||||
@common_distributed.skip_if_rocm_multiprocess
|
||||
@skip_if_no_gpu
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_step(self):
|
||||
"""Check that ZeroRedundancyOptimizer properly exposes the ``step()``
|
||||
interface."""
|
||||
self.dist_init(self.rank, world_size=self.world_size)
|
||||
self.create_pg(self.device)
|
||||
LR = 0.01
|
||||
|
||||
with self.context:
|
||||
@ -436,13 +398,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
self.assertEqual(m.weight, m_zero.weight)
|
||||
self.assertEqual(m.bias, m_zero.bias)
|
||||
|
||||
@common_distributed.skip_if_no_gpu
|
||||
@common_distributed.skip_if_rocm_multiprocess
|
||||
@skip_if_no_gpu
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_step_with_closure(self):
|
||||
"""Check that ZeroRedundancyOptimizer properly exposes the
|
||||
``step(closure)`` interface."""
|
||||
self.dist_init(self.rank, world_size=self.world_size)
|
||||
|
||||
self.create_pg(self.device)
|
||||
with self.context:
|
||||
for bucket_view in [False, True]:
|
||||
x_val = self.rank + 1
|
||||
@ -487,11 +448,11 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
self.assertEqual(m.weight, torch.tensor([[1.1]]))
|
||||
self.assertEqual(m.bias, torch.tensor([2.1]))
|
||||
|
||||
@common_distributed.skip_if_no_gpu
|
||||
@skip_if_no_gpu
|
||||
def test_lr_scheduler(self):
|
||||
"""Check that a normal PyTorch ``lr_scheduler`` is usable with
|
||||
ZeroRedundancyOptimizer."""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
x = torch.tensor([1.0], device=self.device, requires_grad=True)
|
||||
x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
|
||||
o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
|
||||
@ -519,7 +480,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
``ZeroRedundancyOptimizer._partition_parameters()`` in
|
||||
zero_redundancy_optimizer.py.
|
||||
"""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
LR = 0.01
|
||||
sizes = [9, 7, 5, 3]
|
||||
params = []
|
||||
@ -541,7 +502,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
``ZeroRedundancyOptimizer._partition_parameters()`` in
|
||||
zero_redundancy_optimizer.py.
|
||||
"""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
LR = 0.01
|
||||
|
||||
# Test with all parameters trainable to begin with
|
||||
@ -589,14 +550,14 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
all_trainable()
|
||||
some_trainable()
|
||||
|
||||
@common_distributed.skip_if_no_gpu
|
||||
@skip_if_no_gpu
|
||||
def test_multiple_param_groups(self):
|
||||
"""
|
||||
Check parity between constructing ZeRO with multiple parameter groups
|
||||
upfront versus adding parameter groups to ZeRO after construction
|
||||
versus a non-sharded optimizer.
|
||||
"""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
BATCH_SIZE, NUM_ITERS = 8, 3
|
||||
INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 10, 5
|
||||
WD, LR = 0.01, 0.01
|
||||
@ -656,12 +617,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
torch.testing.assert_close(layer1.bias, layer2.bias)
|
||||
torch.testing.assert_close(layer1.bias, layer3.bias)
|
||||
|
||||
@common_distributed.skip_if_no_gpu
|
||||
@common_distributed.skip_if_rocm_multiprocess
|
||||
@skip_if_no_gpu
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_collect_shards(self):
|
||||
"""Check the state consolidation mechanism and the state dict exposed
|
||||
by ZeroRedundancyOptimizer."""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
LR = 1e-3
|
||||
MOMENTUM = 0.99
|
||||
BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 3, 20, 10, 5
|
||||
@ -719,27 +680,25 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
# trivial
|
||||
MIN_WORLD_SIZE = 4
|
||||
if self.world_size < MIN_WORLD_SIZE:
|
||||
common_distributed.logger.info(
|
||||
logger.info(
|
||||
"Skipping `test_nondefault_process_group()` since world size "
|
||||
"of %s is less than %s",
|
||||
self.world_size,
|
||||
MIN_WORLD_SIZE,
|
||||
)
|
||||
return
|
||||
BACKEND = dist.Backend.GLOO
|
||||
self.dist_init(self.rank, self.world_size, BACKEND)
|
||||
# Use GPU if enough are available, or fall back to CPU otherwise, which
|
||||
# is fine since Gloo backend supports both
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() >= self.world_size:
|
||||
device = torch.device(self.rank)
|
||||
else:
|
||||
# Use GPU if enough are available, or fall back to CPU otherwise
|
||||
if torch.get_device_module(self.device).device_count() < self.world_size:
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device(self.device)
|
||||
self.create_pg(device.type)
|
||||
# Create a new process group consisting of the even ranks to exercise
|
||||
# the case where the global and local ranks do not necessarily match
|
||||
subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0]
|
||||
process_group = dist.new_group(
|
||||
ranks=subgroup_ranks,
|
||||
backend=BACKEND,
|
||||
backend=self.backend(device.type),
|
||||
)
|
||||
# Ranks not participating in the new process group are no longer needed
|
||||
if self.rank not in subgroup_ranks:
|
||||
@ -811,7 +770,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
)
|
||||
check(optimizer)
|
||||
|
||||
@common_distributed.skip_if_no_gpu
|
||||
@skip_if_no_gpu
|
||||
@parametrize(
|
||||
"optimizer_class_str",
|
||||
["Adam", "AdamW", "SGD"],
|
||||
@ -828,7 +787,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
):
|
||||
"""When combined with DDP, check that a local optimizer gives the same
|
||||
results as wrapping that optimizer with ZeroRedundancyOptimizer."""
|
||||
self.dist_init(self.rank)
|
||||
self.create_pg(self.device)
|
||||
BATCHES = 20
|
||||
BATCH_SIZE = 64
|
||||
LR = 1e-3
|
||||
@ -867,7 +826,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
)
|
||||
sharded_ddp_model = DDP(
|
||||
module=model,
|
||||
device_ids=[self.rank],
|
||||
device_ids=[self.rank] if requires_ddp_rank(self.device) else None,
|
||||
broadcast_buffers=True,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
@ -879,7 +838,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
)
|
||||
ddp_model = DDP(
|
||||
local_model,
|
||||
device_ids=[self.rank],
|
||||
device_ids=[self.rank] if requires_ddp_rank(self.device) else None,
|
||||
broadcast_buffers=True,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
@ -892,7 +851,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
)
|
||||
|
||||
def check_step():
|
||||
input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM))
|
||||
input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM)).to(self.device)
|
||||
|
||||
def closure_ddp(input_tensor=input_tensor):
|
||||
ddp_optimizer.zero_grad()
|
||||
@ -970,13 +929,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
NUM_EPOCHS = 2
|
||||
LR = 0.01
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
if "cpu" not in device:
|
||||
torch.get_device_module(device).manual_seed(0)
|
||||
|
||||
rank = self.rank
|
||||
world_size = self.world_size
|
||||
is_gpu = device.type == "cuda"
|
||||
backend = _get_backend_for_tests() if is_gpu else dist.Backend.GLOO
|
||||
self.dist_init(rank, world_size, backend)
|
||||
self.create_pg(device)
|
||||
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(2, 3),
|
||||
@ -988,7 +946,9 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
# DDP ensures correct gradients in data parallel training, so DDP with
|
||||
# local optimizers on uneven inputs should be equivalent to ZeRO on
|
||||
# uneven inputs with gradients being manually set
|
||||
ddp_model = DDP(model, device_ids=[rank]) if is_gpu else DDP(model)
|
||||
ddp_model = (
|
||||
DDP(model, device_ids=[rank]) if requires_ddp_rank(device) else DDP(model)
|
||||
)
|
||||
local_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
|
||||
zero_model = copy.deepcopy(model)
|
||||
zero_model.to(device)
|
||||
@ -1111,27 +1071,28 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
)
|
||||
iter += 1
|
||||
|
||||
@common_distributed.requires_nccl()
|
||||
@common_distributed.skip_if_no_gpu
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_no_gpu
|
||||
def test_zero_join_gpu(self):
|
||||
"""Check that the ZeRO join hook allows training with uneven inputs
|
||||
on GPU."""
|
||||
self._test_zero_join(self.device)
|
||||
|
||||
@common_distributed.requires_gloo()
|
||||
@requires_gloo()
|
||||
def test_zero_join_cpu(self):
|
||||
"""Check that the ZeRO join hook allows training with uneven inputs
|
||||
on CPU."""
|
||||
self._test_zero_join(torch.device("cpu"))
|
||||
self._test_zero_join("cpu")
|
||||
|
||||
def _test_zero_model_parallel(self, parameters_as_bucket_view: bool):
|
||||
def _test_zero_model_parallel(self, parameters_as_bucket_view: bool, device: str):
|
||||
# Use two processes each with two GPUs
|
||||
assert self.rank < 2
|
||||
NUM_EPOCHS = 2
|
||||
NUM_INPUTS = 4
|
||||
LR = 0.01
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
if "cpu" not in device:
|
||||
torch.get_device_module(device).manual_seed(0)
|
||||
|
||||
class ModelParallelModel(torch.nn.Module):
|
||||
def __init__(self, dev0, dev1):
|
||||
@ -1221,7 +1182,8 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
atol=1e-04,
|
||||
), "Models differ after a step"
|
||||
|
||||
@common_distributed.skip_if_lt_x_gpu(4)
|
||||
@skipIfHpu
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@parametrize(
|
||||
"parameters_as_bucket_view",
|
||||
[False, True],
|
||||
@ -1234,8 +1196,8 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
layers are assigned to different devices."""
|
||||
if self.rank >= 2:
|
||||
return
|
||||
self.dist_init(self.rank, world_size=2)
|
||||
self._test_zero_model_parallel(parameters_as_bucket_view)
|
||||
self.create_pg(self.device, world_size=2)
|
||||
self._test_zero_model_parallel(parameters_as_bucket_view, self.device)
|
||||
|
||||
def _test_ddp_zero_overlap(
|
||||
self,
|
||||
@ -1250,12 +1212,10 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
SGD_WEIGHT_DECAY = 0.001
|
||||
NUM_INPUTS = 5
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
if "cpu" not in device:
|
||||
torch.get_device_module(device).manual_seed(0)
|
||||
|
||||
rank = self.rank
|
||||
is_gpu = device.type == "cuda"
|
||||
if is_gpu:
|
||||
torch.cuda.set_device(device)
|
||||
models_to_test = [
|
||||
(
|
||||
torch.nn.Sequential(
|
||||
@ -1273,11 +1233,16 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
)
|
||||
)
|
||||
for model, inputs in models_to_test:
|
||||
# Enable determinism in cudnn operators
|
||||
with torch.backends.cudnn.flags(
|
||||
enabled=True, deterministic=True, benchmark=False
|
||||
):
|
||||
device_ids = [rank] if is_gpu else None
|
||||
# Select deterministic context based on device
|
||||
det_ctx = (
|
||||
torch.backends.cudnn.flags(
|
||||
enabled=True, deterministic=True, benchmark=False
|
||||
)
|
||||
if "cuda" in device
|
||||
else torch.use_deterministic_algorithms(True)
|
||||
)
|
||||
with det_ctx:
|
||||
device_ids = [rank] if requires_ddp_rank(device) else None
|
||||
# Set up the DDP model overlapping with ZeRO
|
||||
ddp_model_overlap = DDP(
|
||||
copy.deepcopy(model).to(device),
|
||||
@ -1374,10 +1339,10 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
|
||||
# NOTE: The test is skipped if using Windows since functional optimizers
|
||||
# are not currently supported.
|
||||
@common_distributed.skip_if_win32()
|
||||
@common_distributed.requires_nccl()
|
||||
@common_distributed.skip_if_no_gpu
|
||||
@common_distributed.skip_if_rocm_multiprocess
|
||||
@skip_if_win32()
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_if_no_gpu
|
||||
@skip_if_rocm_multiprocess
|
||||
@parametrize(
|
||||
"use_gpu",
|
||||
[True],
|
||||
@ -1413,9 +1378,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
by ``hook_constructor`` and ``shard_buckets`` and using the given ZeRO
|
||||
and DDP arguments achieves parity with DDP using a local optimizer.
|
||||
"""
|
||||
device = torch.device(self.rank) if use_gpu else torch.device("cpu")
|
||||
backend = _get_backend_for_tests()
|
||||
self.dist_init(self.rank, self.world_size, backend)
|
||||
self.create_pg(self.device)
|
||||
hook_constructor = (
|
||||
hook_with_zero_step
|
||||
if not use_interleaved_hook
|
||||
@ -1423,7 +1386,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
)
|
||||
|
||||
self._test_ddp_zero_overlap(
|
||||
device,
|
||||
self.device if use_gpu else "cpu",
|
||||
hook_constructor,
|
||||
gradient_as_bucket_view,
|
||||
static_graph,
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from functools import partial, wraps
|
||||
@ -16,10 +15,13 @@ if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
|
||||
from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
|
||||
|
||||
device_type = str(get_devtype())
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
@ -27,8 +29,7 @@ if TEST_WITH_DEV_DBG_ASAN:
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
BACKEND = dist.Backend.NCCL
|
||||
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
|
||||
WORLD_SIZE = min(4, max(2, torch.get_device_module(device_type).device_count()))
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
@ -39,30 +40,16 @@ def with_comms(func=None):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
|
||||
if torch.get_device_module(device_type).device_count() < self.world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||
self.dist_init()
|
||||
self.create_pg(device_type)
|
||||
func(self)
|
||||
self.destroy_comms()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class C10dErrorLoggerTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
os.environ["WORLD_SIZE"] = str(self.world_size)
|
||||
os.environ["BACKEND"] = BACKEND
|
||||
self._spawn_processes()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return (
|
||||
torch.device(self.rank)
|
||||
if BACKEND == dist.Backend.NCCL
|
||||
else torch.device("cpu")
|
||||
)
|
||||
|
||||
class C10dErrorLoggerTest(DistributedTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return WORLD_SIZE
|
||||
@ -76,18 +63,6 @@ class C10dErrorLoggerTest(MultiProcessTestCase):
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def dist_init(self):
|
||||
dist.init_process_group(
|
||||
backend=BACKEND,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
init_method=f"file://{self.file_name}",
|
||||
)
|
||||
|
||||
# set device for nccl pg for collectives
|
||||
if BACKEND == "nccl":
|
||||
torch.cuda.set_device(self.rank)
|
||||
|
||||
def test_get_or_create_logger(self):
|
||||
self.assertIsNotNone(_c10d_logger)
|
||||
self.assertEqual(1, len(_c10d_logger.handlers))
|
||||
@ -117,7 +92,11 @@ class C10dErrorLoggerTest(MultiProcessTestCase):
|
||||
re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
|
||||
)
|
||||
|
||||
self.assertEqual(len(error_msg_dict), 9)
|
||||
# NCCL adds additional nccl_version data to the error_msg_dict
|
||||
if self.backend(device_type) == dist.Backend.NCCL:
|
||||
self.assertEqual(len(error_msg_dict), 9)
|
||||
else:
|
||||
self.assertEqual(len(error_msg_dict), 8)
|
||||
|
||||
self.assertIn("pg_name", error_msg_dict.keys())
|
||||
self.assertEqual("None", error_msg_dict["pg_name"])
|
||||
@ -126,13 +105,14 @@ class C10dErrorLoggerTest(MultiProcessTestCase):
|
||||
self.assertEqual("broadcast", error_msg_dict["func_name"])
|
||||
|
||||
self.assertIn("backend", error_msg_dict.keys())
|
||||
self.assertEqual("nccl", error_msg_dict["backend"])
|
||||
self.assertEqual(self.backend(device_type), error_msg_dict["backend"])
|
||||
|
||||
self.assertIn("nccl_version", error_msg_dict.keys())
|
||||
nccl_ver = torch.cuda.nccl.version()
|
||||
self.assertEqual(
|
||||
".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"]
|
||||
)
|
||||
if self.backend(device_type) == dist.Backend.NCCL:
|
||||
self.assertIn("nccl_version", error_msg_dict.keys())
|
||||
nccl_ver = torch.cuda.nccl.version()
|
||||
self.assertEqual(
|
||||
".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"]
|
||||
)
|
||||
|
||||
# In this test case, group_size = world_size, since we don't have multiple processes on one node.
|
||||
self.assertIn("group_size", error_msg_dict.keys())
|
||||
|
@ -26,12 +26,16 @@ from torch.testing._internal.common_distributed import (
|
||||
_dynamo_dist_per_rank_init,
|
||||
at_least_x_gpu,
|
||||
DynamoDistributedMultiProcTestCase,
|
||||
requires_nccl,
|
||||
requires_accelerator_dist_backend,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import skipIfRocm
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
|
||||
device_type = str(get_devtype())
|
||||
|
||||
|
||||
def get_snode_runtime_for_reorder_compute_test(snode):
|
||||
# NOTE: custom cost model to show that the compute reordering algorithm is working
|
||||
# Collective kernels
|
||||
@ -74,7 +78,7 @@ def create_grouped_node_for_allreduce_and_its_deps(snodes):
|
||||
return new_snode_order
|
||||
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend()
|
||||
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
"""
|
||||
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
|
||||
@ -113,9 +117,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
return torch.matmul(ar, b)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs)
|
||||
# Verify that the wait_tensor is sinked below the 1st matmul but
|
||||
@ -154,9 +161,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
return torch.matmul(d, e)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs)
|
||||
# Verify that the all_reduce_ has been raised above the 2nd matmul
|
||||
@ -202,9 +212,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
return torch.mm(e, g)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
# Things to verify:
|
||||
@ -255,9 +268,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
return (e,)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
# NOTE: after scheduling the first all_reduce:
|
||||
@ -312,9 +328,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
return (e,)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
# NOTE: after scheduling the first all_reduce:
|
||||
@ -362,9 +381,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
return (mm,)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
# Expectations:
|
||||
@ -387,9 +409,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
ranks = pg_info["ranks"]
|
||||
group_size = pg_info["group_size"]
|
||||
|
||||
g1 = torch.ones(10, 10, device="cuda")
|
||||
g2 = torch.ones(11, 11, device="cuda")
|
||||
g3 = torch.ones(12, 12, device="cuda")
|
||||
g1 = torch.ones(10, 10, device=device_type)
|
||||
g2 = torch.ones(11, 11, device=device_type)
|
||||
g3 = torch.ones(12, 12, device=device_type)
|
||||
|
||||
def assert_pass(graph):
|
||||
# all_reduces need to remain in order!
|
||||
@ -429,7 +451,9 @@ graph():
|
||||
grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1)
|
||||
return grad3, grad2, grad1
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size, fake_pg=True):
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, self.backend(device_type), fake_pg=True
|
||||
):
|
||||
fn(g1, g2, g3)
|
||||
|
||||
def test_nccl_heuristics(self):
|
||||
|
@ -17,7 +17,9 @@ from torch.distributed.tensor.parallel import (
|
||||
)
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.common_distributed import HAS_ACCELERATOR
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfHpu, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
@ -26,13 +28,16 @@ if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
device_type = get_devtype().type
|
||||
|
||||
|
||||
class TestFakePG(TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dist.destroy_process_group()
|
||||
try:
|
||||
dist.destroy_process_group()
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
def test_all_reduce(self):
|
||||
store = FakeStore()
|
||||
@ -62,20 +67,21 @@ class TestFakePG(TestCase):
|
||||
dist.reduce_scatter(output_tensor, to_reduce_scatter)
|
||||
self.assertEqual(tuple(output_tensor.shape), (3, 3))
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, "No CUDA")
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
def test_construct_fsdp(self):
|
||||
store = FakeStore()
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||
FSDP(nn.Linear(2, 3, device="cuda"))
|
||||
FSDP(nn.Linear(2, 3, device=device_type))
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, "No CUDA")
|
||||
@skipIfHpu
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
def test_fsdp_fake_e2e(self):
|
||||
store = dist.HashStore()
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||
my_module = nn.Sequential(
|
||||
nn.Linear(2, 3, device="cuda"),
|
||||
nn.Linear(2, 3, device=device_type),
|
||||
nn.ReLU(),
|
||||
nn.Linear(3, 2, device="cuda"),
|
||||
nn.Linear(3, 2, device=device_type),
|
||||
)
|
||||
sharded_module = FSDP(my_module, use_orig_params=True)
|
||||
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
|
||||
@ -85,7 +91,8 @@ class TestFakePG(TestCase):
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, "No CUDA")
|
||||
@skipIfHpu
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
def test_fake_pg_tracing(self):
|
||||
store = dist.HashStore()
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||
@ -95,7 +102,7 @@ class TestFakePG(TestCase):
|
||||
def allgather_fn(tensor):
|
||||
return funcol.all_gather_tensor(tensor, 0, default_pg)
|
||||
|
||||
gm = make_fx(allgather_fn)(torch.randn(2, 2, device="cuda"))
|
||||
gm = make_fx(allgather_fn)(torch.randn(2, 2, device=device_type))
|
||||
FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph))
|
||||
|
||||
def test_broadcast(self):
|
||||
@ -165,7 +172,8 @@ class TestFakePG(TestCase):
|
||||
dist.recv(output, 1)
|
||||
self.assertEqual(tuple(output.shape), (3, 3))
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, "No CUDA or TP+FSDP")
|
||||
@skipIfHpu
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
def test_fsdp_tp_fake_e2e(self):
|
||||
world_size = 4
|
||||
tp_size = 2
|
||||
@ -175,9 +183,11 @@ class TestFakePG(TestCase):
|
||||
backend="fake", rank=0, world_size=world_size, store=store
|
||||
)
|
||||
|
||||
device_mesh = DeviceMesh("cuda", torch.arange(0, world_size).view(-1, tp_size))
|
||||
device_mesh = DeviceMesh(
|
||||
device_type, torch.arange(0, world_size).view(-1, tp_size)
|
||||
)
|
||||
device_mesh = init_device_mesh(
|
||||
"cuda", (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"]
|
||||
device_type, (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"]
|
||||
)
|
||||
|
||||
sequence_parallelize_plan = {
|
||||
@ -190,7 +200,7 @@ class TestFakePG(TestCase):
|
||||
}
|
||||
for parallel_plan in [sequence_parallelize_plan, pairwise_parallelize_plan]:
|
||||
my_module = parallelize_module(
|
||||
MLPModule(device="cuda"),
|
||||
MLPModule(device=device_type),
|
||||
device_mesh["tp"],
|
||||
parallel_plan,
|
||||
)
|
||||
@ -203,7 +213,7 @@ class TestFakePG(TestCase):
|
||||
for i in range(10):
|
||||
dp_rank = dist.get_rank()
|
||||
torch.manual_seed(i + dp_rank)
|
||||
input = torch.randn(20, 10).cuda(dist.get_rank())
|
||||
input = torch.randn(20, 10, device=f"{device_type}:{dp_rank}")
|
||||
x = sharded_module(input)
|
||||
loss = x.sum()
|
||||
loss.backward()
|
||||
|
@ -39,6 +39,7 @@ from torch.testing._internal.common_utils import (
|
||||
retry_on_connect_failures,
|
||||
skip_but_pass_in_sandcastle,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_CUDA,
|
||||
TEST_HPU,
|
||||
TEST_WITH_ROCM,
|
||||
TEST_WITH_TSAN,
|
||||
@ -55,6 +56,10 @@ from torch.testing._internal.distributed.multi_threaded_pg import (
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
ACCELERATOR_DIST_BACKENDS = ["nccl", "xccl", "hccl"]
|
||||
DDP_RANK_DEVICES = ["cuda", "xpu"]
|
||||
HAS_ACCELERATOR = TEST_CUDA or TEST_HPU or TEST_XPU
|
||||
|
||||
|
||||
class TestSkip(NamedTuple):
|
||||
exit_code: int
|
||||
@ -109,21 +114,25 @@ class DistTestCases:
|
||||
backend_feature["xpu"] = {"xccl"}
|
||||
|
||||
|
||||
def requires_ddp_rank(device):
|
||||
return device in DDP_RANK_DEVICES
|
||||
|
||||
|
||||
def skip_if_no_gpu(func):
|
||||
"""Skips if the world size exceeds the number of GPUs, ensuring that if the
|
||||
test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not torch.cuda.is_available():
|
||||
if not (TEST_CUDA or TEST_HPU or TEST_XPU):
|
||||
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
if torch.cuda.device_count() < world_size:
|
||||
if TEST_CUDA and torch.cuda.device_count() < world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
|
||||
if TEST_HPU and torch.hpu.device_count < world_size:
|
||||
if TEST_HPU and torch.hpu.device_count() < world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
|
||||
if TEST_XPU and torch.xpu.device_count() < world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
|
||||
if TEST_XPU and torch.xpu.device_count < world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-xpu-{world_size}"].exit_code)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@ -189,7 +198,13 @@ def import_transformers_or_skip():
|
||||
|
||||
|
||||
def at_least_x_gpu(x):
|
||||
return torch.cuda.is_available() and torch.cuda.device_count() >= x
|
||||
if TEST_CUDA and torch.cuda.device_count() >= x:
|
||||
return True
|
||||
if TEST_HPU and torch.hpu.device_count() >= x:
|
||||
return True
|
||||
if TEST_XPU and torch.xpu.device_count() >= x:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def skip_if_lt_x_gpu(x):
|
||||
@ -355,6 +370,35 @@ def requires_mpi():
|
||||
)
|
||||
|
||||
|
||||
def requires_accelerator_dist_backend(backends=None):
|
||||
"""
|
||||
Decorator to skip tests if no accelerator communication backend (NCCL, XCCL, HCCL) is available.
|
||||
|
||||
Args:
|
||||
backends (Optional[List[str]]): Specific accelerator backends to check (e.g., ["nccl", "xccl", "hccl"]).
|
||||
If None, checks all supported accelerator backends (NCCL, XCCL, HCCL).
|
||||
|
||||
Returns:
|
||||
callable: A decorator that skips the test if no specified accelerator backend is available.
|
||||
"""
|
||||
if backends is None:
|
||||
backends = ACCELERATOR_DIST_BACKENDS
|
||||
|
||||
backend_available = any(
|
||||
{
|
||||
"nccl": c10d.is_nccl_available,
|
||||
"xccl": c10d.is_xccl_available,
|
||||
"hccl": lambda: TEST_HPU,
|
||||
}.get(backend, lambda: False)()
|
||||
for backend in backends
|
||||
)
|
||||
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
not backend_available,
|
||||
f"No accelerator communication backend available among {backends}",
|
||||
)
|
||||
|
||||
|
||||
def requires_multicast_support():
|
||||
has_multicast_support = (
|
||||
torch.cuda.is_available()
|
||||
@ -968,9 +1012,14 @@ class MultiProcessTestCase(TestCase):
|
||||
class DistributedTestBase(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
os.environ["WORLD_SIZE"] = str(self.world_size)
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
torch.distributed.destroy_process_group()
|
||||
except AssertionError:
|
||||
pass
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
@ -986,12 +1035,14 @@ class DistributedTestBase(MultiProcessTestCase):
|
||||
else:
|
||||
return "gloo"
|
||||
|
||||
def create_pg(self, device):
|
||||
def create_pg(self, device, world_size=None):
|
||||
if world_size is None:
|
||||
world_size = self.world_size
|
||||
num_visible_devices = torch.get_device_module(device).device_count()
|
||||
store = torch.distributed.FileStore(self.file_name, num_visible_devices)
|
||||
torch.distributed.init_process_group(
|
||||
backend=self.backend(device),
|
||||
world_size=self.world_size,
|
||||
world_size=world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
@ -1404,7 +1455,9 @@ class SaveForwardInputsModel(nn.Module):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
|
||||
def _dynamo_dist_per_rank_init(
|
||||
rank, world_size, backend="nccl", init_pg=True, fake_pg=False
|
||||
):
|
||||
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
|
||||
# Just manually implement the most important part of the dynamo behavior to reset/clear.
|
||||
if not fake_pg:
|
||||
@ -1421,7 +1474,7 @@ def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
|
||||
store=store,
|
||||
)
|
||||
else:
|
||||
c10d.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
c10d.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.utils.counters.clear()
|
||||
try:
|
||||
@ -1465,7 +1518,7 @@ class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
|
||||
super().tearDownClass()
|
||||
|
||||
|
||||
class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
|
||||
class DynamoDistributedMultiProcTestCase(DistributedTestBase):
|
||||
"""
|
||||
Use this for tests that actually run on multiple GPUs.
|
||||
|
||||
@ -1476,20 +1529,9 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
|
||||
sparingly for integration tests.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return torch.cuda.device_count()
|
||||
return torch.accelerator.device_count()
|
||||
|
||||
@classmethod
|
||||
def _run(
|
||||
|
Reference in New Issue
Block a user