add device generalisation support for distributed tests (#152471)

### MOTIVATION
To generalize Distributed test cases for non-CUDA devices

### CHANGES

- test/distributed/optim/test_zero_redundancy_optimizer.py
- test/distributed/test_c10d_logger.py
- test/distributed/test_compute_comm_reordering.py

Replaced hard coded device names with get_devtype from torch.testing._internal.common_fsdp.
DistributedTestBase is used instead of MultiProcessTestCase, to make use of helper functions.

- torch/testing/_internal/common_distributed.py

extended common utility functions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152471
Approved by: https://github.com/d4l3k
This commit is contained in:
Hari Krishna Sai Kodali
2025-06-20 07:35:42 +00:00
committed by PyTorch MergeBot
parent 0aed855b2b
commit e1f28fe17b
5 changed files with 243 additions and 224 deletions

View File

@ -6,7 +6,6 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy import copy
import os
import sys import sys
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, cast from typing import Any, cast
@ -30,12 +29,23 @@ from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW, SGD from torch.optim import AdamW, SGD
from torch.testing._internal import common_distributed from torch.testing._internal.common_distributed import (
DistributedTestBase,
logger,
requires_accelerator_dist_backend,
requires_ddp_rank,
requires_gloo,
skip_if_lt_x_gpu,
skip_if_no_gpu,
skip_if_rocm_multiprocess,
skip_if_win32,
)
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
IS_WINDOWS,
parametrize, parametrize,
run_tests, run_tests,
skipIfHpu,
) )
@ -47,63 +57,24 @@ except ImportError:
HAS_TORCHVISION = False HAS_TORCHVISION = False
# Use GLOO on GPU when running CUDA + Windows device_type = str(get_devtype())
def _get_backend_for_tests():
return (
dist.Backend.NCCL
if not IS_WINDOWS and torch.cuda.is_available()
# Windows only has GLOO, but GLOO GPU works. And use GLOO CPU when
# no GPUs are available.
else dist.Backend.GLOO
)
BACKEND = _get_backend_for_tests() class TestZeroRedundancyOptimizer(DistributedTestBase):
class TestZeroRedundancyOptimizer(common_distributed.MultiProcessTestCase):
def setUp(self):
super().setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
self._spawn_processes()
@property @property
def device(self): def device(self):
return ( return device_type
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
@property @property
def world_size(self): def world_size(self):
return 1 return 1
def tearDown(self):
try:
torch.distributed.destroy_process_group()
except AssertionError:
pass
try:
os.remove(self.file_name)
except OSError:
pass
def dist_init(self, rank, world_size=-1, backend=BACKEND):
if world_size < 1:
world_size = self.world_size
store = dist.FileStore(self.file_name, world_size)
return dist.init_process_group(
backend=backend,
store=store,
rank=rank,
world_size=world_size,
)
class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer): class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_state_dict(self): def test_state_dict(self):
"""Check that ZeroRedundancyOptimizer exposes the expected state dict """Check that ZeroRedundancyOptimizer exposes the expected state dict
interface, irrespective of the sharding.""" interface, irrespective of the sharding."""
self.dist_init(self.rank) self.create_pg(self.device)
LR1 = 0.1 LR1 = 0.1
LR2 = 0.01 LR2 = 0.01
MOMENTUM = 0.9 MOMENTUM = 0.9
@ -171,7 +142,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_lr_scheduler(self): def test_lr_scheduler(self):
"""Check that a normal PyTorch ``lr_scheduler`` is usable with """Check that a normal PyTorch ``lr_scheduler`` is usable with
ZeroRedundancyOptimizer.""" ZeroRedundancyOptimizer."""
self.dist_init(self.rank) self.create_pg(self.device)
NUM_ITERS = 5 NUM_ITERS = 5
LR = 0.01 LR = 0.01
x = torch.tensor([1.0], device=self.device, requires_grad=True) x = torch.tensor([1.0], device=self.device, requires_grad=True)
@ -193,7 +164,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_step_with_kwargs(self): def test_step_with_kwargs(self):
"""Check that the ``step(**kwargs)`` interface is properly exposed.""" """Check that the ``step(**kwargs)`` interface is properly exposed."""
self.dist_init(self.rank) self.create_pg(self.device)
LR = 0.1 LR = 0.1
class SGDWithStepKWArg(torch.optim.SGD): class SGDWithStepKWArg(torch.optim.SGD):
@ -217,7 +188,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
"""Check that ZeroRedundancyOptimizer wrapping an optimizer that adds """Check that ZeroRedundancyOptimizer wrapping an optimizer that adds
extra keys to ``param_groups`` exposes those keys through ZeRO's own extra keys to ``param_groups`` exposes those keys through ZeRO's own
``param_groups``.""" ``param_groups``."""
self.dist_init(self.rank) self.create_pg(self.device)
LR = 0.1 LR = 0.1
class SGDWithNewKey(torch.optim.SGD): class SGDWithNewKey(torch.optim.SGD):
@ -236,7 +207,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_step_without_closure(self): def test_step_without_closure(self):
"""Check that the ``step()`` method (without closure) is handled as """Check that the ``step()`` method (without closure) is handled as
expected.""" expected."""
self.dist_init(self.rank) self.create_pg(self.device)
LR = 0.1 LR = 0.1
class SGDWithoutClosure(torch.optim.SGD): class SGDWithoutClosure(torch.optim.SGD):
@ -255,7 +226,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_zero_grad(self): def test_zero_grad(self):
"""Check that the ``zero_grad`` method is properly handled.""" """Check that the ``zero_grad`` method is properly handled."""
self.dist_init(self.rank) self.create_pg(self.device)
LR = 0.01 LR = 0.01
x = torch.rand(1) x = torch.rand(1)
m = torch.nn.Linear(1, 1) m = torch.nn.Linear(1, 1)
@ -271,7 +242,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
def test_constructor(self): def test_constructor(self):
"""Check the robustness of the ZeroRedundancyOptimizer constructor by """Check the robustness of the ZeroRedundancyOptimizer constructor by
passing different values for the ``params`` argument.""" passing different values for the ``params`` argument."""
self.dist_init(self.rank) self.create_pg(self.device)
LR = 0.01 LR = 0.01
m = torch.nn.Sequential( m = torch.nn.Sequential(
torch.nn.Linear(5, 10), torch.nn.Linear(5, 10),
@ -336,7 +307,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
NOTE: This test should be removed once support for sparse parameters NOTE: This test should be removed once support for sparse parameters
and varying parameter types is added. and varying parameter types is added.
""" """
self.dist_init(self.rank) self.create_pg(self.device)
LR = 0.01 LR = 0.01
inputs = [ inputs = [
[torch.sparse_coo_tensor(size=(2, 3))], [torch.sparse_coo_tensor(size=(2, 3))],
@ -353,25 +324,16 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
@property
def device(self):
return (
torch.device(self.rank)
if torch.cuda.is_available()
else torch.device("cpu")
)
@property @property
def world_size(self): def world_size(self):
return min(4, max(2, torch.cuda.device_count())) return min(4, max(2, torch.get_device_module(self.device).device_count()))
@property @property
def context(self): def context(self):
return ( if requires_ddp_rank(self.device):
nullcontext() return torch.get_device_module(self.device).device(self.rank)
if not torch.cuda.is_available() else:
else torch.cuda.device(self.rank) return nullcontext()
)
def _check_same_model_params( def _check_same_model_params(
self, self,
@ -396,12 +358,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message, msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message,
) )
@common_distributed.skip_if_no_gpu @skip_if_no_gpu
@common_distributed.skip_if_rocm_multiprocess @skip_if_rocm_multiprocess
def test_step(self): def test_step(self):
"""Check that ZeroRedundancyOptimizer properly exposes the ``step()`` """Check that ZeroRedundancyOptimizer properly exposes the ``step()``
interface.""" interface."""
self.dist_init(self.rank, world_size=self.world_size) self.create_pg(self.device)
LR = 0.01 LR = 0.01
with self.context: with self.context:
@ -436,13 +398,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
self.assertEqual(m.weight, m_zero.weight) self.assertEqual(m.weight, m_zero.weight)
self.assertEqual(m.bias, m_zero.bias) self.assertEqual(m.bias, m_zero.bias)
@common_distributed.skip_if_no_gpu @skip_if_no_gpu
@common_distributed.skip_if_rocm_multiprocess @skip_if_rocm_multiprocess
def test_step_with_closure(self): def test_step_with_closure(self):
"""Check that ZeroRedundancyOptimizer properly exposes the """Check that ZeroRedundancyOptimizer properly exposes the
``step(closure)`` interface.""" ``step(closure)`` interface."""
self.dist_init(self.rank, world_size=self.world_size) self.create_pg(self.device)
with self.context: with self.context:
for bucket_view in [False, True]: for bucket_view in [False, True]:
x_val = self.rank + 1 x_val = self.rank + 1
@ -487,11 +448,11 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
self.assertEqual(m.weight, torch.tensor([[1.1]])) self.assertEqual(m.weight, torch.tensor([[1.1]]))
self.assertEqual(m.bias, torch.tensor([2.1])) self.assertEqual(m.bias, torch.tensor([2.1]))
@common_distributed.skip_if_no_gpu @skip_if_no_gpu
def test_lr_scheduler(self): def test_lr_scheduler(self):
"""Check that a normal PyTorch ``lr_scheduler`` is usable with """Check that a normal PyTorch ``lr_scheduler`` is usable with
ZeroRedundancyOptimizer.""" ZeroRedundancyOptimizer."""
self.dist_init(self.rank) self.create_pg(self.device)
x = torch.tensor([1.0], device=self.device, requires_grad=True) x = torch.tensor([1.0], device=self.device, requires_grad=True)
x2 = torch.tensor([1.0], device=self.device, requires_grad=True) x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01) o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
@ -519,7 +480,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
``ZeroRedundancyOptimizer._partition_parameters()`` in ``ZeroRedundancyOptimizer._partition_parameters()`` in
zero_redundancy_optimizer.py. zero_redundancy_optimizer.py.
""" """
self.dist_init(self.rank) self.create_pg(self.device)
LR = 0.01 LR = 0.01
sizes = [9, 7, 5, 3] sizes = [9, 7, 5, 3]
params = [] params = []
@ -541,7 +502,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
``ZeroRedundancyOptimizer._partition_parameters()`` in ``ZeroRedundancyOptimizer._partition_parameters()`` in
zero_redundancy_optimizer.py. zero_redundancy_optimizer.py.
""" """
self.dist_init(self.rank) self.create_pg(self.device)
LR = 0.01 LR = 0.01
# Test with all parameters trainable to begin with # Test with all parameters trainable to begin with
@ -589,14 +550,14 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
all_trainable() all_trainable()
some_trainable() some_trainable()
@common_distributed.skip_if_no_gpu @skip_if_no_gpu
def test_multiple_param_groups(self): def test_multiple_param_groups(self):
""" """
Check parity between constructing ZeRO with multiple parameter groups Check parity between constructing ZeRO with multiple parameter groups
upfront versus adding parameter groups to ZeRO after construction upfront versus adding parameter groups to ZeRO after construction
versus a non-sharded optimizer. versus a non-sharded optimizer.
""" """
self.dist_init(self.rank) self.create_pg(self.device)
BATCH_SIZE, NUM_ITERS = 8, 3 BATCH_SIZE, NUM_ITERS = 8, 3
INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 10, 5 INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 10, 5
WD, LR = 0.01, 0.01 WD, LR = 0.01, 0.01
@ -656,12 +617,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
torch.testing.assert_close(layer1.bias, layer2.bias) torch.testing.assert_close(layer1.bias, layer2.bias)
torch.testing.assert_close(layer1.bias, layer3.bias) torch.testing.assert_close(layer1.bias, layer3.bias)
@common_distributed.skip_if_no_gpu @skip_if_no_gpu
@common_distributed.skip_if_rocm_multiprocess @skip_if_rocm_multiprocess
def test_collect_shards(self): def test_collect_shards(self):
"""Check the state consolidation mechanism and the state dict exposed """Check the state consolidation mechanism and the state dict exposed
by ZeroRedundancyOptimizer.""" by ZeroRedundancyOptimizer."""
self.dist_init(self.rank) self.create_pg(self.device)
LR = 1e-3 LR = 1e-3
MOMENTUM = 0.99 MOMENTUM = 0.99
BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 3, 20, 10, 5 BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 3, 20, 10, 5
@ -719,27 +680,25 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
# trivial # trivial
MIN_WORLD_SIZE = 4 MIN_WORLD_SIZE = 4
if self.world_size < MIN_WORLD_SIZE: if self.world_size < MIN_WORLD_SIZE:
common_distributed.logger.info( logger.info(
"Skipping `test_nondefault_process_group()` since world size " "Skipping `test_nondefault_process_group()` since world size "
"of %s is less than %s", "of %s is less than %s",
self.world_size, self.world_size,
MIN_WORLD_SIZE, MIN_WORLD_SIZE,
) )
return return
BACKEND = dist.Backend.GLOO # Use GPU if enough are available, or fall back to CPU otherwise
self.dist_init(self.rank, self.world_size, BACKEND) if torch.get_device_module(self.device).device_count() < self.world_size:
# Use GPU if enough are available, or fall back to CPU otherwise, which
# is fine since Gloo backend supports both
if torch.cuda.is_available() and torch.cuda.device_count() >= self.world_size:
device = torch.device(self.rank)
else:
device = torch.device("cpu") device = torch.device("cpu")
else:
device = torch.device(self.device)
self.create_pg(device.type)
# Create a new process group consisting of the even ranks to exercise # Create a new process group consisting of the even ranks to exercise
# the case where the global and local ranks do not necessarily match # the case where the global and local ranks do not necessarily match
subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0] subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0]
process_group = dist.new_group( process_group = dist.new_group(
ranks=subgroup_ranks, ranks=subgroup_ranks,
backend=BACKEND, backend=self.backend(device.type),
) )
# Ranks not participating in the new process group are no longer needed # Ranks not participating in the new process group are no longer needed
if self.rank not in subgroup_ranks: if self.rank not in subgroup_ranks:
@ -811,7 +770,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
) )
check(optimizer) check(optimizer)
@common_distributed.skip_if_no_gpu @skip_if_no_gpu
@parametrize( @parametrize(
"optimizer_class_str", "optimizer_class_str",
["Adam", "AdamW", "SGD"], ["Adam", "AdamW", "SGD"],
@ -828,7 +787,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
): ):
"""When combined with DDP, check that a local optimizer gives the same """When combined with DDP, check that a local optimizer gives the same
results as wrapping that optimizer with ZeroRedundancyOptimizer.""" results as wrapping that optimizer with ZeroRedundancyOptimizer."""
self.dist_init(self.rank) self.create_pg(self.device)
BATCHES = 20 BATCHES = 20
BATCH_SIZE = 64 BATCH_SIZE = 64
LR = 1e-3 LR = 1e-3
@ -867,7 +826,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
) )
sharded_ddp_model = DDP( sharded_ddp_model = DDP(
module=model, module=model,
device_ids=[self.rank], device_ids=[self.rank] if requires_ddp_rank(self.device) else None,
broadcast_buffers=True, broadcast_buffers=True,
find_unused_parameters=True, find_unused_parameters=True,
) )
@ -879,7 +838,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
) )
ddp_model = DDP( ddp_model = DDP(
local_model, local_model,
device_ids=[self.rank], device_ids=[self.rank] if requires_ddp_rank(self.device) else None,
broadcast_buffers=True, broadcast_buffers=True,
find_unused_parameters=True, find_unused_parameters=True,
) )
@ -892,7 +851,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
) )
def check_step(): def check_step():
input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM)) input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM)).to(self.device)
def closure_ddp(input_tensor=input_tensor): def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad() ddp_optimizer.zero_grad()
@ -970,13 +929,12 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
NUM_EPOCHS = 2 NUM_EPOCHS = 2
LR = 0.01 LR = 0.01
torch.manual_seed(0) torch.manual_seed(0)
torch.cuda.manual_seed(0) if "cpu" not in device:
torch.get_device_module(device).manual_seed(0)
rank = self.rank rank = self.rank
world_size = self.world_size world_size = self.world_size
is_gpu = device.type == "cuda" self.create_pg(device)
backend = _get_backend_for_tests() if is_gpu else dist.Backend.GLOO
self.dist_init(rank, world_size, backend)
model = torch.nn.Sequential( model = torch.nn.Sequential(
torch.nn.Linear(2, 3), torch.nn.Linear(2, 3),
@ -988,7 +946,9 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
# DDP ensures correct gradients in data parallel training, so DDP with # DDP ensures correct gradients in data parallel training, so DDP with
# local optimizers on uneven inputs should be equivalent to ZeRO on # local optimizers on uneven inputs should be equivalent to ZeRO on
# uneven inputs with gradients being manually set # uneven inputs with gradients being manually set
ddp_model = DDP(model, device_ids=[rank]) if is_gpu else DDP(model) ddp_model = (
DDP(model, device_ids=[rank]) if requires_ddp_rank(device) else DDP(model)
)
local_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR) local_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
zero_model = copy.deepcopy(model) zero_model = copy.deepcopy(model)
zero_model.to(device) zero_model.to(device)
@ -1111,27 +1071,28 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
) )
iter += 1 iter += 1
@common_distributed.requires_nccl() @requires_accelerator_dist_backend()
@common_distributed.skip_if_no_gpu @skip_if_no_gpu
def test_zero_join_gpu(self): def test_zero_join_gpu(self):
"""Check that the ZeRO join hook allows training with uneven inputs """Check that the ZeRO join hook allows training with uneven inputs
on GPU.""" on GPU."""
self._test_zero_join(self.device) self._test_zero_join(self.device)
@common_distributed.requires_gloo() @requires_gloo()
def test_zero_join_cpu(self): def test_zero_join_cpu(self):
"""Check that the ZeRO join hook allows training with uneven inputs """Check that the ZeRO join hook allows training with uneven inputs
on CPU.""" on CPU."""
self._test_zero_join(torch.device("cpu")) self._test_zero_join("cpu")
def _test_zero_model_parallel(self, parameters_as_bucket_view: bool): def _test_zero_model_parallel(self, parameters_as_bucket_view: bool, device: str):
# Use two processes each with two GPUs # Use two processes each with two GPUs
assert self.rank < 2 assert self.rank < 2
NUM_EPOCHS = 2 NUM_EPOCHS = 2
NUM_INPUTS = 4 NUM_INPUTS = 4
LR = 0.01 LR = 0.01
torch.manual_seed(0) torch.manual_seed(0)
torch.cuda.manual_seed(0) if "cpu" not in device:
torch.get_device_module(device).manual_seed(0)
class ModelParallelModel(torch.nn.Module): class ModelParallelModel(torch.nn.Module):
def __init__(self, dev0, dev1): def __init__(self, dev0, dev1):
@ -1221,7 +1182,8 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
atol=1e-04, atol=1e-04,
), "Models differ after a step" ), "Models differ after a step"
@common_distributed.skip_if_lt_x_gpu(4) @skipIfHpu
@skip_if_lt_x_gpu(4)
@parametrize( @parametrize(
"parameters_as_bucket_view", "parameters_as_bucket_view",
[False, True], [False, True],
@ -1234,8 +1196,8 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
layers are assigned to different devices.""" layers are assigned to different devices."""
if self.rank >= 2: if self.rank >= 2:
return return
self.dist_init(self.rank, world_size=2) self.create_pg(self.device, world_size=2)
self._test_zero_model_parallel(parameters_as_bucket_view) self._test_zero_model_parallel(parameters_as_bucket_view, self.device)
def _test_ddp_zero_overlap( def _test_ddp_zero_overlap(
self, self,
@ -1250,12 +1212,10 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
SGD_WEIGHT_DECAY = 0.001 SGD_WEIGHT_DECAY = 0.001
NUM_INPUTS = 5 NUM_INPUTS = 5
torch.manual_seed(0) torch.manual_seed(0)
torch.cuda.manual_seed(0) if "cpu" not in device:
torch.get_device_module(device).manual_seed(0)
rank = self.rank rank = self.rank
is_gpu = device.type == "cuda"
if is_gpu:
torch.cuda.set_device(device)
models_to_test = [ models_to_test = [
( (
torch.nn.Sequential( torch.nn.Sequential(
@ -1273,11 +1233,16 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
) )
) )
for model, inputs in models_to_test: for model, inputs in models_to_test:
# Enable determinism in cudnn operators # Select deterministic context based on device
with torch.backends.cudnn.flags( det_ctx = (
enabled=True, deterministic=True, benchmark=False torch.backends.cudnn.flags(
): enabled=True, deterministic=True, benchmark=False
device_ids = [rank] if is_gpu else None )
if "cuda" in device
else torch.use_deterministic_algorithms(True)
)
with det_ctx:
device_ids = [rank] if requires_ddp_rank(device) else None
# Set up the DDP model overlapping with ZeRO # Set up the DDP model overlapping with ZeRO
ddp_model_overlap = DDP( ddp_model_overlap = DDP(
copy.deepcopy(model).to(device), copy.deepcopy(model).to(device),
@ -1374,10 +1339,10 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
# NOTE: The test is skipped if using Windows since functional optimizers # NOTE: The test is skipped if using Windows since functional optimizers
# are not currently supported. # are not currently supported.
@common_distributed.skip_if_win32() @skip_if_win32()
@common_distributed.requires_nccl() @requires_accelerator_dist_backend()
@common_distributed.skip_if_no_gpu @skip_if_no_gpu
@common_distributed.skip_if_rocm_multiprocess @skip_if_rocm_multiprocess
@parametrize( @parametrize(
"use_gpu", "use_gpu",
[True], [True],
@ -1413,9 +1378,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
by ``hook_constructor`` and ``shard_buckets`` and using the given ZeRO by ``hook_constructor`` and ``shard_buckets`` and using the given ZeRO
and DDP arguments achieves parity with DDP using a local optimizer. and DDP arguments achieves parity with DDP using a local optimizer.
""" """
device = torch.device(self.rank) if use_gpu else torch.device("cpu") self.create_pg(self.device)
backend = _get_backend_for_tests()
self.dist_init(self.rank, self.world_size, backend)
hook_constructor = ( hook_constructor = (
hook_with_zero_step hook_with_zero_step
if not use_interleaved_hook if not use_interleaved_hook
@ -1423,7 +1386,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
) )
self._test_ddp_zero_overlap( self._test_ddp_zero_overlap(
device, self.device if use_gpu else "cpu",
hook_constructor, hook_constructor,
gradient_as_bucket_view, gradient_as_bucket_view,
static_graph, static_graph,

View File

@ -2,7 +2,6 @@
import json import json
import logging import logging
import os
import re import re
import sys import sys
from functools import partial, wraps from functools import partial, wraps
@ -16,10 +15,13 @@ if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr) print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0) sys.exit(0)
from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
device_type = str(get_devtype())
if TEST_WITH_DEV_DBG_ASAN: if TEST_WITH_DEV_DBG_ASAN:
print( print(
"Skip dev-asan as torch + multiprocessing spawn have known issues", "Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -27,8 +29,7 @@ if TEST_WITH_DEV_DBG_ASAN:
) )
sys.exit(0) sys.exit(0)
BACKEND = dist.Backend.NCCL WORLD_SIZE = min(4, max(2, torch.get_device_module(device_type).device_count()))
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
def with_comms(func=None): def with_comms(func=None):
@ -39,30 +40,16 @@ def with_comms(func=None):
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: if torch.get_device_module(device_type).device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
self.dist_init() self.create_pg(device_type)
func(self) func(self)
self.destroy_comms() self.destroy_comms()
return wrapper return wrapper
class C10dErrorLoggerTest(MultiProcessTestCase): class C10dErrorLoggerTest(DistributedTestBase):
def setUp(self):
super().setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["BACKEND"] = BACKEND
self._spawn_processes()
@property
def device(self):
return (
torch.device(self.rank)
if BACKEND == dist.Backend.NCCL
else torch.device("cpu")
)
@property @property
def world_size(self): def world_size(self):
return WORLD_SIZE return WORLD_SIZE
@ -76,18 +63,6 @@ class C10dErrorLoggerTest(MultiProcessTestCase):
dist.barrier() dist.barrier()
dist.destroy_process_group() dist.destroy_process_group()
def dist_init(self):
dist.init_process_group(
backend=BACKEND,
world_size=self.world_size,
rank=self.rank,
init_method=f"file://{self.file_name}",
)
# set device for nccl pg for collectives
if BACKEND == "nccl":
torch.cuda.set_device(self.rank)
def test_get_or_create_logger(self): def test_get_or_create_logger(self):
self.assertIsNotNone(_c10d_logger) self.assertIsNotNone(_c10d_logger)
self.assertEqual(1, len(_c10d_logger.handlers)) self.assertEqual(1, len(_c10d_logger.handlers))
@ -117,7 +92,11 @@ class C10dErrorLoggerTest(MultiProcessTestCase):
re.search("({.+})", captured.output[0]).group(0).replace("'", '"') re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
) )
self.assertEqual(len(error_msg_dict), 9) # NCCL adds additional nccl_version data to the error_msg_dict
if self.backend(device_type) == dist.Backend.NCCL:
self.assertEqual(len(error_msg_dict), 9)
else:
self.assertEqual(len(error_msg_dict), 8)
self.assertIn("pg_name", error_msg_dict.keys()) self.assertIn("pg_name", error_msg_dict.keys())
self.assertEqual("None", error_msg_dict["pg_name"]) self.assertEqual("None", error_msg_dict["pg_name"])
@ -126,13 +105,14 @@ class C10dErrorLoggerTest(MultiProcessTestCase):
self.assertEqual("broadcast", error_msg_dict["func_name"]) self.assertEqual("broadcast", error_msg_dict["func_name"])
self.assertIn("backend", error_msg_dict.keys()) self.assertIn("backend", error_msg_dict.keys())
self.assertEqual("nccl", error_msg_dict["backend"]) self.assertEqual(self.backend(device_type), error_msg_dict["backend"])
self.assertIn("nccl_version", error_msg_dict.keys()) if self.backend(device_type) == dist.Backend.NCCL:
nccl_ver = torch.cuda.nccl.version() self.assertIn("nccl_version", error_msg_dict.keys())
self.assertEqual( nccl_ver = torch.cuda.nccl.version()
".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"] self.assertEqual(
) ".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"]
)
# In this test case, group_size = world_size, since we don't have multiple processes on one node. # In this test case, group_size = world_size, since we don't have multiple processes on one node.
self.assertIn("group_size", error_msg_dict.keys()) self.assertIn("group_size", error_msg_dict.keys())

View File

@ -26,12 +26,16 @@ from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init, _dynamo_dist_per_rank_init,
at_least_x_gpu, at_least_x_gpu,
DynamoDistributedMultiProcTestCase, DynamoDistributedMultiProcTestCase,
requires_nccl, requires_accelerator_dist_backend,
) )
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_GPU from torch.testing._internal.inductor_utils import HAS_GPU
device_type = str(get_devtype())
def get_snode_runtime_for_reorder_compute_test(snode): def get_snode_runtime_for_reorder_compute_test(snode):
# NOTE: custom cost model to show that the compute reordering algorithm is working # NOTE: custom cost model to show that the compute reordering algorithm is working
# Collective kernels # Collective kernels
@ -74,7 +78,7 @@ def create_grouped_node_for_allreduce_and_its_deps(snodes):
return new_snode_order return new_snode_order
@requires_nccl() @requires_accelerator_dist_backend()
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
""" """
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
@ -113,9 +117,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return torch.matmul(ar, b) return torch.matmul(ar, b)
with _dynamo_dist_per_rank_init( with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
): ):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func) compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs) code = run_and_get_triton_code(compiled, inputs)
# Verify that the wait_tensor is sinked below the 1st matmul but # Verify that the wait_tensor is sinked below the 1st matmul but
@ -154,9 +161,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return torch.matmul(d, e) return torch.matmul(d, e)
with _dynamo_dist_per_rank_init( with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
): ):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func) compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs) code = run_and_get_triton_code(compiled, inputs)
# Verify that the all_reduce_ has been raised above the 2nd matmul # Verify that the all_reduce_ has been raised above the 2nd matmul
@ -202,9 +212,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return torch.mm(e, g) return torch.mm(e, g)
with _dynamo_dist_per_rank_init( with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
): ):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func) compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# Things to verify: # Things to verify:
@ -255,9 +268,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return (e,) return (e,)
with _dynamo_dist_per_rank_init( with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
): ):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func) compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: after scheduling the first all_reduce: # NOTE: after scheduling the first all_reduce:
@ -312,9 +328,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return (e,) return (e,)
with _dynamo_dist_per_rank_init( with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
): ):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func) compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: after scheduling the first all_reduce: # NOTE: after scheduling the first all_reduce:
@ -362,9 +381,12 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return (mm,) return (mm,)
with _dynamo_dist_per_rank_init( with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2) self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
): ):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func) compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# Expectations: # Expectations:
@ -387,9 +409,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
ranks = pg_info["ranks"] ranks = pg_info["ranks"]
group_size = pg_info["group_size"] group_size = pg_info["group_size"]
g1 = torch.ones(10, 10, device="cuda") g1 = torch.ones(10, 10, device=device_type)
g2 = torch.ones(11, 11, device="cuda") g2 = torch.ones(11, 11, device=device_type)
g3 = torch.ones(12, 12, device="cuda") g3 = torch.ones(12, 12, device=device_type)
def assert_pass(graph): def assert_pass(graph):
# all_reduces need to remain in order! # all_reduces need to remain in order!
@ -429,7 +451,9 @@ graph():
grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1) grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1)
return grad3, grad2, grad1 return grad3, grad2, grad1
with _dynamo_dist_per_rank_init(self.rank, self.world_size, fake_pg=True): with _dynamo_dist_per_rank_init(
self.rank, self.world_size, self.backend(device_type), fake_pg=True
):
fn(g1, g2, g3) fn(g1, g2, g3)
def test_nccl_heuristics(self): def test_nccl_heuristics(self):

View File

@ -17,7 +17,9 @@ from torch.distributed.tensor.parallel import (
) )
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_distributed import HAS_ACCELERATOR
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import run_tests, skipIfHpu, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
from torch.testing._internal.distributed.fake_pg import FakeStore from torch.testing._internal.distributed.fake_pg import FakeStore
@ -26,13 +28,16 @@ if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr) print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0) sys.exit(0)
HAS_CUDA = torch.cuda.is_available() device_type = get_devtype().type
class TestFakePG(TestCase): class TestFakePG(TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
dist.destroy_process_group() try:
dist.destroy_process_group()
except AssertionError:
pass
def test_all_reduce(self): def test_all_reduce(self):
store = FakeStore() store = FakeStore()
@ -62,20 +67,21 @@ class TestFakePG(TestCase):
dist.reduce_scatter(output_tensor, to_reduce_scatter) dist.reduce_scatter(output_tensor, to_reduce_scatter)
self.assertEqual(tuple(output_tensor.shape), (3, 3)) self.assertEqual(tuple(output_tensor.shape), (3, 3))
@unittest.skipIf(not HAS_CUDA, "No CUDA") @unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
def test_construct_fsdp(self): def test_construct_fsdp(self):
store = FakeStore() store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
FSDP(nn.Linear(2, 3, device="cuda")) FSDP(nn.Linear(2, 3, device=device_type))
@unittest.skipIf(not HAS_CUDA, "No CUDA") @skipIfHpu
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
def test_fsdp_fake_e2e(self): def test_fsdp_fake_e2e(self):
store = dist.HashStore() store = dist.HashStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
my_module = nn.Sequential( my_module = nn.Sequential(
nn.Linear(2, 3, device="cuda"), nn.Linear(2, 3, device=device_type),
nn.ReLU(), nn.ReLU(),
nn.Linear(3, 2, device="cuda"), nn.Linear(3, 2, device=device_type),
) )
sharded_module = FSDP(my_module, use_orig_params=True) sharded_module = FSDP(my_module, use_orig_params=True)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
@ -85,7 +91,8 @@ class TestFakePG(TestCase):
loss.backward() loss.backward()
optim.step() optim.step()
@unittest.skipIf(not HAS_CUDA, "No CUDA") @skipIfHpu
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
def test_fake_pg_tracing(self): def test_fake_pg_tracing(self):
store = dist.HashStore() store = dist.HashStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
@ -95,7 +102,7 @@ class TestFakePG(TestCase):
def allgather_fn(tensor): def allgather_fn(tensor):
return funcol.all_gather_tensor(tensor, 0, default_pg) return funcol.all_gather_tensor(tensor, 0, default_pg)
gm = make_fx(allgather_fn)(torch.randn(2, 2, device="cuda")) gm = make_fx(allgather_fn)(torch.randn(2, 2, device=device_type))
FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph)) FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph))
def test_broadcast(self): def test_broadcast(self):
@ -165,7 +172,8 @@ class TestFakePG(TestCase):
dist.recv(output, 1) dist.recv(output, 1)
self.assertEqual(tuple(output.shape), (3, 3)) self.assertEqual(tuple(output.shape), (3, 3))
@unittest.skipIf(not HAS_CUDA, "No CUDA or TP+FSDP") @skipIfHpu
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
def test_fsdp_tp_fake_e2e(self): def test_fsdp_tp_fake_e2e(self):
world_size = 4 world_size = 4
tp_size = 2 tp_size = 2
@ -175,9 +183,11 @@ class TestFakePG(TestCase):
backend="fake", rank=0, world_size=world_size, store=store backend="fake", rank=0, world_size=world_size, store=store
) )
device_mesh = DeviceMesh("cuda", torch.arange(0, world_size).view(-1, tp_size)) device_mesh = DeviceMesh(
device_type, torch.arange(0, world_size).view(-1, tp_size)
)
device_mesh = init_device_mesh( device_mesh = init_device_mesh(
"cuda", (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"] device_type, (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"]
) )
sequence_parallelize_plan = { sequence_parallelize_plan = {
@ -190,7 +200,7 @@ class TestFakePG(TestCase):
} }
for parallel_plan in [sequence_parallelize_plan, pairwise_parallelize_plan]: for parallel_plan in [sequence_parallelize_plan, pairwise_parallelize_plan]:
my_module = parallelize_module( my_module = parallelize_module(
MLPModule(device="cuda"), MLPModule(device=device_type),
device_mesh["tp"], device_mesh["tp"],
parallel_plan, parallel_plan,
) )
@ -203,7 +213,7 @@ class TestFakePG(TestCase):
for i in range(10): for i in range(10):
dp_rank = dist.get_rank() dp_rank = dist.get_rank()
torch.manual_seed(i + dp_rank) torch.manual_seed(i + dp_rank)
input = torch.randn(20, 10).cuda(dist.get_rank()) input = torch.randn(20, 10, device=f"{device_type}:{dp_rank}")
x = sharded_module(input) x = sharded_module(input)
loss = x.sum() loss = x.sum()
loss.backward() loss.backward()

View File

@ -39,6 +39,7 @@ from torch.testing._internal.common_utils import (
retry_on_connect_failures, retry_on_connect_failures,
skip_but_pass_in_sandcastle, skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if, skip_but_pass_in_sandcastle_if,
TEST_CUDA,
TEST_HPU, TEST_HPU,
TEST_WITH_ROCM, TEST_WITH_ROCM,
TEST_WITH_TSAN, TEST_WITH_TSAN,
@ -55,6 +56,10 @@ from torch.testing._internal.distributed.multi_threaded_pg import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
ACCELERATOR_DIST_BACKENDS = ["nccl", "xccl", "hccl"]
DDP_RANK_DEVICES = ["cuda", "xpu"]
HAS_ACCELERATOR = TEST_CUDA or TEST_HPU or TEST_XPU
class TestSkip(NamedTuple): class TestSkip(NamedTuple):
exit_code: int exit_code: int
@ -109,21 +114,25 @@ class DistTestCases:
backend_feature["xpu"] = {"xccl"} backend_feature["xpu"] = {"xccl"}
def requires_ddp_rank(device):
return device in DDP_RANK_DEVICES
def skip_if_no_gpu(func): def skip_if_no_gpu(func):
"""Skips if the world size exceeds the number of GPUs, ensuring that if the """Skips if the world size exceeds the number of GPUs, ensuring that if the
test is run, each rank has its own GPU via ``torch.cuda.device(rank)``.""" test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not torch.cuda.is_available(): if not (TEST_CUDA or TEST_HPU or TEST_XPU):
sys.exit(TEST_SKIPS["no_cuda"].exit_code) sys.exit(TEST_SKIPS["no_cuda"].exit_code)
world_size = int(os.environ["WORLD_SIZE"]) world_size = int(os.environ["WORLD_SIZE"])
if torch.cuda.device_count() < world_size: if TEST_CUDA and torch.cuda.device_count() < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
if TEST_HPU and torch.hpu.device_count < world_size: if TEST_HPU and torch.hpu.device_count() < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
if TEST_XPU and torch.xpu.device_count() < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
if TEST_XPU and torch.xpu.device_count < world_size:
sys.exit(TEST_SKIPS[f"multi-xpu-{world_size}"].exit_code)
return func(*args, **kwargs) return func(*args, **kwargs)
@ -189,7 +198,13 @@ def import_transformers_or_skip():
def at_least_x_gpu(x): def at_least_x_gpu(x):
return torch.cuda.is_available() and torch.cuda.device_count() >= x if TEST_CUDA and torch.cuda.device_count() >= x:
return True
if TEST_HPU and torch.hpu.device_count() >= x:
return True
if TEST_XPU and torch.xpu.device_count() >= x:
return True
return False
def skip_if_lt_x_gpu(x): def skip_if_lt_x_gpu(x):
@ -355,6 +370,35 @@ def requires_mpi():
) )
def requires_accelerator_dist_backend(backends=None):
"""
Decorator to skip tests if no accelerator communication backend (NCCL, XCCL, HCCL) is available.
Args:
backends (Optional[List[str]]): Specific accelerator backends to check (e.g., ["nccl", "xccl", "hccl"]).
If None, checks all supported accelerator backends (NCCL, XCCL, HCCL).
Returns:
callable: A decorator that skips the test if no specified accelerator backend is available.
"""
if backends is None:
backends = ACCELERATOR_DIST_BACKENDS
backend_available = any(
{
"nccl": c10d.is_nccl_available,
"xccl": c10d.is_xccl_available,
"hccl": lambda: TEST_HPU,
}.get(backend, lambda: False)()
for backend in backends
)
return skip_but_pass_in_sandcastle_if(
not backend_available,
f"No accelerator communication backend available among {backends}",
)
def requires_multicast_support(): def requires_multicast_support():
has_multicast_support = ( has_multicast_support = (
torch.cuda.is_available() torch.cuda.is_available()
@ -968,9 +1012,14 @@ class MultiProcessTestCase(TestCase):
class DistributedTestBase(MultiProcessTestCase): class DistributedTestBase(MultiProcessTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
self._spawn_processes() self._spawn_processes()
def tearDown(self): def tearDown(self):
try:
torch.distributed.destroy_process_group()
except AssertionError:
pass
try: try:
os.remove(self.file_name) os.remove(self.file_name)
except OSError: except OSError:
@ -986,12 +1035,14 @@ class DistributedTestBase(MultiProcessTestCase):
else: else:
return "gloo" return "gloo"
def create_pg(self, device): def create_pg(self, device, world_size=None):
if world_size is None:
world_size = self.world_size
num_visible_devices = torch.get_device_module(device).device_count() num_visible_devices = torch.get_device_module(device).device_count()
store = torch.distributed.FileStore(self.file_name, num_visible_devices) store = torch.distributed.FileStore(self.file_name, num_visible_devices)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=self.backend(device), backend=self.backend(device),
world_size=self.world_size, world_size=world_size,
rank=self.rank, rank=self.rank,
store=store, store=store,
) )
@ -1404,7 +1455,9 @@ class SaveForwardInputsModel(nn.Module):
@contextmanager @contextmanager
def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False): def _dynamo_dist_per_rank_init(
rank, world_size, backend="nccl", init_pg=True, fake_pg=False
):
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase, # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
# Just manually implement the most important part of the dynamo behavior to reset/clear. # Just manually implement the most important part of the dynamo behavior to reset/clear.
if not fake_pg: if not fake_pg:
@ -1421,7 +1474,7 @@ def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
store=store, store=store,
) )
else: else:
c10d.init_process_group("nccl", rank=rank, world_size=world_size) c10d.init_process_group(backend=backend, rank=rank, world_size=world_size)
torch._dynamo.reset() torch._dynamo.reset()
torch._dynamo.utils.counters.clear() torch._dynamo.utils.counters.clear()
try: try:
@ -1465,7 +1518,7 @@ class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
super().tearDownClass() super().tearDownClass()
class DynamoDistributedMultiProcTestCase(MultiProcessTestCase): class DynamoDistributedMultiProcTestCase(DistributedTestBase):
""" """
Use this for tests that actually run on multiple GPUs. Use this for tests that actually run on multiple GPUs.
@ -1476,20 +1529,9 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
sparingly for integration tests. sparingly for integration tests.
""" """
def setUp(self):
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
@property @property
def world_size(self) -> int: def world_size(self) -> int:
return torch.cuda.device_count() return torch.accelerator.device_count()
@classmethod @classmethod
def _run( def _run(