add device generalisation support for distributed tests (#152471)

### MOTIVATION
To generalize Distributed test cases for non-CUDA devices

### CHANGES

- test/distributed/optim/test_zero_redundancy_optimizer.py
- test/distributed/test_c10d_logger.py
- test/distributed/test_compute_comm_reordering.py

Replaced hard coded device names with get_devtype from torch.testing._internal.common_fsdp.
DistributedTestBase is used instead of MultiProcessTestCase, to make use of helper functions.

- torch/testing/_internal/common_distributed.py

extended common utility functions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152471
Approved by: https://github.com/d4l3k
This commit is contained in:
Hari Krishna Sai Kodali
2025-06-20 07:35:42 +00:00
committed by PyTorch MergeBot
parent 0aed855b2b
commit e1f28fe17b
5 changed files with 243 additions and 224 deletions

View File

@ -6,7 +6,6 @@
# LICENSE file in the root directory of this source tree.
import copy
import os
import sys
from contextlib import nullcontext
from typing import Any, cast
@ -30,12 +29,23 @@ from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW, SGD
from torch.testing._internal import common_distributed
from torch.testing._internal.common_distributed import (
DistributedTestBase,
logger,
requires_accelerator_dist_backend,
requires_ddp_rank,
requires_gloo,
skip_if_lt_x_gpu,
skip_if_no_gpu,
skip_if_rocm_multiprocess,
skip_if_win32,
)
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_WINDOWS,
parametrize,
run_tests,
skipIfHpu,
)
@ -47,63 +57,24 @@ except ImportError:
HAS_TORCHVISION = False
# Use GLOO on GPU when running CUDA + Windows
def _get_backend_for_tests():
return (
dist.Backend.NCCL
if not IS_WINDOWS and torch.cuda.is_available()
# Windows only has GLOO, but GLOO GPU works. And use GLOO CPU when
# no GPUs are available.
else dist.Backend.GLOO
)
device_type = str(get_devtype())
BACKEND = _get_backend_for_tests()
class TestZeroRedundancyOptimizer(common_distributed.MultiProcessTestCase):
def setUp(self):
super().setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
self._spawn_processes()
class TestZeroRedundancyOptimizer(DistributedTestBase):
@property
def device(self):
return (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
return device_type
@property
def world_size(self):
return 1
def tearDown(self):
try:
torch.distributed.destroy_process_group()
except AssertionError:
pass
try:
os.remove(self.file_name)
except OSError:
pass
def dist_init(self, rank, world_size=-1, backend=BACKEND):
if world_size < 1:
world_size = self.world_size
store = dist.FileStore(self.file_name, world_size)
return dist.init_process_group(
backend=backend,
store=store,
rank=rank,
world_size=world_size,
)
class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_state_dict(self):
"""Check that ZeroRedundancyOptimizer exposes the expected state dict
interface, irrespective of the sharding."""
self.dist_init(self.rank)
self.create_pg(self.device)
LR1 = 0.1
LR2 = 0.01
MOMENTUM = 0.9
@ -171,7 +142,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_lr_scheduler(self):
"""Check that a normal PyTorch ``lr_scheduler`` is usable with
ZeroRedundancyOptimizer."""
self.dist_init(self.rank)
self.create_pg(self.device)
NUM_ITERS = 5
LR = 0.01
x = torch.tensor([1.0], device=self.device, requires_grad=True)
@ -193,7 +164,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_step_with_kwargs(self):
"""Check that the ``step(**kwargs)`` interface is properly exposed."""
self.dist_init(self.rank)
self.create_pg(self.device)
LR = 0.1
class SGDWithStepKWArg(torch.optim.SGD):
@ -217,7 +188,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
"""Check that ZeroRedundancyOptimizer wrapping an optimizer that adds
extra keys to ``param_groups`` exposes those keys through ZeRO's own
``param_groups``."""
self.dist_init(self.rank)
self.create_pg(self.device)
LR = 0.1
class SGDWithNewKey(torch.optim.SGD):
@ -236,7 +207,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_step_without_closure(self):
"""Check that the ``step()`` method (without closure) is handled as
expected."""
self.dist_init(self.rank)
self.create_pg(self.device)
LR = 0.1
class SGDWithoutClosure(torch.optim.SGD):
@ -255,7 +226,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_zero_grad(self):
"""Check that the ``zero_grad`` method is properly handled."""
self.dist_init(self.rank)
self.create_pg(self.device)
LR = 0.01
x = torch.rand(1)
m = torch.nn.Linear(1, 1)
@ -271,7 +242,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_constructor(self):
"""Check the robustness of the ZeroRedundancyOptimizer constructor by
passing different values for the ``params`` argument."""
self.dist_init(self.rank)
self.create_pg(self.device)
LR = 0.01
m = torch.nn.Sequential(
torch.nn.Linear(5, 10),
@ -336,7 +307,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
NOTE: This test should be removed once support for sparse parameters
and varying parameter types is added.
"""
self.dist_init(self.rank)
self.create_pg(self.device)
LR = 0.01
inputs = [
[torch.sparse_coo_tensor(size=(2, 3))],
@ -353,25 +324,16 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
@property
def device(self):
return (
torch.device(self.rank)
if torch.cuda.is_available()
else torch.device("cpu")
)
@property
def world_size(self):
return min(4, max(2, torch.cuda.device_count()))
return min(4, max(2, torch.get_device_module(self.device).device_count()))
@property
def context(self):
return (
nullcontext()
if not torch.cuda.is_available()
else torch.cuda.device(self.rank)
)
if requires_ddp_rank(self.device):
return torch.get_device_module(self.device).device(self.rank)
else:
return nullcontext()
def _check_same_model_params(
self,
@ -396,12 +358,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message,
)
@common_distributed.skip_if_no_gpu
@common_distributed.skip_if_rocm_multiprocess
@skip_if_no_gpu
@skip_if_rocm_multiprocess
def test_step(self):
"""Check that ZeroRedundancyOptimizer properly exposes the ``step()``
interface."""
self.dist_init(self.rank, world_size=self.world_size)
self.create_pg(self.device)
LR = 0.01
with self.context:
@ -436,13 +398,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
self.assertEqual(m.weight, m_zero.weight)
self.assertEqual(m.bias, m_zero.bias)
@common_distributed.skip_if_no_gpu
@common_distributed.skip_if_rocm_multiprocess
@skip_if_no_gpu
@skip_if_rocm_multiprocess
def test_step_with_closure(self):
"""Check that ZeroRedundancyOptimizer properly exposes the
``step(closure)`` interface."""
self.dist_init(self.rank, world_size=self.world_size)
self.create_pg(self.device)
with self.context:
for bucket_view in [False, True]:
x_val = self.rank + 1
@ -487,11 +448,11 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
self.assertEqual(m.weight, torch.tensor([[1.1]]))
self.assertEqual(m.bias, torch.tensor([2.1]))
@common_distributed.skip_if_no_gpu
@skip_if_no_gpu
def test_lr_scheduler(self):
"""Check that a normal PyTorch ``lr_scheduler`` is usable with
ZeroRedundancyOptimizer."""
self.dist_init(self.rank)
self.create_pg(self.device)
x = torch.tensor([1.0], device=self.device, requires_grad=True)
x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
@ -519,7 +480,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
``ZeroRedundancyOptimizer._partition_parameters()`` in
zero_redundancy_optimizer.py.
"""
self.dist_init(self.rank)
self.create_pg(self.device)
LR = 0.01
sizes = [9, 7, 5, 3]
params = []
@ -541,7 +502,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
``ZeroRedundancyOptimizer._partition_parameters()`` in
zero_redundancy_optimizer.py.
"""
self.dist_init(self.rank)
self.create_pg(self.device)
LR = 0.01
# Test with all parameters trainable to begin with
@ -589,14 +550,14 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
all_trainable()
some_trainable()
@common_distributed.skip_if_no_gpu
@skip_if_no_gpu
def test_multiple_param_groups(self):
"""
Check parity between constructing ZeRO with multiple parameter groups
upfront versus adding parameter groups to ZeRO after construction
versus a non-sharded optimizer.
"""
self.dist_init(self.rank)
self.create_pg(self.device)
BATCH_SIZE, NUM_ITERS = 8, 3
INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 10, 5
WD, LR = 0.01, 0.01
@ -656,12 +617,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
torch.testing.assert_close(layer1.bias, layer2.bias)
torch.testing.assert_close(layer1.bias, layer3.bias)
@common_distributed.skip_if_no_gpu
@common_distributed.skip_if_rocm_multiprocess
@skip_if_no_gpu
@skip_if_rocm_multiprocess
def test_collect_shards(self):
"""Check the state consolidation mechanism and the state dict exposed
by ZeroRedundancyOptimizer."""
self.dist_init(self.rank)
self.create_pg(self.device)
LR = 1e-3
MOMENTUM = 0.99
BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 3, 20, 10, 5
@ -719,27 +680,25 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
# trivial
MIN_WORLD_SIZE = 4
if self.world_size < MIN_WORLD_SIZE:
common_distributed.logger.info(
logger.info(
"Skipping `test_nondefault_process_group()` since world size "
"of %s is less than %s",
self.world_size,
MIN_WORLD_SIZE,
)
return
BACKEND = dist.Backend.GLOO
self.dist_init(self.rank, self.world_size, BACKEND)
# Use GPU if enough are available, or fall back to CPU otherwise, which
# is fine since Gloo backend supports both
if torch.cuda.is_available() and torch.cuda.device_count() >= self.world_size:
device = torch.device(self.rank)
else:
# Use GPU if enough are available, or fall back to CPU otherwise
if torch.get_device_module(self.device).device_count() < self.world_size:
device = torch.device("cpu")
else:
device = torch.device(self.device)
self.create_pg(device.type)
# Create a new process group consisting of the even ranks to exercise
# the case where the global and local ranks do not necessarily match
subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0]
process_group = dist.new_group(
ranks=subgroup_ranks,
backend=BACKEND,
backend=self.backend(device.type),
)
# Ranks not participating in the new process group are no longer needed
if self.rank not in subgroup_ranks:
@ -811,7 +770,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
)
check(optimizer)
@common_distributed.skip_if_no_gpu
@skip_if_no_gpu
@parametrize(
"optimizer_class_str",
["Adam", "AdamW", "SGD"],
@ -828,7 +787,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
):
"""When combined with DDP, check that a local optimizer gives the same
results as wrapping that optimizer with ZeroRedundancyOptimizer."""
self.dist_init(self.rank)
self.create_pg(self.device)
BATCHES = 20
BATCH_SIZE = 64
LR = 1e-3
@ -867,7 +826,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
)
sharded_ddp_model = DDP(
module=model,
device_ids=[self.rank],
device_ids=[self.rank] if requires_ddp_rank(self.device) else None,
broadcast_buffers=True,
find_unused_parameters=True,
)
@ -879,7 +838,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
)
ddp_model = DDP(
local_model,
device_ids=[self.rank],
device_ids=[self.rank] if requires_ddp_rank(self.device) else None,
broadcast_buffers=True,
find_unused_parameters=True,
)
@ -892,7 +851,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
)
def check_step():
input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM))
input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM)).to(self.device)
def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad()
@ -970,13 +929,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
NUM_EPOCHS = 2
LR = 0.01
torch.manual_seed(0)
torch.cuda.manual_seed(0)
if "cpu" not in device:
torch.get_device_module(device).manual_seed(0)
rank = self.rank
world_size = self.world_size
is_gpu = device.type == "cuda"
backend = _get_backend_for_tests() if is_gpu else dist.Backend.GLOO
self.dist_init(rank, world_size, backend)
self.create_pg(device)
model = torch.nn.Sequential(
torch.nn.Linear(2, 3),
@ -988,7 +946,9 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
# DDP ensures correct gradients in data parallel training, so DDP with
# local optimizers on uneven inputs should be equivalent to ZeRO on
# uneven inputs with gradients being manually set
ddp_model = DDP(model, device_ids=[rank]) if is_gpu else DDP(model)
ddp_model = (
DDP(model, device_ids=[rank]) if requires_ddp_rank(device) else DDP(model)
)
local_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
zero_model = copy.deepcopy(model)
zero_model.to(device)
@ -1111,27 +1071,28 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
)
iter += 1
@common_distributed.requires_nccl()
@common_distributed.skip_if_no_gpu
@requires_accelerator_dist_backend()
@skip_if_no_gpu
def test_zero_join_gpu(self):
"""Check that the ZeRO join hook allows training with uneven inputs
on GPU."""
self._test_zero_join(self.device)
@common_distributed.requires_gloo()
@requires_gloo()
def test_zero_join_cpu(self):
"""Check that the ZeRO join hook allows training with uneven inputs
on CPU."""
self._test_zero_join(torch.device("cpu"))
self._test_zero_join("cpu")
def _test_zero_model_parallel(self, parameters_as_bucket_view: bool):
def _test_zero_model_parallel(self, parameters_as_bucket_view: bool, device: str):
# Use two processes each with two GPUs
assert self.rank < 2
NUM_EPOCHS = 2
NUM_INPUTS = 4
LR = 0.01
torch.manual_seed(0)
torch.cuda.manual_seed(0)
if "cpu" not in device:
torch.get_device_module(device).manual_seed(0)
class ModelParallelModel(torch.nn.Module):
def __init__(self, dev0, dev1):
@ -1221,7 +1182,8 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
atol=1e-04,
), "Models differ after a step"
@common_distributed.skip_if_lt_x_gpu(4)
@skipIfHpu
@skip_if_lt_x_gpu(4)
@parametrize(
"parameters_as_bucket_view",
[False, True],
@ -1234,8 +1196,8 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
layers are assigned to different devices."""
if self.rank >= 2:
return
self.dist_init(self.rank, world_size=2)
self._test_zero_model_parallel(parameters_as_bucket_view)
self.create_pg(self.device, world_size=2)
self._test_zero_model_parallel(parameters_as_bucket_view, self.device)
def _test_ddp_zero_overlap(
self,
@ -1250,12 +1212,10 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
SGD_WEIGHT_DECAY = 0.001
NUM_INPUTS = 5
torch.manual_seed(0)
torch.cuda.manual_seed(0)
if "cpu" not in device:
torch.get_device_module(device).manual_seed(0)
rank = self.rank
is_gpu = device.type == "cuda"
if is_gpu:
torch.cuda.set_device(device)
models_to_test = [
(
torch.nn.Sequential(
@ -1273,11 +1233,16 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
)
)
for model, inputs in models_to_test:
# Enable determinism in cudnn operators
with torch.backends.cudnn.flags(
# Select deterministic context based on device
det_ctx = (
torch.backends.cudnn.flags(
enabled=True, deterministic=True, benchmark=False
):
device_ids = [rank] if is_gpu else None
)
if "cuda" in device
else torch.use_deterministic_algorithms(True)
)
with det_ctx:
device_ids = [rank] if requires_ddp_rank(device) else None
# Set up the DDP model overlapping with ZeRO
ddp_model_overlap = DDP(
copy.deepcopy(model).to(device),
@ -1374,10 +1339,10 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
# NOTE: The test is skipped if using Windows since functional optimizers
# are not currently supported.
@common_distributed.skip_if_win32()
@common_distributed.requires_nccl()
@common_distributed.skip_if_no_gpu
@common_distributed.skip_if_rocm_multiprocess
@skip_if_win32()
@requires_accelerator_dist_backend()
@skip_if_no_gpu
@skip_if_rocm_multiprocess
@parametrize(
"use_gpu",
[True],
@ -1413,9 +1378,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
by ``hook_constructor`` and ``shard_buckets`` and using the given ZeRO
and DDP arguments achieves parity with DDP using a local optimizer.
"""
device = torch.device(self.rank) if use_gpu else torch.device("cpu")
backend = _get_backend_for_tests()
self.dist_init(self.rank, self.world_size, backend)
self.create_pg(self.device)
hook_constructor = (
hook_with_zero_step
if not use_interleaved_hook
@ -1423,7 +1386,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
)
self._test_ddp_zero_overlap(
device,
self.device if use_gpu else "cpu",
hook_constructor,
gradient_as_bucket_view,
static_graph,

View File

@ -2,7 +2,6 @@
import json
import logging
import os
import re
import sys
from functools import partial, wraps
@ -16,10 +15,13 @@ if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
device_type = str(get_devtype())
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -27,8 +29,7 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
BACKEND = dist.Backend.NCCL
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
WORLD_SIZE = min(4, max(2, torch.get_device_module(device_type).device_count()))
def with_comms(func=None):
@ -39,30 +40,16 @@ def with_comms(func=None):
@wraps(func)
def wrapper(self, *args, **kwargs):
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
if torch.get_device_module(device_type).device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
self.dist_init()
self.create_pg(device_type)
func(self)
self.destroy_comms()
return wrapper
class C10dErrorLoggerTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["BACKEND"] = BACKEND
self._spawn_processes()
@property
def device(self):
return (
torch.device(self.rank)
if BACKEND == dist.Backend.NCCL
else torch.device("cpu")
)
class C10dErrorLoggerTest(DistributedTestBase):
@property
def world_size(self):
return WORLD_SIZE
@ -76,18 +63,6 @@ class C10dErrorLoggerTest(MultiProcessTestCase):
dist.barrier()
dist.destroy_process_group()
def dist_init(self):
dist.init_process_group(
backend=BACKEND,
world_size=self.world_size,
rank=self.rank,
init_method=f"file://{self.file_name}",
)
# set device for nccl pg for collectives
if BACKEND == "nccl":
torch.cuda.set_device(self.rank)
def test_get_or_create_logger(self):
self.assertIsNotNone(_c10d_logger)
self.assertEqual(1, len(_c10d_logger.handlers))
@ -117,7 +92,11 @@ class C10dErrorLoggerTest(MultiProcessTestCase):
re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
)
# NCCL adds additional nccl_version data to the error_msg_dict
if self.backend(device_type) == dist.Backend.NCCL:
self.assertEqual(len(error_msg_dict), 9)
else:
self.assertEqual(len(error_msg_dict), 8)
self.assertIn("pg_name", error_msg_dict.keys())
self.assertEqual("None", error_msg_dict["pg_name"])
@ -126,8 +105,9 @@ class C10dErrorLoggerTest(MultiProcessTestCase):
self.assertEqual("broadcast", error_msg_dict["func_name"])
self.assertIn("backend", error_msg_dict.keys())
self.assertEqual("nccl", error_msg_dict["backend"])
self.assertEqual(self.backend(device_type), error_msg_dict["backend"])
if self.backend(device_type) == dist.Backend.NCCL:
self.assertIn("nccl_version", error_msg_dict.keys())
nccl_ver = torch.cuda.nccl.version()
self.assertEqual(

View File

@ -26,12 +26,16 @@ from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init,
at_least_x_gpu,
DynamoDistributedMultiProcTestCase,
requires_nccl,
requires_accelerator_dist_backend,
)
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_GPU
device_type = str(get_devtype())
def get_snode_runtime_for_reorder_compute_test(snode):
# NOTE: custom cost model to show that the compute reordering algorithm is working
# Collective kernels
@ -74,7 +78,7 @@ def create_grouped_node_for_allreduce_and_its_deps(snodes):
return new_snode_order
@requires_nccl()
@requires_accelerator_dist_backend()
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
"""
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
@ -113,9 +117,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return torch.matmul(ar, b)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs)
# Verify that the wait_tensor is sinked below the 1st matmul but
@ -154,9 +161,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return torch.matmul(d, e)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs)
# Verify that the all_reduce_ has been raised above the 2nd matmul
@ -202,9 +212,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return torch.mm(e, g)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# Things to verify:
@ -255,9 +268,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return (e,)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: after scheduling the first all_reduce:
@ -312,9 +328,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return (e,)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: after scheduling the first all_reduce:
@ -362,9 +381,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return (mm,)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# Expectations:
@ -387,9 +409,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
ranks = pg_info["ranks"]
group_size = pg_info["group_size"]
g1 = torch.ones(10, 10, device="cuda")
g2 = torch.ones(11, 11, device="cuda")
g3 = torch.ones(12, 12, device="cuda")
g1 = torch.ones(10, 10, device=device_type)
g2 = torch.ones(11, 11, device=device_type)
g3 = torch.ones(12, 12, device=device_type)
def assert_pass(graph):
# all_reduces need to remain in order!
@ -429,7 +451,9 @@ graph():
grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1)
return grad3, grad2, grad1
with _dynamo_dist_per_rank_init(self.rank, self.world_size, fake_pg=True):
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, self.backend(device_type), fake_pg=True
):
fn(g1, g2, g3)
def test_nccl_heuristics(self):

View File

@ -17,7 +17,9 @@ from torch.distributed.tensor.parallel import (
)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_distributed import HAS_ACCELERATOR
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import run_tests, skipIfHpu, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
from torch.testing._internal.distributed.fake_pg import FakeStore
@ -26,13 +28,16 @@ if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
HAS_CUDA = torch.cuda.is_available()
device_type = get_devtype().type
class TestFakePG(TestCase):
def tearDown(self):
super().tearDown()
try:
dist.destroy_process_group()
except AssertionError:
pass
def test_all_reduce(self):
store = FakeStore()
@ -62,20 +67,21 @@ class TestFakePG(TestCase):
dist.reduce_scatter(output_tensor, to_reduce_scatter)
self.assertEqual(tuple(output_tensor.shape), (3, 3))
@unittest.skipIf(not HAS_CUDA, "No CUDA")
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
def test_construct_fsdp(self):
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
FSDP(nn.Linear(2, 3, device="cuda"))
FSDP(nn.Linear(2, 3, device=device_type))
@unittest.skipIf(not HAS_CUDA, "No CUDA")
@skipIfHpu
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
def test_fsdp_fake_e2e(self):
store = dist.HashStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
my_module = nn.Sequential(
nn.Linear(2, 3, device="cuda"),
nn.Linear(2, 3, device=device_type),
nn.ReLU(),
nn.Linear(3, 2, device="cuda"),
nn.Linear(3, 2, device=device_type),
)
sharded_module = FSDP(my_module, use_orig_params=True)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
@ -85,7 +91,8 @@ class TestFakePG(TestCase):
loss.backward()
optim.step()
@unittest.skipIf(not HAS_CUDA, "No CUDA")
@skipIfHpu
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
def test_fake_pg_tracing(self):
store = dist.HashStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
@ -95,7 +102,7 @@ class TestFakePG(TestCase):
def allgather_fn(tensor):
return funcol.all_gather_tensor(tensor, 0, default_pg)
gm = make_fx(allgather_fn)(torch.randn(2, 2, device="cuda"))
gm = make_fx(allgather_fn)(torch.randn(2, 2, device=device_type))
FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph))
def test_broadcast(self):
@ -165,7 +172,8 @@ class TestFakePG(TestCase):
dist.recv(output, 1)
self.assertEqual(tuple(output.shape), (3, 3))
@unittest.skipIf(not HAS_CUDA, "No CUDA or TP+FSDP")
@skipIfHpu
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
def test_fsdp_tp_fake_e2e(self):
world_size = 4
tp_size = 2
@ -175,9 +183,11 @@ class TestFakePG(TestCase):
backend="fake", rank=0, world_size=world_size, store=store
)
device_mesh = DeviceMesh("cuda", torch.arange(0, world_size).view(-1, tp_size))
device_mesh = DeviceMesh(
device_type, torch.arange(0, world_size).view(-1, tp_size)
)
device_mesh = init_device_mesh(
"cuda", (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"]
device_type, (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"]
)
sequence_parallelize_plan = {
@ -190,7 +200,7 @@ class TestFakePG(TestCase):
}
for parallel_plan in [sequence_parallelize_plan, pairwise_parallelize_plan]:
my_module = parallelize_module(
MLPModule(device="cuda"),
MLPModule(device=device_type),
device_mesh["tp"],
parallel_plan,
)
@ -203,7 +213,7 @@ class TestFakePG(TestCase):
for i in range(10):
dp_rank = dist.get_rank()
torch.manual_seed(i + dp_rank)
input = torch.randn(20, 10).cuda(dist.get_rank())
input = torch.randn(20, 10, device=f"{device_type}:{dp_rank}")
x = sharded_module(input)
loss = x.sum()
loss.backward()

View File

@ -39,6 +39,7 @@ from torch.testing._internal.common_utils import (
retry_on_connect_failures,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
TEST_CUDA,
TEST_HPU,
TEST_WITH_ROCM,
TEST_WITH_TSAN,
@ -55,6 +56,10 @@ from torch.testing._internal.distributed.multi_threaded_pg import (
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ACCELERATOR_DIST_BACKENDS = ["nccl", "xccl", "hccl"]
DDP_RANK_DEVICES = ["cuda", "xpu"]
HAS_ACCELERATOR = TEST_CUDA or TEST_HPU or TEST_XPU
class TestSkip(NamedTuple):
exit_code: int
@ -109,21 +114,25 @@ class DistTestCases:
backend_feature["xpu"] = {"xccl"}
def requires_ddp_rank(device):
return device in DDP_RANK_DEVICES
def skip_if_no_gpu(func):
"""Skips if the world size exceeds the number of GPUs, ensuring that if the
test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
@wraps(func)
def wrapper(*args, **kwargs):
if not torch.cuda.is_available():
if not (TEST_CUDA or TEST_HPU or TEST_XPU):
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
world_size = int(os.environ["WORLD_SIZE"])
if torch.cuda.device_count() < world_size:
if TEST_CUDA and torch.cuda.device_count() < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
if TEST_HPU and torch.hpu.device_count < world_size:
if TEST_HPU and torch.hpu.device_count() < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
if TEST_XPU and torch.xpu.device_count() < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
if TEST_XPU and torch.xpu.device_count < world_size:
sys.exit(TEST_SKIPS[f"multi-xpu-{world_size}"].exit_code)
return func(*args, **kwargs)
@ -189,7 +198,13 @@ def import_transformers_or_skip():
def at_least_x_gpu(x):
return torch.cuda.is_available() and torch.cuda.device_count() >= x
if TEST_CUDA and torch.cuda.device_count() >= x:
return True
if TEST_HPU and torch.hpu.device_count() >= x:
return True
if TEST_XPU and torch.xpu.device_count() >= x:
return True
return False
def skip_if_lt_x_gpu(x):
@ -355,6 +370,35 @@ def requires_mpi():
)
def requires_accelerator_dist_backend(backends=None):
"""
Decorator to skip tests if no accelerator communication backend (NCCL, XCCL, HCCL) is available.
Args:
backends (Optional[List[str]]): Specific accelerator backends to check (e.g., ["nccl", "xccl", "hccl"]).
If None, checks all supported accelerator backends (NCCL, XCCL, HCCL).
Returns:
callable: A decorator that skips the test if no specified accelerator backend is available.
"""
if backends is None:
backends = ACCELERATOR_DIST_BACKENDS
backend_available = any(
{
"nccl": c10d.is_nccl_available,
"xccl": c10d.is_xccl_available,
"hccl": lambda: TEST_HPU,
}.get(backend, lambda: False)()
for backend in backends
)
return skip_but_pass_in_sandcastle_if(
not backend_available,
f"No accelerator communication backend available among {backends}",
)
def requires_multicast_support():
has_multicast_support = (
torch.cuda.is_available()
@ -968,9 +1012,14 @@ class MultiProcessTestCase(TestCase):
class DistributedTestBase(MultiProcessTestCase):
def setUp(self):
super().setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
self._spawn_processes()
def tearDown(self):
try:
torch.distributed.destroy_process_group()
except AssertionError:
pass
try:
os.remove(self.file_name)
except OSError:
@ -986,12 +1035,14 @@ class DistributedTestBase(MultiProcessTestCase):
else:
return "gloo"
def create_pg(self, device):
def create_pg(self, device, world_size=None):
if world_size is None:
world_size = self.world_size
num_visible_devices = torch.get_device_module(device).device_count()
store = torch.distributed.FileStore(self.file_name, num_visible_devices)
torch.distributed.init_process_group(
backend=self.backend(device),
world_size=self.world_size,
world_size=world_size,
rank=self.rank,
store=store,
)
@ -1404,7 +1455,9 @@ class SaveForwardInputsModel(nn.Module):
@contextmanager
def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
def _dynamo_dist_per_rank_init(
rank, world_size, backend="nccl", init_pg=True, fake_pg=False
):
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
# Just manually implement the most important part of the dynamo behavior to reset/clear.
if not fake_pg:
@ -1421,7 +1474,7 @@ def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
store=store,
)
else:
c10d.init_process_group("nccl", rank=rank, world_size=world_size)
c10d.init_process_group(backend=backend, rank=rank, world_size=world_size)
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
try:
@ -1465,7 +1518,7 @@ class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
super().tearDownClass()
class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
class DynamoDistributedMultiProcTestCase(DistributedTestBase):
"""
Use this for tests that actually run on multiple GPUs.
@ -1476,20 +1529,9 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
sparingly for integration tests.
"""
def setUp(self):
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
@property
def world_size(self) -> int:
return torch.cuda.device_count()
return torch.accelerator.device_count()
@classmethod
def _run(