[Reland][2/N]Port several test files under test/distributed to Intel GPU (#159473)

For https://github.com/pytorch/pytorch/issues/114850, we will port distributed tests to Intel GPU. This PR will work on some test files under test/distributed. We could enable Intel GPU with following methods and try the best to keep the original code styles:

- instantiate_device_type_tests()
- use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- use requires_accelerator_dist_backend to allow both nccl and xccl test
- enabled XPU for some test path
- Change the hardcoded world_size according to device_count.
- Unify some common code under torch/testing/_internal for multiple backend, for example:
  Added xpu for Backend.backend_capability and dist.Backend.register_backend()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159473
Approved by: https://github.com/guangyey, https://github.com/d4l3k
This commit is contained in:
Deng, Daisy
2025-09-17 06:42:24 +00:00
committed by PyTorch MergeBot
parent 71b272e4a3
commit c9485f8ff3
11 changed files with 345 additions and 233 deletions

View File

@ -43,6 +43,7 @@ from torch.testing._internal.common_utils import (
retry_on_connect_failures,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
TEST_XPU,
TestCase,
)
from torch.utils.checkpoint import checkpoint
@ -63,6 +64,8 @@ else:
torch.backends.cuda.matmul.allow_tf32 = False
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
def gpus_for_rank(world_size):
"""Multigpu tests are designed to simulate the multi nodes with multi
@ -70,8 +73,9 @@ def gpus_for_rank(world_size):
On a single node, all visible GPUs are evenly
divided to subsets, each process only uses a subset.
"""
visible_devices = list(range(torch.cuda.device_count()))
gpus_per_process = torch.cuda.device_count() // world_size
device_count = torch.accelerator.device_count()
visible_devices = list(range(device_count))
gpus_per_process = device_count // world_size
gpus_for_rank = []
for rank in range(world_size):
gpus_for_rank.append(
@ -401,7 +405,7 @@ class CommonDistributedDataParallelTest:
gradient_as_bucket_view=gradient_as_bucket_view,
)
input = torch.randn(global_batch_size, 2).cuda(devices[0])
input = torch.randn(global_batch_size, 2).to(devices[0])
target = torch.randn(global_batch_size, 4)
return model, ddp_model, input, target
@ -435,10 +439,10 @@ class CommonDistributedDataParallelTest:
allow_none_grads=False,
):
# to reproduce the same training results
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
torch.manual_seed(31415)
model = copy.deepcopy(input_model).cuda()
ddp_model = copy.deepcopy(input_model).cuda()
model = copy.deepcopy(input_model).to(device_type)
ddp_model = copy.deepcopy(input_model).to(device_type)
ddp_model = nn.parallel.DistributedDataParallel(
ddp_model,
bucket_cap_mb=1,
@ -554,8 +558,8 @@ class CommonDistributedDataParallelTest:
def _prepare_dummy_data(self):
ddp_bs = 16
bs = ddp_bs * self.world_size
input = torch.rand((bs, 20), device="cuda", requires_grad=True)
target = torch.randn((bs, 20), device="cuda")
input = torch.rand((bs, 20), device=device_type, requires_grad=True)
target = torch.randn((bs, 20), device=device_type)
offset = self.rank * ddp_bs
ddp_input = input[offset : offset + ddp_bs]
ddp_target = target[offset : offset + ddp_bs]
@ -715,7 +719,7 @@ class CommonDistributedDataParallelTest:
Test that checkpointing with weight sharing works.
"""
process_group = self._get_process_group()
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
for use_bucket_view, static_graph in product((False, True), (False, True)):
torch.manual_seed(31415)
l1 = nn.Linear(20, 20)
@ -738,7 +742,7 @@ class CommonDistributedDataParallelTest:
same layer twice and having weights shared across layers.
"""
process_group = self._get_process_group()
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
for use_bucket_view in (True, False):
self._test_ddp_checkpointing(
self.CheckpointTwiceModuleWeightSharing(),
@ -1162,7 +1166,7 @@ class AbstractCommTest:
# Verify sequence numbers are appropriately incremented
for i in range(10):
t = torch.ones(1, device=torch.cuda.current_device())
t = torch.ones(1, device=device_type)
dist.all_reduce(t, group=process_group)
if not c10d._rank_not_in_group(process_group):
seq_num = self._verify_sequence_number_across_pg(
@ -1193,7 +1197,7 @@ class AbstractCommTest:
self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
def _test_sequence_num_incremented_default_group(self, backend_name):
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend_name,
@ -1207,7 +1211,7 @@ class AbstractCommTest:
)
def _test_sequence_num_incremented_subgroup(self, backend_name):
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend_name,
@ -1262,8 +1266,8 @@ class AbstractCommTest:
in_group_ranks = list(filter(lambda x: x % 2 == 0, range(self.world_size)))
group = dist.new_group(in_group_ranks)
x = torch.zeros(2, 2).cuda(self.rank)
xs = [torch.zeros(2, 2).cuda(self.rank) for _ in range(len(in_group_ranks))]
x = torch.zeros(2, 2).to(self.rank)
xs = [torch.zeros(2, 2).to(self.rank) for _ in range(len(in_group_ranks))]
if self.rank not in in_group_ranks:
msg = ".*{}.*does not belong to.*"
with self.assertWarnsOnceRegex(UserWarning, msg.format("all_gather")):
@ -1392,7 +1396,7 @@ class AbstractCommTest:
rank=self.rank,
store=store,
)
device = "cuda" if backend == "nccl" else "cpu"
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
# test alltoall_base
tensor = torch.tensor([1, 0, 0, 1], dtype=torch.bool, device=device)
zeros = torch.tensor([0, 0, 0, 0], dtype=torch.bool, device=device)
@ -1574,8 +1578,8 @@ class CommTest(AbstractCommTest, MultiProcessTestCase):
class DummyWork(dist._Work):
def wait(self, timeout=5.0):
if torch.cuda.is_available():
torch.cuda.current_stream().synchronize()
if torch.accelerator.is_available():
torch.accelerator.current_stream().synchronize()
return True
@ -1790,6 +1794,18 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
("cpu:gloo,cuda:nccl", "cpu:gloo,cuda:nccl"),
]
if TEST_XPU:
# Override backend_config_strings_and_expected_values for Intel GPU.
backend_config_strings_and_expected_values[4:10] = [
(dist.Backend.DUMMY, "cpu:dummy,cuda:dummy,xpu:dummy"),
("DUMMY", "cpu:dummy,cuda:dummy,xpu:dummy"),
("dummy", "cpu:dummy,cuda:dummy,xpu:dummy"),
("cpu:dummy,xpu:dummy", "cpu:dummy,xpu:dummy"),
("cpu:dummy,xpu:xccl", "cpu:dummy,xpu:xccl"),
("cpu:gloo,xpu:dummy", "cpu:gloo,xpu:dummy"),
("cpu:gloo,xpu:xccl", "cpu:gloo,xpu:xccl"),
]
for config_str, expected_value in backend_config_strings_and_expected_values:
with self.subTest(config_str):
# ensures these configs strings are valid and no ValueError is raised
@ -1800,6 +1816,8 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
invalid_backend_config_strings = [
"cpu:gloo,cuda:nccl,", # trailing comma
"cpu:gloo,cuda:nccl,cpu:dummy", # duplicate device
"cpu:gloo,xpu:xccl,", # trailing comma
"cpu:gloo,xpu:xccl,cpu:dummy", # duplicate device
]
for config_str in invalid_backend_config_strings:
with self.subTest(config_str):
@ -1814,7 +1832,7 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "6789"
dist.init_process_group(
"cpu:dummy,cuda:dummy", rank=self.rank, world_size=self.world_size
"cpu:dummy,cuda:dummy,xpu:dummy", rank=self.rank, world_size=self.world_size
)
# test all_gather
@ -2053,7 +2071,7 @@ dist.init_process_group(rank=0, world_size=1, store=dist.HashStore())
# correctly dispatched
# TODO: this will be updated in the future to not be backend specific
device = "cuda" if backend == "nccl" else "cpu"
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
# ensure supported devices (cpu, cuda) succeeds during dispatch call
tensor = torch.zeros(2, 2, device=torch.device(device))
# multi tensor collectives
@ -2119,7 +2137,7 @@ dist.init_process_group(rank=0, world_size=1, store=dist.HashStore())
rank=self.rank,
store=store,
)
device = "cuda" if backend == "nccl" else "cpu"
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
# test alltoall_base
input_tensor = torch.ones(2, 2, device=torch.device(device))
output_tensor = torch.zeros(2, 2, device=torch.device(device))
@ -2251,8 +2269,9 @@ class LocalRankTest(MultiProcessTestCase):
if __name__ == "__main__":
assert not torch.cuda._initialized, (
"test_distributed must not have initialized CUDA context on main process"
)
if device_type != "cpu":
assert not torch.get_device_module()._initialized, (
"test_distributed must not have initialized {device_type} context on main process"
)
run_tests()

View File

@ -25,7 +25,7 @@ from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_device_type import e4m3_type
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
@ -59,7 +59,7 @@ if not dist.is_available():
sys.exit(0)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
class TestWithNCCL(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
@ -75,13 +75,15 @@ class TestWithNCCL(MultiProcessTestCase):
@property
def device(self) -> torch.device:
return torch.device(f"cuda:{self.rank}")
return torch.device(self.rank)
def _init_process_group(self) -> None:
torch.cuda.set_device(self.device)
torch.accelerator.set_device_index(self.rank)
store = dist.FileStore(self.file_name, self.world_size)
backend = dist.get_default_backend_for_device(self.device.type)
dist.init_process_group(
backend="nccl",
backend=backend,
world_size=self.world_size,
rank=self.rank,
store=store,
@ -273,7 +275,7 @@ class TestWithNCCL(MultiProcessTestCase):
)
# check memory leak
for i in range(1, 10):
mem_usage[i] = torch.cuda.max_memory_allocated()
mem_usage[i] = torch.accelerator.max_memory_allocated()
compiled(arg)
assert mem_usage[9] == mem_usage[8]
@ -370,14 +372,16 @@ class TestWithNCCL(MultiProcessTestCase):
@skip_if_lt_x_gpu(2)
def test_all_to_all_single(self) -> None:
self._init_process_group()
torch.cuda.set_device(self.device)
torch.accelerator.set_device_index(self.rank)
torch.manual_seed(42)
send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))
input_split_sizes = send_sz_matrix[self.rank].tolist()
output_split_sizes = send_sz_matrix[:, self.rank].tolist()
input = torch.full((sum(input_split_sizes),), float(self.rank)).cuda()
input = torch.full((sum(input_split_sizes),), float(self.rank)).to(
self.device.type
)
output = torch.ops._c10d_functional.all_to_all_single(
input,
@ -388,7 +392,7 @@ class TestWithNCCL(MultiProcessTestCase):
output = torch.ops._c10d_functional.wait_tensor(output)
expect = torch.cat(
[
torch.full((sz,), float(rank)).cuda()
torch.full((sz,), float(rank)).to(self.device.type)
for rank, sz in enumerate(output_split_sizes)
]
)
@ -464,7 +468,7 @@ class TestWithNCCL(MultiProcessTestCase):
@fresh_cache()
def test_threading(self):
self._init_process_group()
device = torch.device(f"cuda:{self.rank}")
device = self.device
def func(arg: torch.Tensor) -> torch.Tensor:
buf0 = arg + 42
@ -544,9 +548,9 @@ class TestWithNCCL(MultiProcessTestCase):
return in_grad, w_grad
m, n, k = 128, 256, 64
in_ = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
w = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
out_grad = torch.randn((m, n), device="cuda", dtype=torch.bfloat16)
in_ = torch.randn((m, k), device=self.device.type, dtype=torch.bfloat16)
w = torch.randn((n, k), device=self.device.type, dtype=torch.bfloat16)
out_grad = torch.randn((m, n), device=self.device.type, dtype=torch.bfloat16)
eager_in_grad, eager_w_grad = fp8_rowwise_backward(in_, w, out_grad)
compile_in_grad, compile_w_grad = torch.compile(fp8_rowwise_backward)(
@ -775,7 +779,8 @@ class CompileTest(TestCase):
self.rank = 0
self.world_size = 2
torch.cuda.set_device("cuda:0")
torch.accelerator.set_device_index(0)
self.device = torch.accelerator.current_accelerator()
store = FakeStore()
dist.init_process_group(
@ -801,7 +806,7 @@ class CompileTest(TestCase):
ar1 = funcol.wait_tensor(ar1)
return ar0, ar1
arg = torch.rand(4, 4, device="cuda")
arg = torch.rand(4, 4, device=self.device)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
@ -834,7 +839,7 @@ class CompileTest(TestCase):
# Test aoti
AOTIRunnerUtil.run(func, (arg,))
torch.cuda.synchronize()
torch.accelerator.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -849,7 +854,7 @@ class CompileTest(TestCase):
ar1 = [funcol.wait_tensor(out) for out in ar1]
return ar0, ar1
args = [torch.rand(4, 4, device="cuda") for _ in range(2)]
args = [torch.rand(4, 4, device=self.device.type) for _ in range(2)]
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, args)
buf0, buf1, buf2, buf3 = find_buffer_assignments(code)
@ -879,7 +884,7 @@ class CompileTest(TestCase):
# Test aoti
out = AOTIRunnerUtil.run(func, (args,)) # noqa: F841
torch.cuda.synchronize()
torch.accelerator.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -890,7 +895,7 @@ class CompileTest(TestCase):
ar0 = funcol.wait_tensor(ar0)
return ar0
arg = torch.rand(4, 4, device="cuda")
arg = torch.rand(4, 4, device=self.device.type)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
@ -915,7 +920,7 @@ class CompileTest(TestCase):
# Expect allocation
return ar0
arg = torch.rand(4, 4, device="cuda").T
arg = torch.rand(4, 4, device=self.device.type).T
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
@ -946,7 +951,7 @@ class CompileTest(TestCase):
buf2 = torch.mm(arg, buf1)
return buf1, buf2
arg = torch.rand(4, 4, device="cuda")
arg = torch.rand(4, 4, device=self.device.type)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
buf0, buf1 = find_buffer_assignments(code)
@ -976,7 +981,7 @@ class CompileTest(TestCase):
ag0 = funcol.wait_tensor(ag0)
return ag0
arg = torch.rand(4, 4, device="cuda")
arg = torch.rand(4, 4, device=self.device.type)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
(
@ -993,7 +998,7 @@ class CompileTest(TestCase):
# Test aoti
AOTIRunnerUtil.run(func, (arg,))
torch.cuda.synchronize()
torch.accelerator.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -1003,7 +1008,7 @@ class CompileTest(TestCase):
ag0 = [funcol.wait_tensor(out) for out in ag0]
return ag0
args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
args = [torch.rand(4, 4, device=self.device.type) for _ in range(4)]
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, args)
(
@ -1027,7 +1032,7 @@ class CompileTest(TestCase):
# Test aoti
out = AOTIRunnerUtil.run(func, (args,)) # noqa: F841
torch.cuda.synchronize()
torch.accelerator.synchronize()
@unittest.skipIf(not HAS_GPU, "This is a GPU test!")
@fresh_cache()
@ -1037,7 +1042,7 @@ class CompileTest(TestCase):
return funcol.wait_tensor(t)
# Test aoti
arg = torch.rand(4, 4, device="cuda")
arg = torch.rand(4, 4, device=self.device.type)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
(
@ -1049,7 +1054,7 @@ class CompileTest(TestCase):
# Test aoti
AOTIRunnerUtil.run(func, (arg,))
torch.cuda.synchronize()
torch.accelerator.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -1059,7 +1064,7 @@ class CompileTest(TestCase):
rs0 = funcol.wait_tensor(rs0)
return rs0
arg = torch.rand(4, 4, device="cuda")
arg = torch.rand(4, 4, device=self.device.type)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
(
@ -1075,7 +1080,7 @@ class CompileTest(TestCase):
# Test aoti
AOTIRunnerUtil.run(func, (arg,))
torch.cuda.synchronize()
torch.accelerator.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -1087,7 +1092,7 @@ class CompileTest(TestCase):
rs0 = [funcol.wait_tensor(out) for out in rs0]
return rs0
args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
args = [torch.rand(4, 4, device=self.device.type) for _ in range(4)]
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, args)
(
@ -1111,7 +1116,7 @@ class CompileTest(TestCase):
# Test aoti
AOTIRunnerUtil.run(func, (args,))
torch.cuda.synchronize()
torch.accelerator.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -1140,7 +1145,9 @@ class CompileTest(TestCase):
input_split_sizes = send_sz_matrix[self.rank]
output_split_sizes = send_sz_matrix[:, self.rank].contiguous()
input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).cuda()
input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).to(
self.device.type
)
with torch._dynamo.config.patch(
dynamic_shapes=True,
@ -1174,7 +1181,7 @@ class CompileTest(TestCase):
br1 = funcol.wait_tensor(br1)
return br0, br1
arg = torch.rand(4, 4, device="cuda")
arg = torch.rand(4, 4, device=self.device.type)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
@ -1197,7 +1204,7 @@ class CompileTest(TestCase):
# Test aoti
AOTIRunnerUtil.run(func, (arg,))
torch.cuda.synchronize()
torch.accelerator.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -1212,7 +1219,7 @@ class CompileTest(TestCase):
ar1 = funcol.wait_tensor(ar1)
return ar0, ar1
arg = torch.rand(4, 4, device="cuda")
arg = torch.rand(4, 4, device=self.device.type)
compiled = torch.compile(func, fullgraph=True)
code = run_and_get_triton_code(compiled, arg)

View File

@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import os
import unittest
import torch
import torch.distributed as dist
@ -26,7 +27,7 @@ from torch.distributed.tensor._collective_utils import (
)
from torch.distributed.tensor.placement_types import _Partial, Shard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import run_tests, TEST_XPU
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
@ -35,6 +36,10 @@ from torch.testing._internal.distributed.fake_pg import FakeProcessGroup, FakeSt
from torch.utils._typing_utils import not_none
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
device_count = torch.accelerator.device_count()
def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_rank=-1):
os.environ["MASTER_ADDR"] = addr
os.environ["MASTER_PORT"] = port
@ -44,6 +49,7 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_ran
os.environ["LOCAL_RANK"] = f"{local_rank}"
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.")
class DeviceMeshTestGlooBackend(DTensorTestBase):
@property
def backend(self):
@ -73,14 +79,16 @@ class DeviceMeshSetDeviceTest(DTensorTestBase):
# Set the device on each process before DeviceMesh constructor,
# and device to be different than the default world rank
torch.cuda.set_device((self.rank + 2) % self.world_size)
torch.accelerator.set_device_index((self.rank + 2) % self.world_size)
_set_env_var(world_size=self.world_size, rank=self.rank)
DeviceMesh(self.device_type, mesh_tensor)
self.assertTrue(is_initialized())
# check that the device is set to the correct device
# and respect the previous set_device calls
self.assertEqual(torch.cuda.current_device(), (self.rank + 2) % self.world_size)
self.assertEqual(
torch.accelerator.current_device_idx(), (self.rank + 2) % self.world_size
)
self.destroy_pg()
@skip_if_lt_x_gpu(4)
@ -101,7 +109,7 @@ class DeviceMeshSetDeviceTest(DTensorTestBase):
# check that the device is set to the correct device
# and respect the LOCAL_RANK env var
self.assertEqual(torch.cuda.current_device(), local_rank)
self.assertEqual(torch.accelerator.current_device_idx(), local_rank)
self.destroy_pg()
@skip_if_lt_x_gpu(4)
@ -120,7 +128,7 @@ class DeviceMeshSetDeviceTest(DTensorTestBase):
self.assertTrue(is_initialized())
# check that the device is set to the correct device
self.assertEqual(torch.cuda.current_device(), self.rank)
self.assertEqual(torch.accelerator.current_device_idx(), self.rank)
self.destroy_pg()
@ -222,7 +230,7 @@ class DeviceMeshTest(DTensorTestBase):
@with_comms
def test_device_mesh_2d(self):
mesh_tensor = torch.arange(4).reshape(2, 2)
# construct a cuda device mesh
# construct a device mesh for self.device_type
mesh = DeviceMesh(self.device_type, mesh_tensor)
# check all dim groups
@ -260,7 +268,11 @@ class DeviceMeshTest(DTensorTestBase):
def test_fake_pg_device_mesh(self):
fake_store = FakeStore()
init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size)
device_type = "cuda" if torch.cuda.is_available() else "cpu"
device_type = (
torch.accelerator.current_accelerator().type
if torch.accelerator.is_available()
else "cpu"
)
mesh = DeviceMesh(device_type, torch.arange(self.world_size))
local_tensor = torch.randn(2, 8)
@ -300,7 +312,7 @@ class DeviceMeshTest(DTensorTestBase):
regex = r"Invalid mesh \[\[0, 1\], \[2, 3\]\] for ProcessGroup with ranks \[0, 1, 2, 3\]"
with self.assertRaisesRegex(ValueError, regex):
DeviceMesh.from_group(
global_pg, "cuda", invalid_mesh, mesh_dim_names=("dim0", "dim1")
global_pg, device_type, invalid_mesh, mesh_dim_names=("dim0", "dim1")
)
device_mesh = init_device_mesh(self.device_type, (2, 2))
@ -320,12 +332,16 @@ class DeviceMeshTest(DTensorTestBase):
# test init_device_mesh with an invalid device type that contains a GPU index
mesh_shape = (2, self.world_size // 2)
init_device_mesh(
"cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
f"{device_type}:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
)
@with_comms
def test_set_mesh_dim_group_options(self):
device_type = "cuda" if torch.cuda.is_available() else "cpu"
device_type = (
torch.accelerator.current_accelerator().type
if torch.accelerator.is_available()
else "cpu"
)
_mesh_resources._set_mesh_dim_group_options(1, "fake", None)
mesh_tensor = torch.arange(4).reshape(2, 2)
@ -341,7 +357,7 @@ class DeviceMeshTestNDim(DTensorTestBase):
@with_comms
def test_device_mesh_nd(self):
# construct a cuda device mesh
# construct a device mesh for self.device_type
mesh_tensor = torch.arange(8).reshape(2, 2, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
@ -710,7 +726,9 @@ class TestDeviceMeshGetItem(DTensorTestBase):
with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"):
mesh_dim_names = ("DP", "TP")
mesh = init_device_mesh(
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
self.device_type,
(2, 4),
mesh_dim_names=mesh_dim_names,
)
mesh[child_mesh_dim_name]
@ -938,7 +956,9 @@ class TestMeshEnv(DTensorTestBase):
@with_comms
def test_get_root_mesh(self):
mesh_3d = init_device_mesh(
self.device_type, (2, 2, 2), mesh_dim_names=("dp", "cp", "tp")
self.device_type,
(2, 2, 2),
mesh_dim_names=("dp", "cp", "tp"),
)
dp_cp_mesh = mesh_3d["dp", "cp"]
@ -986,7 +1006,9 @@ class TestMeshEnv(DTensorTestBase):
@with_comms
def test_get_all_submeshes(self):
mesh_2d = init_device_mesh(
self.device_type, (2, 4), mesh_dim_names=("replicate", "shard")
self.device_type,
(2, 4),
mesh_dim_names=("replicate", "shard"),
)
all_submeshes = _mesh_resources._get_all_submeshes(mesh_2d, "replicate")
self.assertEqual(len(all_submeshes), 4)

View File

@ -43,11 +43,12 @@ from torch.testing._internal.common_distributed import (
DynamoDistributedMultiProcTestCase,
DynamoDistributedSingleProcTestCase,
import_transformers_or_skip,
requires_nccl,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import requires_cuda
from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.testing._internal.triton_utils import requires_cuda_and_triton
def reset_rng_state():
@ -270,7 +271,15 @@ def get_hf_bert(rank):
except ImportError as e:
raise unittest.SkipTest("Unable to import transformers") from e
batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}"
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
)
batch_size, max_length, config, device = (
4,
512,
BertConfig(),
f"{device_type}:{rank}",
)
model = AutoModelForMaskedLM.from_config(config).to(device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device)
decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(
@ -550,8 +559,8 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
# Are these tests failing? Check and see if TestFakeDistributedSingleProc has a
# single process version; if it's just a problem in the Dynamo distributed
# optimizer, you should be able to repro it single process!
@requires_nccl()
# # optimizer, you should be able to repro it single process!
@requires_accelerator_dist_backend(["nccl", "xccl"])
class TestMultiProc(DynamoDistributedMultiProcTestCase):
"""
Note: MultiProcTestCase spawns processes per test and is slow.
@ -559,12 +568,16 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
sparingly for integration tests.
"""
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
)
@skip_if_lt_x_gpu(2)
@config.patch(optimize_ddp=False, enable_compiler_collectives=True)
def test_ddp_baseline_aot_eager_multiprocess(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
self.assertFalse(config.optimize_ddp)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
m = DDP(m, device_ids=[self.rank])
m = torch.compile(m, backend="aot_eager")
outputs = m(inputs)
@ -632,7 +645,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
self.assertFalse(config.optimize_ddp)
model = MyModel().to(device="cuda")
model = MyModel().to(device=self.device_type)
# Activation checkpointing for Linear layers.
non_reentrant_wrapper = functools.partial(
@ -647,7 +660,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
)
model = DDP(model)
x = torch.randn(10, 64).cuda()
x = torch.randn(10, 64).to(self.device_type)
correct_outputs = model(x)
opt_model = torch.compile(model)
@ -659,14 +672,14 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
def test_fsdp_aot_eager(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="aot_eager")
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# Test with recursive wrapping, nested FSDP around each Linear
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
fsdp_m = FSDP(
m,
auto_wrap_policy=functools.partial(
@ -680,6 +693,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
@skip_if_lt_x_gpu(2)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@requires_cuda_and_triton
def test_ddp_optimizer_cudagraph(self):
class Net(nn.Module):
def __init__(self):
@ -730,7 +744,9 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
from torch._dynamo.utils import counters
counters.clear()
m, inputs, correct_outputs = get_mutating_model(f"cuda:{self.rank}")
m, inputs, correct_outputs = get_mutating_model(
f"{self.device_type}:{self.rank}"
)
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
outputs = fsdp_m(inputs)
@ -748,7 +764,9 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
from torch._dynamo.utils import counters
counters.clear()
m, inputs, correct_outputs = get_forced_getattr_module(f"cuda:{self.rank}")
m, inputs, correct_outputs = get_forced_getattr_module(
f"{self.device_type}:{self.rank}"
)
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
outputs = fsdp_m(inputs)
@ -762,7 +780,9 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
from torch._dynamo.utils import counters
counters.clear()
m, inputs, correct_outputs = get_forced_getattr_module(f"cuda:{self.rank}")
m, inputs, correct_outputs = get_forced_getattr_module(
f"{self.device_type}:{self.rank}"
)
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
outputs = fsdp_m(inputs)
@ -774,14 +794,14 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
def test_fsdp_inductor(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="inductor")
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# Test with recursive wrapping, nested FSDP around each Linear
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
fsdp_m = FSDP(
m,
auto_wrap_policy=functools.partial(
@ -799,7 +819,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
def test_fsdp_activation_checkpointing(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model, inputs = get_toy_model_for_activation_checkpointing(
f"cuda:{self.rank}"
f"{self.device_type}:{self.rank}"
)
is_inner = lambda module: isinstance(module, ToyInnerModel) # noqa: E731
wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=is_inner)
@ -961,7 +981,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
torch._dynamo.utils.clear_compilation_metrics()
# TODO: This should be possible to do inside the function, but
device = f"cuda:{self.rank}"
device = f"{self.device_type}:{self.rank}"
@torch.compile()
def f(x, y):
@ -1181,7 +1201,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
pg = dist.distributed_c10d._get_default_group()
device = f"cuda:{self.rank}"
device = f"{self.device_type}:{self.rank}"
@torch.compile(fullgraph=True)
def f(x):
@ -1196,6 +1216,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
pg = dist.distributed_c10d.GroupMember.NON_GROUP_MEMBER
self.assertEqual(f(x), x + 1)
@skipIfXpu # ProcessGroupXCCL doesn't support _set_default_timeout yet.
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "fx_graph_cache", False)
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
@ -1205,7 +1226,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
device = f"cuda:{self.rank}"
device = f"{self.device_type}:{self.rank}"
pg = dist.distributed_c10d._get_default_group()
@ -1238,7 +1259,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
w = pg.allreduce(x)
w.wait()
torch.cuda.synchronize(device)
torch.accelerator.synchronize(device)
metrics = torch._dynamo.utils.get_compilation_metrics()
# Number of compiles same on all nodes
@ -1247,6 +1268,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
for r in res[1:]:
self.assertEqual(res[0], r)
@skipIfXpu # ProcessGroupXCCL doesn't support _set_default_timeout yet.
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "fx_graph_cache", True)
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
@ -1258,7 +1280,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
with fresh_cache(), _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
device = f"cuda:{self.rank}"
device = f"{self.device_type}:{self.rank}"
pg = dist.distributed_c10d._get_default_group()
@ -1281,7 +1303,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
w = pg.allreduce(x)
w.wait()
torch.cuda.synchronize(device)
torch.accelerator.synchronize(device)
torch._dynamo.reset()
if self.rank == 0:
@ -1298,11 +1320,11 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
w = pg.allreduce(x)
w.wait()
torch.cuda.synchronize(device)
torch.accelerator.synchronize(device)
@requires_nccl()
@requires_cuda
@requires_accelerator_dist_backend(["nccl", "xccl"])
@unittest.skipUnless(torch.accelerator.is_available(), "Requires accelerator")
class TestSingleProc(DynamoDistributedSingleProcTestCase):
"""
Test harness initializes dist process group.
@ -1311,6 +1333,10 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
Use TestMultiProc for things that really need to run on multiple nodes
"""
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
)
def get_model(
self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
):
@ -1428,6 +1454,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
self.assertEqual(len(break_reasons), 4)
self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
@skipIfXpu # XPU device doesn't support flex_attention yet.
@patch.object(config, "optimize_ddp", True)
def test_compiled_flex_attention_full_model_ddp(self):
class Model(torch.nn.Module):
@ -1474,16 +1501,16 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
S = 512
D = 64
device = "cuda"
model = Model(S, H, D)
model.to(device)
model.to(self.device_type)
model = torch.compile(model)
model = DDP(model, device_ids=self.device_ids)
hidden_states = torch.randn(B, S, H * D).to(device)
hidden_states = torch.randn(B, S, H * D).to(self.device_type)
model(hidden_states)
torch.cuda.synchronize()
torch.accelerator.synchronize()
@skipIfXpu # XPU device doesn't support flex_attention yet.
@patch.object(config, "optimize_ddp", True)
def test_compiled_flex_attention_local_ddp(self):
class Model(torch.nn.Module):
@ -1530,15 +1557,14 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
S = 512
D = 64
device = "cuda"
model = Model(S, H, D)
model.to(device)
model.to(self.device_type)
model = torch.compile(model)
model = DDP(model, device_ids=self.device_ids)
hidden_states = torch.randn(B, S, H * D).to(device)
hidden_states = torch.randn(B, S, H * D).to(self.device_type)
model(hidden_states)
torch.cuda.synchronize()
torch.accelerator.synchronize()
@patch.object(config, "optimize_ddp", True)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@ -1787,9 +1813,9 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
a = torch.cos(a)
return a
mod = MockModule().cuda()
mod = MockModule().to(self.device_type)
mod = DDP(mod, bucket_cap_mb=1)
x = torch.randn(N, N, device="cuda", requires_grad=True)
x = torch.randn(N, N, device=self.device_type, requires_grad=True)
args = (x,)
backend = "aot_eager"
@ -1799,7 +1825,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
def test_fsdp_orig_params_assert(self):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, _ = get_model(f"cuda:{self.rank}")
m, inputs, _ = get_model(f"{self.device_type}:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=False)
# Test is that this function call does not throw an exception.
fsdp_m = torch.compile(fsdp_m)
@ -1845,7 +1871,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
return out
device = f"cuda:{self.rank}"
device = f"{self.device_type}:{self.rank}"
m = ToyModel(
in_feat=10,
hidden_feat=5000,
@ -1892,7 +1918,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
torch._dynamo.reset()
device = f"cuda:{self.rank}"
device = f"{self.device_type}:{self.rank}"
m = ToyModel(
in_feat=10,
hidden_feat=5000,
@ -1933,9 +1959,14 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
class DuplicateModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self._param = torch.randn((3,), device="cuda")
device_type = (
acc.type
if (acc := torch.accelerator.current_accelerator())
else "cpu"
)
self._param = torch.randn((3,), device=device_type)
self._buf = torch.nn.Buffer(
torch.randn((3,), requires_grad=False, device="cuda")
torch.randn((3,), requires_grad=False, device=device_type)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -1948,7 +1979,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
model = DuplicateModule()
fsdp_model = FSDP(copy.deepcopy(model), use_orig_params=True)
fsdp_model = torch.compile(fsdp_model, backend="aot_eager")
inp = torch.randn((2, 3), device="cuda")
inp = torch.randn((2, 3), device=self.device_type)
local_out = model(inp)
fsdp_out = fsdp_model(inp)
self.assertEqual(local_out, fsdp_out)
@ -1965,8 +1996,13 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
class BufModule(nn.Module):
def __init__(self) -> None:
super().__init__()
device_type = (
acc.type
if (acc := torch.accelerator.current_accelerator())
else "cpu"
)
self._buf = nn.Buffer(
torch.randn((3,), requires_grad=False, device="cuda")
torch.randn((3,), requires_grad=False, device=device_type)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -1975,7 +2011,12 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self._param = nn.Parameter(torch.randn((1,), device="cuda"))
device_type = (
acc.type
if (acc := torch.accelerator.current_accelerator())
else "cpu"
)
self._param = nn.Parameter(torch.randn((1,), device=device_type))
self._buf_module = BufModule()
# Share the buffer, meaning same tensor but different source
self._buf = self._buf_module._buf
@ -1992,7 +2033,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
fsdp_model = FSDP(Model(), use_orig_params=True)
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
fsdp_model = torch.compile(fsdp_model, backend=cnt)
inp = torch.randn((2, 3), device="cuda")
inp = torch.randn((2, 3), device=self.device_type)
for _ in range(15):
fsdp_model(inp)
# Check for no recompiles (if there were incorrect de-dup guards, then
@ -2011,7 +2052,12 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
super().__init__()
self._use_self = use_self
torch.manual_seed(42) # force `_param` to be deterministic
self._param = nn.Parameter(torch.randn((3,), device="cuda"))
device_type = (
acc.type
if (acc := torch.accelerator.current_accelerator())
else "cpu"
)
self._param = nn.Parameter(torch.randn((3,), device=device_type))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._use_self:
@ -2026,7 +2072,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
return x + y
model = ModuleWithStaticMethod(False)
x = torch.randn((2, 3), device="cuda")
x = torch.randn((2, 3), device=self.device_type)
ref_out = model(x)
test_outs: list[torch.Tensor] = []

View File

@ -10,6 +10,7 @@ import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
import torch.distributed as c10d
# for some reason importing functional collectives after dynamo breaks collectives handling!
import torch.distributed._functional_collectives as _functional_collectives
@ -37,14 +38,14 @@ from torch.testing._internal.common_distributed import (
DynamoDistributedMultiProcTestCase,
DynamoDistributedSingleProcTestCase,
MultiProcessTestCase,
requires_nccl,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
requires_cuda,
skipIfRocm,
skipIfXpu,
TEST_XPU,
xfailIf,
)
@ -59,13 +60,15 @@ def _tolist_with_constrain_as_size(tensor):
return lst
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@instantiate_parametrized_tests
class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
"""
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
"""
device = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
def get_world_trs(self):
return {
"tag": "",
@ -102,8 +105,11 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
example,
**self.get_world_trs(),
)
t = torch.randn(4, 4, device="cuda")
inputs = (t if self.rank == 0 else torch.zeros(4, 4, device="cuda"), 0)
t = torch.randn(4, 4, device=self.device)
inputs = (
t if self.rank == 0 else torch.zeros(4, 4, device=self.device),
0,
)
eager_out = example(*inputs)
self.assertTrue(same(t, eager_out))
@ -137,7 +143,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
matmul_cat_col,
**self.get_world_trs(),
)
inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 6
inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 6
eager_out = matmul_cat_col(*inputs)
compiled_matmul_cat_col = compile(matmul_cat_col, inputs)
@ -179,7 +185,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
for nelem in [1024, 2048, 4096]:
# CI (Tesla T4) does not support bfloat16 compilation natively,
# using float
x = torch.randn(nelem, device="cuda", dtype=torch.float)
x = torch.randn(nelem, device=self.device, dtype=torch.float)
golden_out = eager_func(x)
for _ in range(3):
@ -217,8 +223,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
eager_func,
**self.get_world_trs(),
)
eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4
inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
eager_inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 4
inductor_inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 2
eager_out = inductor_func(eager_func(*eager_inputs), *inductor_inputs)
compiled_inductor_func = compile(
@ -256,8 +262,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
inductor_func,
**self.get_world_trs(),
)
inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4
eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
inductor_inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 4
eager_inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 2
eager_out = eager_func(inductor_func(*inductor_inputs), *eager_inputs)
compiled_inductor_func = compile(inductor_func, inductor_inputs)
@ -270,6 +276,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@skip_if_lt_x_gpu(2)
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1728
@skipIfRocm
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1728
def test_eager_async_allreduce_inductor_wait(self):
import torch.distributed as dist
from torch._inductor.utils import run_and_get_code
@ -292,7 +299,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
return y * y
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
x = torch.ones(12800, 12800, device="cuda") + self.rank
x = torch.ones(12800, 12800, device=self.device) + self.rank
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
# NOTE: We run for 10 iterations each, to ensure that the GPU execution is way behind CPU
@ -363,7 +370,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
return (e,)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
inputs = torch.ones(4, 4, device="cuda") + self.rank
inputs = torch.ones(4, 4, device=self.device) + self.rank
compiled = torch.compile(func)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
@ -380,7 +387,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
inputs = (
# rank0: [0., 1.], rank1: [2., 3.]
torch.arange(2, dtype=torch.float32, device="cuda") + 2 * self.rank,
torch.arange(2, dtype=torch.float32, device=self.device)
+ 2 * self.rank,
[1, 0],
)
compiled = torch.compile(func)
@ -389,7 +397,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
self.assertTrue(same(out, correct))
# rank0: [2., 3.], rank1: [0., 1.]
expected = torch.arange(2, dtype=torch.float32, device="cuda") + 2 * (
expected = torch.arange(2, dtype=torch.float32, device=self.device) + 2 * (
(self.rank - 1 + self.world_size) % self.world_size
)
self.assertEqual(out, expected)
@ -412,9 +420,9 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
return out
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model = Model().cuda()
model = Model().to(self.device)
model_compiled = torch.compile(model)
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda")
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device=self.device)
out = model_compiled(inp, self.world_size, **self.get_world_trs())
correct = model(inp, self.world_size, **self.get_world_trs())
self.assertTrue(same(out, correct))
@ -429,7 +437,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
func_compiled = torch.compile(func)
inp = torch.tensor(self.rank, dtype=torch.long, device="cuda")
inp = torch.tensor(self.rank, dtype=torch.long, device=self.device)
out = func_compiled(inp, self.world_size)
correct = func(inp, self.world_size)
self.assertTrue(same(out, correct))
@ -451,9 +459,9 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
return out
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model = Model().cuda()
model = Model().to(self.device)
model_compiled = torch.compile(model)
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda")
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device=self.device)
out = model_compiled(inp, self.world_size, **self.get_world_trs())
correct = model(inp, self.world_size, **self.get_world_trs())
self.assertTrue(same(out, correct))
@ -482,7 +490,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
example,
**self.get_world_trs(),
)
inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 2
eager_out = example(*inputs)
compiled_matmul_cat_col = compile(example, inputs)
@ -509,7 +517,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
example,
**self.get_world_trs(),
)
inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 2
eager_out = example(*inputs)
compiled_fn = compile(example, inputs)
@ -563,7 +571,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
dtype=torch.int64,
)
inputs = (
torch.ones(int(row), 5, device="cuda") * (self.rank + 1),
torch.ones(int(row), 5, device=self.device) * (self.rank + 1),
input_split_sizes_tensor,
output_split_sizes_tensor,
)
@ -732,7 +740,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
dtype=torch.int64,
)
inputs = (
torch.ones(int(row), 5, device="cuda", requires_grad=True)
torch.ones(int(row), 5, device=self.device, requires_grad=True)
* (self.rank + 1),
input_split_sizes_tensor,
output_split_sizes_tensor,
@ -795,7 +803,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
inputs = (
torch.ones(self.world_size, self.world_size, device="cuda")
torch.ones(self.world_size, self.world_size, device=self.device)
* (self.rank + 1),
)
trs = self.get_world_trs()
@ -819,8 +827,11 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
@instantiate_parametrized_tests
@requires_nccl()
@requires_cuda
@requires_accelerator_dist_backend(["nccl", "xccl"])
@unittest.skipIf(
not torch.accelerator.is_available(),
"No accelerator is available",
)
class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
"""
Prefer single-proc test runner for basic tests as it is easier to work with.
@ -843,7 +854,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
ar = torch.ops.c10d_functional.wait_tensor(ar)
return ar
inputs = torch.ones(4, 4, device="cuda")
inputs = torch.ones(4, 4, device=self.device)
compiled = torch.compile(func)
out = compiled(inputs, **self.get_world_trs())
@ -878,7 +889,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
other = torch.ones_like(inp) + 22
return ar, other
inputs = torch.ones(4, 4, device="cuda")
inputs = torch.ones(4, 4, device=self.device)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
@ -911,7 +922,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
other = torch.ones_like(inp) + 22
return ar, y, other
inputs = torch.ones(4, 4, device="cuda")
inputs = torch.ones(4, 4, device=self.device)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
@ -952,7 +963,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
ar = _functional_collectives.all_reduce(inp, "sum", "0")
return ar
inputs = torch.ones(4, 4, device="cuda")
inputs = torch.ones(4, 4, device=self.device)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter)
out = compiled(inputs)
@ -963,12 +974,13 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
self.assertEqual(counter.op_count, 2)
self.assertTrue(same(out, correct))
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
def test_dynamo_trace_all_gather_tensor(self):
def func(inp):
ar = _functional_collectives.all_gather_tensor(inp, 0, "0")
return ar
inputs = torch.ones(4, 4, device="cuda")
inputs = torch.ones(4, 4, device=self.device)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter)
out = compiled(inputs)
@ -979,6 +991,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
self.assertEqual(counter.op_count, 2)
self.assertTrue(same(out, correct))
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
def test_dynamo_trace_all_gather_tensor_pg(self):
def func(inp, *, pg):
ar = _functional_collectives.all_gather_tensor(inp, 0, pg)
@ -995,6 +1008,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
self.assertEqual(counter.op_count, 2)
self.assertTrue(same(out, correct))
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
def test_dynamo_rewrite_dist_all_gather(self):
def func(inp, out, *, pg):
torch.distributed.all_gather_into_tensor(
@ -1020,6 +1034,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
assert counter.op_count == 3
assert same(outputs, correct_outputs)
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
def test_dynamo_rewrite_dist_all_gather_list(self):
def func(inp, out, *, pg):
torch.distributed.all_gather(
@ -1042,6 +1057,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
assert counter.frame_count == 1
assert same(outputs, correct_outputs)
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
def test_dynamo_rewrite_dist_all_gather_args_match(self):
# Duplicated most of the structure from test_dynamo_rewrite_dist_all_gather
# except uses kwargs to ensure rewrite has matching arg names
@ -1070,6 +1086,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
assert counter.op_count == 3
assert same(outputs, correct_outputs)
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
def test_dynamo_rewrite_dist_reduce_scatter(self):
def func(inp, out, *, pg):
torch.distributed.reduce_scatter_tensor(
@ -1237,6 +1254,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
input = torch.ones(2, device=self.device)
compiled(input)
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
def test_dynamo_support_collective_op_with_async_op_False(self):
def func(inp, out, *, pg):
# user explicitly set the attribute `async_op` to False,
@ -1296,12 +1314,13 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
assert counter.op_count == 1
assert same(outputs, correct_outputs)
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
def test_dynamo_trace_reduce_scatter_tensor(self):
def func(inp):
ar = _functional_collectives.reduce_scatter_tensor(inp, "sum", 0, "0")
return ar
inputs = torch.ones(4, 4, device="cuda")
inputs = torch.ones(4, 4, device=self.device)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter)
out = compiled(inputs)
@ -1312,6 +1331,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
self.assertEqual(counter.op_count, 2)
self.assertTrue(same(out, correct))
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
def test_dynamo_trace_allgather_coalesced(self):
def func(inp, *, tag, ranks, group_size):
ar = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(
@ -1319,7 +1339,10 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
)
return ar
inputs = [torch.ones(4, 4, device="cuda"), torch.ones(6, 6, device="cuda")]
inputs = [
torch.ones(4, 4, device=self.device),
torch.ones(6, 6, device=self.device),
]
counter = CompileCounter()
compiled = torch.compile(func, backend=counter)
out = compiled(inputs, **self.get_world_trs())
@ -1339,7 +1362,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
ar = _functional_collectives.all_reduce(inp, "sum", "0")
return ar
input = torch.ones(4, 4, device="cuda", requires_grad=True)
input = torch.ones(4, 4, device=self.device, requires_grad=True)
compiled = torch.compile(
func, backend="aot_eager"
) # inductor bug with single-op allreduce graph
@ -1357,6 +1380,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs())
self.assertEqual(x.size(), out.size())
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
def test_inductor_all_gather_coalesced(self):
@ -1376,7 +1400,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
other = torch.ones_like(inp) + 22
return ar0, y, other, ar1
inputs = torch.ones(4, 4, device="cuda")
inputs = torch.ones(4, 4, device=self.device)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
@ -1403,6 +1427,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
correct = func(inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
def test_inductor_reduce_scatter_coalesced(self):
@ -1422,7 +1447,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
other = torch.ones_like(inp) + 22
return ar0, y, other, ar1
inputs = torch.ones(4, 4, device="cuda")
inputs = torch.ones(4, 4, device=self.device)
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
@ -1449,6 +1474,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
correct = func(inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_reorder_peak_memory(self):
"""
@ -1470,7 +1496,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
other = torch.ones_like(inp) + 22
return ar0, y, other, ar1
inputs = torch.ones(4, 4, device="cuda")
inputs = torch.ones(4, 4, device=self.device)
# get stats directly from the internal helper without affecting the real pass's signature
node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None
@ -1642,10 +1668,10 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
return y, ag_0_out, ag_1_out
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
x = torch.ones(4, 384, device=self.device, dtype=torch.float32)
w = torch.ones(384, 512, device=self.device, dtype=torch.float32)
ag_0 = torch.ones(384, 512, device=self.device, dtype=torch.float32)
ag_1 = torch.ones(384, 512, device=self.device, dtype=torch.float32)
inputs = [x, w, ag_0, ag_1]
with torch._inductor.config.patch(
@ -1810,12 +1836,12 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
rs_3_out,
)
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_0 = torch.ones(1024, 512, device="cuda", dtype=torch.float32)
ag_1 = torch.ones(512, 1024, device="cuda", dtype=torch.float32)
ag_2 = torch.ones(1024, 512, device="cuda", dtype=torch.float32)
ag_3 = torch.ones(512, 1024, device="cuda", dtype=torch.float32)
x = torch.ones(4, 384, device=self.device, dtype=torch.float32)
w = torch.ones(384, 512, device=self.device, dtype=torch.float32)
ag_0 = torch.ones(1024, 512, device=self.device, dtype=torch.float32)
ag_1 = torch.ones(512, 1024, device=self.device, dtype=torch.float32)
ag_2 = torch.ones(1024, 512, device=self.device, dtype=torch.float32)
ag_3 = torch.ones(512, 1024, device=self.device, dtype=torch.float32)
inputs = [x, w, ag_0, ag_1, ag_2, ag_3]
# get stats directly from the internal helper without affecting the real pass's signature
@ -1917,6 +1943,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
node_stat1 = next(it)
self.assertTrue("collective ordering" in node_stat1.limiting_factor)
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_reorder_respects_wait_dep(self):
"""
@ -1939,7 +1966,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
# ensure other is not incorrectly aliasing ar's buffer
return ag_1_wait
inputs = torch.ones(4, 4, device="cuda")
inputs = torch.ones(4, 4, device=self.device)
# get stats directly from the internal helper without affecting the real pass's signature
node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None
@ -1988,7 +2015,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
self.assertEqual(stats.moves, 0)
@requires_nccl()
@requires_accelerator_dist_backend(["nccl", "xccl"])
class TestSyncDecisionCrossRanks(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
@ -2004,16 +2031,21 @@ class TestSyncDecisionCrossRanks(MultiProcessTestCase):
@property
def device(self) -> torch.device:
return torch.device(f"cuda:{self.rank}")
device_type = torch.accelerator.current_accelerator().type
return torch.device(f"{device_type}:{self.rank}")
def _init_process_group(self) -> None:
torch._inductor.config.triton.store_cubin = True
torch._inductor.config.debug = True
torch.cuda.set_device(self.device)
torch.get_device_module(self.device).set_device(self.device)
store = torch.distributed.FileStore(self.file_name, self.world_size)
backend = c10d.get_default_backend_for_device(
torch.accelerator.current_accelerator().type
)
torch.distributed.init_process_group(
backend="nccl",
backend=backend,
world_size=self.world_size,
rank=self.rank,
store=store,

View File

@ -54,6 +54,8 @@ DEFAULT_HOSTNAME = "localhost"
torch.backends.cuda.matmul.allow_tf32 = False
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
def gpus_for_rank(world_size):
"""Multigpu tests are designed to simulate the multi nodes with multi
@ -61,8 +63,8 @@ def gpus_for_rank(world_size):
On a single node, all visible GPUs are evenly
divided to subsets, each process only uses a subset.
"""
visible_devices = list(range(torch.cuda.device_count()))
gpus_per_process = torch.cuda.device_count() // world_size
visible_devices = list(range(torch.accelerator.device_count()))
gpus_per_process = torch.accelerator.device_count() // world_size
gpus_for_rank = []
for rank in range(world_size):
gpus_for_rank.append(
@ -1174,8 +1176,8 @@ class TestClientProtocol(TestCase):
if __name__ == "__main__":
assert not torch.cuda._initialized, (
"test_distributed must not have initialized CUDA context on main process"
)
if device_type != "cpu":
assert not torch.get_device_module()._initialized, (
"test_distributed must not have initialized {device_type} context on main process"
)
run_tests()

View File

@ -115,10 +115,13 @@ from torch.testing._internal.common_utils import (
set_default_dtype,
set_rng_seed,
skipIfTorchDynamo,
TEST_XPU,
TestCase,
)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
@ -1788,18 +1791,21 @@ class TestDistributions(DistributionsTestCase):
).logpmf(sample)
self.assertEqual(log_prob, expected, atol=1e-4, rtol=0)
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "CUDA and XPU not found")
def test_zero_excluded_binomial(self):
vals = Binomial(
total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.9).cuda()
total_count=torch.tensor(1.0).to(device_type),
probs=torch.tensor(0.9).to(device_type),
).sample(torch.Size((100000000,)))
self.assertTrue((vals >= 0).all())
vals = Binomial(
total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.1).cuda()
total_count=torch.tensor(1.0).to(device_type),
probs=torch.tensor(0.1).to(device_type),
).sample(torch.Size((100000000,)))
self.assertTrue((vals < 2).all())
vals = Binomial(
total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.5).cuda()
total_count=torch.tensor(1.0).to(device_type),
probs=torch.tensor(0.5).to(device_type),
).sample(torch.Size((10000,)))
# vals should be roughly half zeroes, half ones
assert (vals == 0.0).sum() > 4000
@ -2050,15 +2056,15 @@ class TestDistributions(DistributionsTestCase):
)
torch.set_default_dtype(saved_dtype)
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "CUDA and XPU not found")
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_poisson_gpu_sample(self):
set_rng_seed(1)
for rate in [0.12, 0.9, 4.0]:
self._check_sampler_discrete(
Poisson(torch.tensor([rate]).cuda()),
Poisson(torch.tensor([rate]).to(device_type)),
scipy.stats.poisson(rate),
f"Poisson(lambda={rate}, cuda)",
f"Poisson(lambda={rate}, {device_type})",
failure_rate=1e-3,
)
@ -3490,13 +3496,13 @@ class TestDistributions(DistributionsTestCase):
self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "CUDA and XPU not found")
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_gamma_gpu_shape(self):
alpha = torch.randn(2, 3).cuda().exp().requires_grad_()
beta = torch.randn(2, 3).cuda().exp().requires_grad_()
alpha_1d = torch.randn(1).cuda().exp().requires_grad_()
beta_1d = torch.randn(1).cuda().exp().requires_grad_()
alpha = torch.randn(2, 3).to(device_type).exp().requires_grad_()
beta = torch.randn(2, 3).to(device_type).exp().requires_grad_()
alpha_1d = torch.randn(1).to(device_type).exp().requires_grad_()
beta_1d = torch.randn(1).to(device_type).exp().requires_grad_()
self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
@ -3527,7 +3533,10 @@ class TestDistributions(DistributionsTestCase):
def test_gamma_gpu_sample(self):
set_rng_seed(0)
for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
a, b = torch.tensor([alpha]).cuda(), torch.tensor([beta]).cuda()
a, b = (
torch.tensor([alpha]).to(device_type),
torch.tensor([beta]).to(device_type),
)
self._check_sampler_sampler(
Gamma(a, b),
scipy.stats.gamma(alpha, scale=1.0 / beta),
@ -3973,11 +3982,11 @@ class TestDistributions(DistributionsTestCase):
self.assertEqual(frac_zeros, 0.5, atol=0.05, rtol=0)
self.assertEqual(frac_ones, 0.5, atol=0.05, rtol=0)
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "CUDA and XPU not found")
def test_beta_underflow_gpu(self):
set_rng_seed(1)
num_samples = 50000
conc = torch.tensor(1e-2, dtype=torch.float64).cuda()
conc = torch.tensor(1e-2, dtype=torch.float64).to(device_type)
beta_samples = Beta(conc, conc).sample([num_samples])
self.assertEqual((beta_samples == 0).sum(), 0)
self.assertEqual((beta_samples == 1).sum(), 0)

View File

@ -258,8 +258,6 @@ class TestCommAnalysis(TestCase):
finally:
dist.destroy_process_group()
# lack of profiler on XPU
@expectedFailureXPU
def test_legacy_all_reduce(self):
def fn(x):
r = c10d.all_reduce(x, "sum", "", self.RANKS, self.WORLD_SIZE)
@ -268,8 +266,6 @@ class TestCommAnalysis(TestCase):
inp = T(10, 10)
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_legacy_all_reduce_coalesced(self):
def fn(x):
rs = c10d.all_reduce_coalesced(x, "sum", "", self.RANKS, self.WORLD_SIZE)
@ -278,8 +274,6 @@ class TestCommAnalysis(TestCase):
inp = [T(10, 10), T(15, 15)]
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_legacy_all_gather_into_tensor_coalesced(self):
def fn(x):
rs = c10d.all_gather_into_tensor_coalesced(
@ -293,8 +287,6 @@ class TestCommAnalysis(TestCase):
inp = [T(10, 10), T(15, 15)]
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_all_reduce(self):
def fn(x):
r = _c10d.all_reduce(x, "sum", "0")
@ -303,8 +295,6 @@ class TestCommAnalysis(TestCase):
inp = T(10, 10)
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_all_reduce_coalesced(self):
def fn(x):
rs = _c10d.all_reduce_coalesced(x, "sum", "0")
@ -313,8 +303,6 @@ class TestCommAnalysis(TestCase):
inp = [T(10, 10), T(15, 15)]
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_all_gather_into_tensor(self):
def fn(x):
rs = _c10d.all_gather_into_tensor(
@ -327,8 +315,6 @@ class TestCommAnalysis(TestCase):
inp = T(10, 10)
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_all_gather_into_tensor_coalesced(self):
def fn(x):
rs = _c10d.all_gather_into_tensor_coalesced(
@ -341,8 +327,6 @@ class TestCommAnalysis(TestCase):
inp = [T(10, 10), T(15, 15)]
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_reduce_scatter_tensor(self):
def fn(x):
rs = _c10d.reduce_scatter_tensor(
@ -356,8 +340,6 @@ class TestCommAnalysis(TestCase):
inp = T(self.WORLD_SIZE, 10)
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_reduce_scatter_tensor_coalesced(self):
def fn(x):
rs = _c10d.reduce_scatter_tensor_coalesced(

View File

@ -337,10 +337,12 @@ class Backend(str): # noqa: SLOT000
# assume default devices "cpu" and "cuda", but warn
warnings.warn(
f"Device capability of {name} unspecified, assuming `cpu` and "
"`cuda`. Please specify it via the `devices` argument of "
"`cuda` or `xpu`. Please specify it via the `devices` argument of "
"`register_backend`."
)
Backend.backend_capability[name.lower()] = ["cpu", "cuda"]
Backend.backend_capability[name.lower()] = (
["cpu", "cuda", "xpu"] if torch.xpu.is_available() else ["cpu", "cuda"]
)
elif isinstance(devices, str):
# Single device string specified. Simply convert to list.
Backend.backend_capability[name.lower()] = [devices]

View File

@ -339,26 +339,17 @@ def requires_gloo():
def requires_nccl_version(version, msg):
if TEST_CUDA:
if not c10d.is_nccl_available():
return skip_but_pass_in_sandcastle(
"c10d was not compiled with the NCCL backend",
)
else:
return skip_but_pass_in_sandcastle_if(
torch.cuda.nccl.version() < version,
f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
)
if not TEST_CUDA:
return lambda f: f
if not c10d.is_nccl_available():
return skip_but_pass_in_sandcastle(
"c10d was not compiled with the NCCL backend",
)
else:
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return decorator
return skip_but_pass_in_sandcastle_if(
torch.cuda.nccl.version() < version,
f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
)
def requires_nccl():

View File

@ -28,5 +28,5 @@ def _create_fake_pg(common_opts, backend_opts):
dist.Backend.register_backend(
"fake", _create_fake_pg, extended_api=True, devices=["cpu", "cuda", "hpu"]
"fake", _create_fake_pg, extended_api=True, devices=["cpu", "cuda", "hpu", "xpu"]
)