Compare commits

...

4 Commits

9 changed files with 102 additions and 144 deletions

View File

@ -65,6 +65,7 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
curr_backend = dist.get_default_backend_for_device(device_type)
class SimpleModel(nn.Module):
@ -422,10 +423,10 @@ class TestFullyShard2DStateDict(DTensorTestBase):
@property
def backend(self):
# need to specify gloo backend for testing cpu offload
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
return f"cpu:gloo,{device_type}:{curr_backend}"
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_fully_shard_tp_2d_set_full_state_dict(self):
dummy_model = SimpleModel().to(device_type)
mesh_2d = init_device_mesh(
@ -514,8 +515,8 @@ class Test2dFSDP1ParallelIntegration(DTensorTestBase):
).to_local()
self.assertEqual(param_m2, param_m1)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_2d_ddp_integration_functionality(self) -> None:
model, twod_model, dp_pg = self.init_model(self.device_type)
optim = torch.optim.Adam(model.parameters(), lr=3e-5)
@ -566,8 +567,8 @@ class TestNew2dParallelTraining(DTensorTestBase):
p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local()
self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_2d_fsdp_state_enable_extension(self):
mesh_2d = init_device_mesh(
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
@ -642,18 +643,18 @@ class TestNew2dParallelTraining(DTensorTestBase):
# Ensure all params are still the same after optimizer update.
self._compare_params(model, model_2d)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_2d_e2e_training_default(self):
self._test_2d_e2e_training()
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_2d_e2e_training_use_orig_params(self):
self._test_2d_e2e_training(use_orig_params=True)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_2d_e2e_training_not_use_orig_params(self):
# TODO: need to revisit input_reshard API about why it failed multi-gpu tests.
# self._test_2d_e2e_training(recompute_activation=True)
@ -666,10 +667,10 @@ class TestNew2dParallelStateDict(DTensorTestBase):
@property
def backend(self):
# need to specify gloo backend for testing cpu offload
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
return f"cpu:gloo,{device_type}:{curr_backend}"
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
def test_fsdp_2d_extension(self):
"""
Test whether _fsdp_extension from FSDPstate has been set correctly.
@ -700,8 +701,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
model_1d_fsdp_state = _get_module_fsdp_state(model_1d)
self.assertEqual(model_1d_fsdp_state._fsdp_extension, None)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
@parametrize("is_even_sharded_model", [True, False])
def test_2d_state_dict(self, is_even_sharded_model):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@ -756,8 +757,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True
)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
@parametrize("is_even_sharded_model", [True, False])
def test_2d_load_state_dict(self, is_even_sharded_model):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@ -811,8 +812,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
self.assertEqual(v1.device_mesh, v2.device_mesh)
self.assertEqual(v1.placements, v2.placements)
@with_comms
@skip_if_lt_x_gpu(4)
@with_comms
@parametrize("is_even_sharded_model", [True, False])
def test_2d_optim_state_dict(self, is_even_sharded_model):
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@ -899,9 +900,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
else:
self.assertEqual(new_state, state)
@skip_if_lt_x_gpu(4)
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_fsdp1_tp_2d_set_full_state_dict(self):
"""
This is a workaround for loading full state dict into a FSDP1+TP 2D model.

View File

@ -29,8 +29,8 @@ from torch.distributed.tensor.parallel import (
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
at_least_x_gpu,
MultiProcessTestCase,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
@ -40,7 +40,6 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_XPU,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
@ -107,11 +106,9 @@ class ComposabilityTest(MultiProcessTestCase):
def device(self):
return self.rank
@requires_accelerator_dist_backend(["nccl", "xccl"])
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs"
)
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
def test_pp_and_dcp(self):
"""
Test that pipeline parallelism and distributed checkpointing can be used together and
@ -201,11 +198,9 @@ class ComposabilityTest(MultiProcessTestCase):
_dcp_test(self)
@requires_accelerator_dist_backend(["nccl", "xccl"])
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
@parametrize(
"ScheduleClass",
[
@ -355,11 +350,9 @@ class ComposabilityTest(MultiProcessTestCase):
torch.distributed.destroy_process_group()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
@parametrize(
"ScheduleClass",
[
@ -552,11 +545,9 @@ class ComposabilityTest(MultiProcessTestCase):
torch.distributed.destroy_process_group()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
@parametrize(
"ScheduleClass",
[

View File

@ -1,6 +1,5 @@
# Owner(s): ["oncall: distributed"]
import os
import sys
import torch
@ -18,8 +17,8 @@ from torch.distributed.algorithms.ddp_comm_hooks import (
)
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
DistributedTestBase,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
@ -30,9 +29,12 @@ if TEST_WITH_DEV_DBG_ASAN:
sys.exit(0)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
def gpus_for_rank(world_size):
visible_devices = list(range(torch.cuda.device_count()))
gpus_per_process = torch.cuda.device_count() // world_size
visible_devices = list(range(torch.accelerator.device_count()))
gpus_per_process = torch.accelerator.device_count() // world_size
gpus_for_rank = []
for rank in range(world_size):
gpus_for_rank.append(
@ -60,27 +62,7 @@ class TestDdpCommHook(nn.Module):
return self.t0(x ** (1 + rank))
class DistributedDataParallelCommHookTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
self._spawn_processes()
def tearDown(self):
try:
os.remove(self.file_name)
except OSError:
pass
def _get_process_group_nccl(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
return dist.distributed_c10d._get_default_group()
class DistributedDataParallelCommHookTest(DistributedTestBase):
@property
def world_size(self):
return 2
@ -119,14 +101,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
param = next(model.parameters())
return param.grad
@requires_nccl()
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_allreduce_hook(self):
"""
This unit test verifies the ``allreduce`` hook registered case gives same result
with no hook registered case.
"""
process_group = self._get_process_group_nccl()
process_group = self.create_pg(device_type)
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -135,14 +117,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
@requires_nccl()
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_fp16compress_hook(self):
"""
This unit test verifies the ``fp16 compress`` hook registered case
gives close result with no hook registered case.
"""
process_group = self._get_process_group_nccl()
process_group = self.create_pg(device_type)
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -151,14 +133,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
@requires_nccl()
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_quantize_per_tensor_hook(self):
"""
This unit test verifies the ``quantize per tensor`` hook registered case
gives close result with no hook registered case.
"""
process_group = self._get_process_group_nccl()
process_group = self.create_pg(device_type)
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -167,14 +149,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
@requires_nccl()
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_quantize_per_channel_hook(self):
"""
This unit test verifies the ``quantize per channel`` hook registered case
gives close result with no hook registered case.
"""
process_group = self._get_process_group_nccl()
process_group = self.create_pg(device_type)
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -185,14 +167,14 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
@requires_nccl()
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
def test_ddp_comm_hook_noop_hook(self):
"""
This unit test verifies the ``noop`` hook registered case and a subsequent allreduce
gives same result with no hook registered case.
"""
process_group = self._get_process_group_nccl()
process_group = self.create_pg(device_type)
# No hook registered case, get the reference grads.
reference_grads = self._get_grads(process_group, None)
@ -204,10 +186,10 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
@requires_nccl()
@requires_accelerator_dist_backend()
@skip_if_lt_x_gpu(2)
def test_is_last_hook(self):
process_group = self._get_process_group_nccl()
process_group = self.create_pg(device_type)
def hook(flags, bucket):
flags.append(bucket.is_last())

View File

@ -32,7 +32,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
class TestStateDictUtils(DTensorTestBase):
@property
def world_size(self):
return min(4, torch.cuda.device_count())
return min(4, torch.accelerator.device_count())
@with_comms
@skip_if_lt_x_gpu(2)
@ -49,7 +49,7 @@ class TestStateDictUtils(DTensorTestBase):
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
self.assertTrue(gathered_state_dict["dtensor"].is_cuda)
self.assertEqual(gathered_state_dict["dtensor"].device.type, self.device_type)
@with_comms
@skip_if_lt_x_gpu(4)
@ -69,14 +69,16 @@ class TestStateDictUtils(DTensorTestBase):
)
if dist.get_rank() in (0, 2):
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
self.assertFalse(gathered_state_dict["dtensor"].is_cuda)
self.assertNotEqual(
gathered_state_dict["dtensor"].device.type, self.device_type
)
else:
self.assertEqual(gathered_state_dict, {})
@with_comms
@skip_if_lt_x_gpu(4)
def test_cpu_and_ranks_only(self):
device = torch.device("cuda")
device = torch.device(self.device_type)
state_dict = {
"tensor1": torch.arange(10, device=device),
"tensor2": torch.ones(10, device=device),
@ -85,7 +87,7 @@ class TestStateDictUtils(DTensorTestBase):
cpu_state_dict = _offload_state_dict_to_cpu(state_dict, ranks_only=(0, 2))
if dist.get_rank() in (0, 2):
for v in cpu_state_dict.values():
self.assertFalse(v.is_cuda)
self.assertNotEqual(v.device.type, self.device_type)
self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10))
self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10))
else:
@ -109,27 +111,27 @@ class TestStateDictUtils(DTensorTestBase):
for _ in range(10):
tensor, dtensor = create_dtensor()
ltensor.append(tensor)
ltensor.append(torch.ones(10, device=torch.device("cuda")))
ltensor.append(torch.ones(10, device=torch.device(self.device_type)))
ldtensor.append(dtensor)
ldtensor.append(torch.ones(10, device=torch.device("cuda")))
ldtensor.append(torch.ones(10, device=torch.device(self.device_type)))
tensor, dtensor = create_dtensor()
dist_state_dict = {
"local": dtensor,
"list": ldtensor,
"arange": torch.arange(10, device=torch.device("cuda")),
"arange": torch.arange(10, device=torch.device(self.device_type)),
}
state_dict = {
"local": tensor,
"list": ltensor,
"arange": torch.arange(10, device=torch.device("cuda")),
"arange": torch.arange(10, device=torch.device(self.device_type)),
}
self.assertEqual(state_dict, _gather_state_dict(dist_state_dict))
@with_comms
@skip_if_lt_x_gpu(2)
def test_create_cpu_state_dict(self):
device = torch.device("cuda")
device = torch.device(self.device_type)
rank = dist.get_rank()
# Scale tensors based on world size
# to fit in the tensor shards accurately.
@ -149,7 +151,7 @@ class TestStateDictUtils(DTensorTestBase):
metadata=ShardMetadata(
shard_offsets=[5 * rank, 0],
shard_sizes=[5, 10],
placement=f"rank:{rank}/cuda:{rank}",
placement=f"rank:{rank}/{self.device_type}:{rank}",
),
)
],
@ -159,7 +161,7 @@ class TestStateDictUtils(DTensorTestBase):
torch.arange(50 * scale_factor, device=device).reshape(
5 * scale_factor, 10
),
init_device_mesh("cuda", mesh_shape=(self.world_size,)),
init_device_mesh(self.device_type, mesh_shape=(self.world_size,)),
[Shard(0)],
),
"non_tensor_bytes_io": copy.deepcopy(buffer),
@ -245,7 +247,7 @@ class TestStateDictUtils(DTensorTestBase):
even_tensor = torch.randn(self.world_size, 2)
uneven_tensor = torch.randn(1, 2)
mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
even_dtensor = distribute_tensor(
torch.randn(self.world_size, 2), mesh, [Shard(0)]
)
@ -273,10 +275,10 @@ class TestStateDictUtils(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
def test_cpu_offload_for_dtensor(self):
device_mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
sd = {
"k": DTensor.from_local(
torch.ones(8, 8, device="cuda"), device_mesh, [Shard(0)]
torch.ones(8, 8, device=self.device_type), device_mesh, [Shard(0)]
)
}
cpu_sd = _create_cpu_state_dict(sd)
@ -290,12 +292,12 @@ class TestStateDictUtils(DTensorTestBase):
self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
_copy_state_dict(sd, cpu_sd, non_blocking=True)
torch.cuda.synchronize()
torch.accelerator.synchronize()
self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
sd["k"] += 1
self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
_copy_state_dict(sd, cpu_sd, non_blocking=True)
torch.cuda.synchronize()
torch.accelerator.synchronize()
self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"]))

View File

@ -7,7 +7,7 @@
import copy
import sys
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from typing import Any, cast
import numpy as np
@ -40,7 +40,6 @@ from torch.testing._internal.common_distributed import (
skip_if_rocm_multiprocess,
skip_if_win32,
)
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -57,7 +56,17 @@ except ImportError:
HAS_TORCHVISION = False
device_type = str(get_devtype())
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
@contextmanager
def deterministic_algorithms(enabled=True):
prev_state = torch.are_deterministic_algorithms_enabled()
torch.use_deterministic_algorithms(enabled)
try:
yield
finally:
torch.use_deterministic_algorithms(prev_state)
class TestZeroRedundancyOptimizer(DistributedTestBase):
@ -1241,7 +1250,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
enabled=True, deterministic=True, benchmark=False
)
if "cuda" in device
else torch.use_deterministic_algorithms(True)
else deterministic_algorithms(True)
)
with det_ctx:
device_ids = [rank] if requires_ddp_rank(device) else None

View File

@ -24,7 +24,7 @@ from torch.distributed._functional_collectives import (
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_device_type import e4m3_type
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
DistributedTestBase,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
@ -59,12 +59,8 @@ if not dist.is_available():
sys.exit(0)
@requires_accelerator_dist_backend(["nccl", "xccl"])
class TestWithNCCL(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
@requires_accelerator_dist_backend()
class TestWithNCCL(DistributedTestBase):
@property
def world_size(self) -> int:
return 2
@ -78,16 +74,7 @@ class TestWithNCCL(MultiProcessTestCase):
return torch.device(self.rank)
def _init_process_group(self) -> None:
torch.accelerator.set_device_index(self.rank)
store = dist.FileStore(self.file_name, self.world_size)
backend = dist.get_default_backend_for_device(self.device.type)
dist.init_process_group(
backend=backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)
self.create_pg(self.device.type)
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
@skip_if_lt_x_gpu(2)

View File

@ -11,13 +11,10 @@ if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
from torch.testing._internal.common_utils import (
run_tests,
skipIfHpu,
TEST_CUDA,
TEST_HPU,
TEST_WITH_DEV_DBG_ASAN,
)
@ -29,16 +26,8 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
if TEST_HPU:
DEVICE = "hpu"
elif TEST_CUDA:
DEVICE = "cuda"
else:
DEVICE = "cpu"
device_module = torch.get_device_module(DEVICE)
device_count = device_module.device_count()
BACKEND = dist.get_default_backend_for_device(DEVICE)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
device_count = torch.accelerator.device_count()
def with_comms(func=None):
@ -49,11 +38,10 @@ def with_comms(func=None):
@wraps(func)
def wrapper(self, *args, **kwargs):
if DEVICE != "cpu" and device_count < self.world_size:
if device_type != "cpu" and device_count < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
kwargs["device"] = DEVICE
self.pg = self.create_pg(device=DEVICE)
self.pg = self.create_pg(device=device_type)
try:
return func(self, *args, **kwargs)
finally:
@ -64,7 +52,7 @@ def with_comms(func=None):
class TestObjectCollectives(DistributedTestBase):
@with_comms()
def test_all_gather_object(self, device):
def test_all_gather_object(self):
output = [None] * dist.get_world_size()
dist.all_gather_object(object_list=output, obj=self.rank)
@ -72,7 +60,7 @@ class TestObjectCollectives(DistributedTestBase):
self.assertEqual(i, v, f"rank: {self.rank}")
@with_comms()
def test_gather_object(self, device):
def test_gather_object(self):
output = [None] * dist.get_world_size() if self.rank == 0 else None
dist.gather_object(obj=self.rank, object_gather_list=output)
@ -82,7 +70,7 @@ class TestObjectCollectives(DistributedTestBase):
@skipIfHpu
@with_comms()
def test_send_recv_object_list(self, device):
def test_send_recv_object_list(self):
val = 99 if self.rank == 0 else None
object_list = [val] * dist.get_world_size()
if self.rank == 0:
@ -96,7 +84,7 @@ class TestObjectCollectives(DistributedTestBase):
self.assertEqual(None, object_list[0])
@with_comms()
def test_broadcast_object_list(self, device):
def test_broadcast_object_list(self):
val = 99 if self.rank == 0 else None
object_list = [val] * dist.get_world_size()
# TODO test with broadcast_object_list's device argument
@ -105,7 +93,7 @@ class TestObjectCollectives(DistributedTestBase):
self.assertEqual(99, object_list[0])
@with_comms()
def test_scatter_object_list(self, device):
def test_scatter_object_list(self):
input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
output_list = [None]
dist.scatter_object_list(
@ -123,34 +111,30 @@ class TestObjectCollectives(DistributedTestBase):
my_pg = dist.new_group(ranks, use_local_synchronization=True)
return rank, ranks, my_pg
@skipIfHpu
@with_comms()
def test_subpg_scatter_object(self, device):
def test_subpg_scatter_object(self):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None]
dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
self.assertEqual(rank, out_list[0])
@skipIfHpu
@with_comms()
def test_subpg_all_gather_object(self, device):
def test_subpg_all_gather_object(self):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None] * len(ranks)
dist.all_gather_object(out_list, rank, group=my_pg)
self.assertEqual(ranks, out_list)
@skipIfHpu
@with_comms()
def test_subpg_gather_object(self, device):
def test_subpg_gather_object(self):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None] * len(ranks) if rank == ranks[0] else None
dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
if rank == ranks[0]:
self.assertEqual(ranks, out_list)
@skipIfHpu
@with_comms()
def test_subpg_broadcast_object(self, device):
def test_subpg_broadcast_object(self):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None]
if rank == ranks[0]:
@ -159,7 +143,5 @@ class TestObjectCollectives(DistributedTestBase):
self.assertEqual(ranks[0], out_list[0])
devices = ("cpu", "cuda", "hpu")
instantiate_device_type_tests(TestObjectCollectives, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -29,7 +29,7 @@ from torch.distributed.tensor._collective_utils import (
)
from torch.distributed.tensor.placement_types import _Partial, Shard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
from torch.testing._internal.common_utils import run_tests, TEST_HPU, TEST_XPU, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
@ -58,7 +58,7 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_ran
os.environ["LOCAL_RANK"] = f"{local_rank}"
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.")
@unittest.skipIf(TEST_XPU or TEST_HPU, "XPU/HPU does not support gloo backend.")
class DeviceMeshTestGlooBackend(DTensorTestBase):
@property
def backend(self):

View File

@ -40,6 +40,7 @@ from torch.distributed.tensor.parallel import (
SequenceParallel,
)
from torch.testing._internal.common_distributed import (
ACCELERATOR_DIST_BACKENDS,
MultiProcContinuousTest,
MultiProcessTestCase,
MultiThreadedTestCase,
@ -393,14 +394,17 @@ class DTensorTestBase(MultiProcessTestCase):
return init_device_mesh(self.device_type, (self.world_size,))
def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
if "nccl" in self.backend and torch.cuda.device_count() < self.world_size:
if backend is None:
backend = self.backend
requires_gpu = any(
gpu_backend in backend for gpu_backend in ACCELERATOR_DIST_BACKENDS
)
if requires_gpu and torch.accelerator.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
curr_backend = dist.get_default_backend_for_device(self.device_type)
if backend is None:
backend = self.backend
if backend not in [
"nccl",
"gloo",