mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Reland][2/N]Port several test files under test/distributed to Intel GPU (#159473)
For https://github.com/pytorch/pytorch/issues/114850, we will port distributed tests to Intel GPU. This PR will work on some test files under test/distributed. We could enable Intel GPU with following methods and try the best to keep the original code styles: - instantiate_device_type_tests() - use "torch.accelerator.current_accelerator()" to determine the accelerator backend - use requires_accelerator_dist_backend to allow both nccl and xccl test - enabled XPU for some test path - Change the hardcoded world_size according to device_count. - Unify some common code under torch/testing/_internal for multiple backend, for example: Added xpu for Backend.backend_capability and dist.Backend.register_backend() Pull Request resolved: https://github.com/pytorch/pytorch/pull/159473 Approved by: https://github.com/guangyey, https://github.com/d4l3k
This commit is contained in:
committed by
PyTorch MergeBot
parent
71b272e4a3
commit
c9485f8ff3
@ -43,6 +43,7 @@ from torch.testing._internal.common_utils import (
|
||||
retry_on_connect_failures,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TEST_XPU,
|
||||
TestCase,
|
||||
)
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
@ -63,6 +64,8 @@ else:
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
def gpus_for_rank(world_size):
|
||||
"""Multigpu tests are designed to simulate the multi nodes with multi
|
||||
@ -70,8 +73,9 @@ def gpus_for_rank(world_size):
|
||||
On a single node, all visible GPUs are evenly
|
||||
divided to subsets, each process only uses a subset.
|
||||
"""
|
||||
visible_devices = list(range(torch.cuda.device_count()))
|
||||
gpus_per_process = torch.cuda.device_count() // world_size
|
||||
device_count = torch.accelerator.device_count()
|
||||
visible_devices = list(range(device_count))
|
||||
gpus_per_process = device_count // world_size
|
||||
gpus_for_rank = []
|
||||
for rank in range(world_size):
|
||||
gpus_for_rank.append(
|
||||
@ -401,7 +405,7 @@ class CommonDistributedDataParallelTest:
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
)
|
||||
|
||||
input = torch.randn(global_batch_size, 2).cuda(devices[0])
|
||||
input = torch.randn(global_batch_size, 2).to(devices[0])
|
||||
target = torch.randn(global_batch_size, 4)
|
||||
|
||||
return model, ddp_model, input, target
|
||||
@ -435,10 +439,10 @@ class CommonDistributedDataParallelTest:
|
||||
allow_none_grads=False,
|
||||
):
|
||||
# to reproduce the same training results
|
||||
torch.cuda.set_device(self.rank)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
torch.manual_seed(31415)
|
||||
model = copy.deepcopy(input_model).cuda()
|
||||
ddp_model = copy.deepcopy(input_model).cuda()
|
||||
model = copy.deepcopy(input_model).to(device_type)
|
||||
ddp_model = copy.deepcopy(input_model).to(device_type)
|
||||
ddp_model = nn.parallel.DistributedDataParallel(
|
||||
ddp_model,
|
||||
bucket_cap_mb=1,
|
||||
@ -554,8 +558,8 @@ class CommonDistributedDataParallelTest:
|
||||
def _prepare_dummy_data(self):
|
||||
ddp_bs = 16
|
||||
bs = ddp_bs * self.world_size
|
||||
input = torch.rand((bs, 20), device="cuda", requires_grad=True)
|
||||
target = torch.randn((bs, 20), device="cuda")
|
||||
input = torch.rand((bs, 20), device=device_type, requires_grad=True)
|
||||
target = torch.randn((bs, 20), device=device_type)
|
||||
offset = self.rank * ddp_bs
|
||||
ddp_input = input[offset : offset + ddp_bs]
|
||||
ddp_target = target[offset : offset + ddp_bs]
|
||||
@ -715,7 +719,7 @@ class CommonDistributedDataParallelTest:
|
||||
Test that checkpointing with weight sharing works.
|
||||
"""
|
||||
process_group = self._get_process_group()
|
||||
torch.cuda.set_device(self.rank)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
for use_bucket_view, static_graph in product((False, True), (False, True)):
|
||||
torch.manual_seed(31415)
|
||||
l1 = nn.Linear(20, 20)
|
||||
@ -738,7 +742,7 @@ class CommonDistributedDataParallelTest:
|
||||
same layer twice and having weights shared across layers.
|
||||
"""
|
||||
process_group = self._get_process_group()
|
||||
torch.cuda.set_device(self.rank)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
for use_bucket_view in (True, False):
|
||||
self._test_ddp_checkpointing(
|
||||
self.CheckpointTwiceModuleWeightSharing(),
|
||||
@ -1162,7 +1166,7 @@ class AbstractCommTest:
|
||||
|
||||
# Verify sequence numbers are appropriately incremented
|
||||
for i in range(10):
|
||||
t = torch.ones(1, device=torch.cuda.current_device())
|
||||
t = torch.ones(1, device=device_type)
|
||||
dist.all_reduce(t, group=process_group)
|
||||
if not c10d._rank_not_in_group(process_group):
|
||||
seq_num = self._verify_sequence_number_across_pg(
|
||||
@ -1193,7 +1197,7 @@ class AbstractCommTest:
|
||||
self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
|
||||
|
||||
def _test_sequence_num_incremented_default_group(self, backend_name):
|
||||
torch.cuda.set_device(self.rank)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend_name,
|
||||
@ -1207,7 +1211,7 @@ class AbstractCommTest:
|
||||
)
|
||||
|
||||
def _test_sequence_num_incremented_subgroup(self, backend_name):
|
||||
torch.cuda.set_device(self.rank)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend_name,
|
||||
@ -1262,8 +1266,8 @@ class AbstractCommTest:
|
||||
in_group_ranks = list(filter(lambda x: x % 2 == 0, range(self.world_size)))
|
||||
group = dist.new_group(in_group_ranks)
|
||||
|
||||
x = torch.zeros(2, 2).cuda(self.rank)
|
||||
xs = [torch.zeros(2, 2).cuda(self.rank) for _ in range(len(in_group_ranks))]
|
||||
x = torch.zeros(2, 2).to(self.rank)
|
||||
xs = [torch.zeros(2, 2).to(self.rank) for _ in range(len(in_group_ranks))]
|
||||
if self.rank not in in_group_ranks:
|
||||
msg = ".*{}.*does not belong to.*"
|
||||
with self.assertWarnsOnceRegex(UserWarning, msg.format("all_gather")):
|
||||
@ -1392,7 +1396,7 @@ class AbstractCommTest:
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
device = "cuda" if backend == "nccl" else "cpu"
|
||||
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
|
||||
# test alltoall_base
|
||||
tensor = torch.tensor([1, 0, 0, 1], dtype=torch.bool, device=device)
|
||||
zeros = torch.tensor([0, 0, 0, 0], dtype=torch.bool, device=device)
|
||||
@ -1574,8 +1578,8 @@ class CommTest(AbstractCommTest, MultiProcessTestCase):
|
||||
|
||||
class DummyWork(dist._Work):
|
||||
def wait(self, timeout=5.0):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.current_stream().synchronize()
|
||||
if torch.accelerator.is_available():
|
||||
torch.accelerator.current_stream().synchronize()
|
||||
return True
|
||||
|
||||
|
||||
@ -1790,6 +1794,18 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
|
||||
("cpu:gloo,cuda:nccl", "cpu:gloo,cuda:nccl"),
|
||||
]
|
||||
|
||||
if TEST_XPU:
|
||||
# Override backend_config_strings_and_expected_values for Intel GPU.
|
||||
backend_config_strings_and_expected_values[4:10] = [
|
||||
(dist.Backend.DUMMY, "cpu:dummy,cuda:dummy,xpu:dummy"),
|
||||
("DUMMY", "cpu:dummy,cuda:dummy,xpu:dummy"),
|
||||
("dummy", "cpu:dummy,cuda:dummy,xpu:dummy"),
|
||||
("cpu:dummy,xpu:dummy", "cpu:dummy,xpu:dummy"),
|
||||
("cpu:dummy,xpu:xccl", "cpu:dummy,xpu:xccl"),
|
||||
("cpu:gloo,xpu:dummy", "cpu:gloo,xpu:dummy"),
|
||||
("cpu:gloo,xpu:xccl", "cpu:gloo,xpu:xccl"),
|
||||
]
|
||||
|
||||
for config_str, expected_value in backend_config_strings_and_expected_values:
|
||||
with self.subTest(config_str):
|
||||
# ensures these configs strings are valid and no ValueError is raised
|
||||
@ -1800,6 +1816,8 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
|
||||
invalid_backend_config_strings = [
|
||||
"cpu:gloo,cuda:nccl,", # trailing comma
|
||||
"cpu:gloo,cuda:nccl,cpu:dummy", # duplicate device
|
||||
"cpu:gloo,xpu:xccl,", # trailing comma
|
||||
"cpu:gloo,xpu:xccl,cpu:dummy", # duplicate device
|
||||
]
|
||||
for config_str in invalid_backend_config_strings:
|
||||
with self.subTest(config_str):
|
||||
@ -1814,7 +1832,7 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "6789"
|
||||
dist.init_process_group(
|
||||
"cpu:dummy,cuda:dummy", rank=self.rank, world_size=self.world_size
|
||||
"cpu:dummy,cuda:dummy,xpu:dummy", rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
|
||||
# test all_gather
|
||||
@ -2053,7 +2071,7 @@ dist.init_process_group(rank=0, world_size=1, store=dist.HashStore())
|
||||
# correctly dispatched
|
||||
|
||||
# TODO: this will be updated in the future to not be backend specific
|
||||
device = "cuda" if backend == "nccl" else "cpu"
|
||||
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
|
||||
# ensure supported devices (cpu, cuda) succeeds during dispatch call
|
||||
tensor = torch.zeros(2, 2, device=torch.device(device))
|
||||
# multi tensor collectives
|
||||
@ -2119,7 +2137,7 @@ dist.init_process_group(rank=0, world_size=1, store=dist.HashStore())
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
device = "cuda" if backend == "nccl" else "cpu"
|
||||
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
|
||||
# test alltoall_base
|
||||
input_tensor = torch.ones(2, 2, device=torch.device(device))
|
||||
output_tensor = torch.zeros(2, 2, device=torch.device(device))
|
||||
@ -2251,8 +2269,9 @@ class LocalRankTest(MultiProcessTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert not torch.cuda._initialized, (
|
||||
"test_distributed must not have initialized CUDA context on main process"
|
||||
)
|
||||
if device_type != "cpu":
|
||||
assert not torch.get_device_module()._initialized, (
|
||||
"test_distributed must not have initialized {device_type} context on main process"
|
||||
)
|
||||
|
||||
run_tests()
|
||||
|
@ -25,7 +25,7 @@ from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
|
||||
from torch.testing._internal.common_device_type import e4m3_type
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
requires_nccl,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
|
||||
@ -59,7 +59,7 @@ if not dist.is_available():
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
class TestWithNCCL(MultiProcessTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
@ -75,13 +75,15 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(f"cuda:{self.rank}")
|
||||
return torch.device(self.rank)
|
||||
|
||||
def _init_process_group(self) -> None:
|
||||
torch.cuda.set_device(self.device)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
backend = dist.get_default_backend_for_device(self.device.type)
|
||||
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
backend=backend,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
@ -273,7 +275,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
)
|
||||
# check memory leak
|
||||
for i in range(1, 10):
|
||||
mem_usage[i] = torch.cuda.max_memory_allocated()
|
||||
mem_usage[i] = torch.accelerator.max_memory_allocated()
|
||||
compiled(arg)
|
||||
|
||||
assert mem_usage[9] == mem_usage[8]
|
||||
@ -370,14 +372,16 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_to_all_single(self) -> None:
|
||||
self._init_process_group()
|
||||
torch.cuda.set_device(self.device)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
|
||||
torch.manual_seed(42)
|
||||
send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))
|
||||
|
||||
input_split_sizes = send_sz_matrix[self.rank].tolist()
|
||||
output_split_sizes = send_sz_matrix[:, self.rank].tolist()
|
||||
input = torch.full((sum(input_split_sizes),), float(self.rank)).cuda()
|
||||
input = torch.full((sum(input_split_sizes),), float(self.rank)).to(
|
||||
self.device.type
|
||||
)
|
||||
|
||||
output = torch.ops._c10d_functional.all_to_all_single(
|
||||
input,
|
||||
@ -388,7 +392,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
output = torch.ops._c10d_functional.wait_tensor(output)
|
||||
expect = torch.cat(
|
||||
[
|
||||
torch.full((sz,), float(rank)).cuda()
|
||||
torch.full((sz,), float(rank)).to(self.device.type)
|
||||
for rank, sz in enumerate(output_split_sizes)
|
||||
]
|
||||
)
|
||||
@ -464,7 +468,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
@fresh_cache()
|
||||
def test_threading(self):
|
||||
self._init_process_group()
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
device = self.device
|
||||
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
buf0 = arg + 42
|
||||
@ -544,9 +548,9 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
return in_grad, w_grad
|
||||
|
||||
m, n, k = 128, 256, 64
|
||||
in_ = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
w = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
out_grad = torch.randn((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
in_ = torch.randn((m, k), device=self.device.type, dtype=torch.bfloat16)
|
||||
w = torch.randn((n, k), device=self.device.type, dtype=torch.bfloat16)
|
||||
out_grad = torch.randn((m, n), device=self.device.type, dtype=torch.bfloat16)
|
||||
|
||||
eager_in_grad, eager_w_grad = fp8_rowwise_backward(in_, w, out_grad)
|
||||
compile_in_grad, compile_w_grad = torch.compile(fp8_rowwise_backward)(
|
||||
@ -775,7 +779,8 @@ class CompileTest(TestCase):
|
||||
|
||||
self.rank = 0
|
||||
self.world_size = 2
|
||||
torch.cuda.set_device("cuda:0")
|
||||
torch.accelerator.set_device_index(0)
|
||||
self.device = torch.accelerator.current_accelerator()
|
||||
|
||||
store = FakeStore()
|
||||
dist.init_process_group(
|
||||
@ -801,7 +806,7 @@ class CompileTest(TestCase):
|
||||
ar1 = funcol.wait_tensor(ar1)
|
||||
return ar0, ar1
|
||||
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
arg = torch.rand(4, 4, device=self.device)
|
||||
compiled = torch.compile(func)
|
||||
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
@ -834,7 +839,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (arg,))
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -849,7 +854,7 @@ class CompileTest(TestCase):
|
||||
ar1 = [funcol.wait_tensor(out) for out in ar1]
|
||||
return ar0, ar1
|
||||
|
||||
args = [torch.rand(4, 4, device="cuda") for _ in range(2)]
|
||||
args = [torch.rand(4, 4, device=self.device.type) for _ in range(2)]
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, args)
|
||||
buf0, buf1, buf2, buf3 = find_buffer_assignments(code)
|
||||
@ -879,7 +884,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
out = AOTIRunnerUtil.run(func, (args,)) # noqa: F841
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -890,7 +895,7 @@ class CompileTest(TestCase):
|
||||
ar0 = funcol.wait_tensor(ar0)
|
||||
return ar0
|
||||
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
compiled = torch.compile(func)
|
||||
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
@ -915,7 +920,7 @@ class CompileTest(TestCase):
|
||||
# Expect allocation
|
||||
return ar0
|
||||
|
||||
arg = torch.rand(4, 4, device="cuda").T
|
||||
arg = torch.rand(4, 4, device=self.device.type).T
|
||||
compiled = torch.compile(func)
|
||||
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
@ -946,7 +951,7 @@ class CompileTest(TestCase):
|
||||
buf2 = torch.mm(arg, buf1)
|
||||
return buf1, buf2
|
||||
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
buf0, buf1 = find_buffer_assignments(code)
|
||||
@ -976,7 +981,7 @@ class CompileTest(TestCase):
|
||||
ag0 = funcol.wait_tensor(ag0)
|
||||
return ag0
|
||||
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
(
|
||||
@ -993,7 +998,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (arg,))
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -1003,7 +1008,7 @@ class CompileTest(TestCase):
|
||||
ag0 = [funcol.wait_tensor(out) for out in ag0]
|
||||
return ag0
|
||||
|
||||
args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
|
||||
args = [torch.rand(4, 4, device=self.device.type) for _ in range(4)]
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, args)
|
||||
(
|
||||
@ -1027,7 +1032,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
out = AOTIRunnerUtil.run(func, (args,)) # noqa: F841
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "This is a GPU test!")
|
||||
@fresh_cache()
|
||||
@ -1037,7 +1042,7 @@ class CompileTest(TestCase):
|
||||
return funcol.wait_tensor(t)
|
||||
|
||||
# Test aoti
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
(
|
||||
@ -1049,7 +1054,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (arg,))
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -1059,7 +1064,7 @@ class CompileTest(TestCase):
|
||||
rs0 = funcol.wait_tensor(rs0)
|
||||
return rs0
|
||||
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
(
|
||||
@ -1075,7 +1080,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (arg,))
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -1087,7 +1092,7 @@ class CompileTest(TestCase):
|
||||
rs0 = [funcol.wait_tensor(out) for out in rs0]
|
||||
return rs0
|
||||
|
||||
args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
|
||||
args = [torch.rand(4, 4, device=self.device.type) for _ in range(4)]
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, args)
|
||||
(
|
||||
@ -1111,7 +1116,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (args,))
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -1140,7 +1145,9 @@ class CompileTest(TestCase):
|
||||
|
||||
input_split_sizes = send_sz_matrix[self.rank]
|
||||
output_split_sizes = send_sz_matrix[:, self.rank].contiguous()
|
||||
input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).cuda()
|
||||
input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).to(
|
||||
self.device.type
|
||||
)
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
dynamic_shapes=True,
|
||||
@ -1174,7 +1181,7 @@ class CompileTest(TestCase):
|
||||
br1 = funcol.wait_tensor(br1)
|
||||
return br0, br1
|
||||
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
compiled = torch.compile(func)
|
||||
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
@ -1197,7 +1204,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (arg,))
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -1212,7 +1219,7 @@ class CompileTest(TestCase):
|
||||
ar1 = funcol.wait_tensor(ar1)
|
||||
return ar0, ar1
|
||||
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
compiled = torch.compile(func, fullgraph=True)
|
||||
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -26,7 +27,7 @@ from torch.distributed.tensor._collective_utils import (
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import _Partial, Shard
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_XPU
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
@ -35,6 +36,10 @@ from torch.testing._internal.distributed.fake_pg import FakeProcessGroup, FakeSt
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
device_count = torch.accelerator.device_count()
|
||||
|
||||
|
||||
def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_rank=-1):
|
||||
os.environ["MASTER_ADDR"] = addr
|
||||
os.environ["MASTER_PORT"] = port
|
||||
@ -44,6 +49,7 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_ran
|
||||
os.environ["LOCAL_RANK"] = f"{local_rank}"
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.")
|
||||
class DeviceMeshTestGlooBackend(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
@ -73,14 +79,16 @@ class DeviceMeshSetDeviceTest(DTensorTestBase):
|
||||
|
||||
# Set the device on each process before DeviceMesh constructor,
|
||||
# and device to be different than the default world rank
|
||||
torch.cuda.set_device((self.rank + 2) % self.world_size)
|
||||
torch.accelerator.set_device_index((self.rank + 2) % self.world_size)
|
||||
_set_env_var(world_size=self.world_size, rank=self.rank)
|
||||
DeviceMesh(self.device_type, mesh_tensor)
|
||||
self.assertTrue(is_initialized())
|
||||
|
||||
# check that the device is set to the correct device
|
||||
# and respect the previous set_device calls
|
||||
self.assertEqual(torch.cuda.current_device(), (self.rank + 2) % self.world_size)
|
||||
self.assertEqual(
|
||||
torch.accelerator.current_device_idx(), (self.rank + 2) % self.world_size
|
||||
)
|
||||
self.destroy_pg()
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@ -101,7 +109,7 @@ class DeviceMeshSetDeviceTest(DTensorTestBase):
|
||||
|
||||
# check that the device is set to the correct device
|
||||
# and respect the LOCAL_RANK env var
|
||||
self.assertEqual(torch.cuda.current_device(), local_rank)
|
||||
self.assertEqual(torch.accelerator.current_device_idx(), local_rank)
|
||||
self.destroy_pg()
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@ -120,7 +128,7 @@ class DeviceMeshSetDeviceTest(DTensorTestBase):
|
||||
self.assertTrue(is_initialized())
|
||||
|
||||
# check that the device is set to the correct device
|
||||
self.assertEqual(torch.cuda.current_device(), self.rank)
|
||||
self.assertEqual(torch.accelerator.current_device_idx(), self.rank)
|
||||
self.destroy_pg()
|
||||
|
||||
|
||||
@ -222,7 +230,7 @@ class DeviceMeshTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_device_mesh_2d(self):
|
||||
mesh_tensor = torch.arange(4).reshape(2, 2)
|
||||
# construct a cuda device mesh
|
||||
# construct a device mesh for self.device_type
|
||||
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
||||
|
||||
# check all dim groups
|
||||
@ -260,7 +268,11 @@ class DeviceMeshTest(DTensorTestBase):
|
||||
def test_fake_pg_device_mesh(self):
|
||||
fake_store = FakeStore()
|
||||
init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size)
|
||||
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device_type = (
|
||||
torch.accelerator.current_accelerator().type
|
||||
if torch.accelerator.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
mesh = DeviceMesh(device_type, torch.arange(self.world_size))
|
||||
|
||||
local_tensor = torch.randn(2, 8)
|
||||
@ -300,7 +312,7 @@ class DeviceMeshTest(DTensorTestBase):
|
||||
regex = r"Invalid mesh \[\[0, 1\], \[2, 3\]\] for ProcessGroup with ranks \[0, 1, 2, 3\]"
|
||||
with self.assertRaisesRegex(ValueError, regex):
|
||||
DeviceMesh.from_group(
|
||||
global_pg, "cuda", invalid_mesh, mesh_dim_names=("dim0", "dim1")
|
||||
global_pg, device_type, invalid_mesh, mesh_dim_names=("dim0", "dim1")
|
||||
)
|
||||
|
||||
device_mesh = init_device_mesh(self.device_type, (2, 2))
|
||||
@ -320,12 +332,16 @@ class DeviceMeshTest(DTensorTestBase):
|
||||
# test init_device_mesh with an invalid device type that contains a GPU index
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
init_device_mesh(
|
||||
"cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
|
||||
f"{device_type}:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_set_mesh_dim_group_options(self):
|
||||
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device_type = (
|
||||
torch.accelerator.current_accelerator().type
|
||||
if torch.accelerator.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
_mesh_resources._set_mesh_dim_group_options(1, "fake", None)
|
||||
|
||||
mesh_tensor = torch.arange(4).reshape(2, 2)
|
||||
@ -341,7 +357,7 @@ class DeviceMeshTestNDim(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
def test_device_mesh_nd(self):
|
||||
# construct a cuda device mesh
|
||||
# construct a device mesh for self.device_type
|
||||
mesh_tensor = torch.arange(8).reshape(2, 2, 2)
|
||||
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
||||
|
||||
@ -710,7 +726,9 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||
with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"):
|
||||
mesh_dim_names = ("DP", "TP")
|
||||
mesh = init_device_mesh(
|
||||
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
|
||||
self.device_type,
|
||||
(2, 4),
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
)
|
||||
mesh[child_mesh_dim_name]
|
||||
|
||||
@ -938,7 +956,9 @@ class TestMeshEnv(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_get_root_mesh(self):
|
||||
mesh_3d = init_device_mesh(
|
||||
self.device_type, (2, 2, 2), mesh_dim_names=("dp", "cp", "tp")
|
||||
self.device_type,
|
||||
(2, 2, 2),
|
||||
mesh_dim_names=("dp", "cp", "tp"),
|
||||
)
|
||||
|
||||
dp_cp_mesh = mesh_3d["dp", "cp"]
|
||||
@ -986,7 +1006,9 @@ class TestMeshEnv(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_get_all_submeshes(self):
|
||||
mesh_2d = init_device_mesh(
|
||||
self.device_type, (2, 4), mesh_dim_names=("replicate", "shard")
|
||||
self.device_type,
|
||||
(2, 4),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
)
|
||||
all_submeshes = _mesh_resources._get_all_submeshes(mesh_2d, "replicate")
|
||||
self.assertEqual(len(all_submeshes), 4)
|
||||
|
@ -43,11 +43,12 @@ from torch.testing._internal.common_distributed import (
|
||||
DynamoDistributedMultiProcTestCase,
|
||||
DynamoDistributedSingleProcTestCase,
|
||||
import_transformers_or_skip,
|
||||
requires_nccl,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import requires_cuda
|
||||
from torch.testing._internal.common_utils import skipIfXpu
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
def reset_rng_state():
|
||||
@ -270,7 +271,15 @@ def get_hf_bert(rank):
|
||||
except ImportError as e:
|
||||
raise unittest.SkipTest("Unable to import transformers") from e
|
||||
|
||||
batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}"
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
)
|
||||
batch_size, max_length, config, device = (
|
||||
4,
|
||||
512,
|
||||
BertConfig(),
|
||||
f"{device_type}:{rank}",
|
||||
)
|
||||
model = AutoModelForMaskedLM.from_config(config).to(device)
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device)
|
||||
decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(
|
||||
@ -550,8 +559,8 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
|
||||
|
||||
# Are these tests failing? Check and see if TestFakeDistributedSingleProc has a
|
||||
# single process version; if it's just a problem in the Dynamo distributed
|
||||
# optimizer, you should be able to repro it single process!
|
||||
@requires_nccl()
|
||||
# # optimizer, you should be able to repro it single process!
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
"""
|
||||
Note: MultiProcTestCase spawns processes per test and is slow.
|
||||
@ -559,12 +568,16 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
sparingly for integration tests.
|
||||
"""
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@config.patch(optimize_ddp=False, enable_compiler_collectives=True)
|
||||
def test_ddp_baseline_aot_eager_multiprocess(self):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
self.assertFalse(config.optimize_ddp)
|
||||
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
|
||||
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
|
||||
m = DDP(m, device_ids=[self.rank])
|
||||
m = torch.compile(m, backend="aot_eager")
|
||||
outputs = m(inputs)
|
||||
@ -632,7 +645,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
self.assertFalse(config.optimize_ddp)
|
||||
model = MyModel().to(device="cuda")
|
||||
model = MyModel().to(device=self.device_type)
|
||||
|
||||
# Activation checkpointing for Linear layers.
|
||||
non_reentrant_wrapper = functools.partial(
|
||||
@ -647,7 +660,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
)
|
||||
|
||||
model = DDP(model)
|
||||
x = torch.randn(10, 64).cuda()
|
||||
x = torch.randn(10, 64).to(self.device_type)
|
||||
correct_outputs = model(x)
|
||||
|
||||
opt_model = torch.compile(model)
|
||||
@ -659,14 +672,14 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
def test_fsdp_aot_eager(self):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
# Test with basic FSDP wrapping (outer wrap around whole model)
|
||||
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
|
||||
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
|
||||
fsdp_m = FSDP(m, use_orig_params=True)
|
||||
fsdp_m = torch.compile(fsdp_m, backend="aot_eager")
|
||||
outputs = fsdp_m(inputs)
|
||||
self.assertTrue(same(correct_outputs, outputs))
|
||||
|
||||
# Test with recursive wrapping, nested FSDP around each Linear
|
||||
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
|
||||
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
|
||||
fsdp_m = FSDP(
|
||||
m,
|
||||
auto_wrap_policy=functools.partial(
|
||||
@ -680,6 +693,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@requires_cuda_and_triton
|
||||
def test_ddp_optimizer_cudagraph(self):
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
@ -730,7 +744,9 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
from torch._dynamo.utils import counters
|
||||
|
||||
counters.clear()
|
||||
m, inputs, correct_outputs = get_mutating_model(f"cuda:{self.rank}")
|
||||
m, inputs, correct_outputs = get_mutating_model(
|
||||
f"{self.device_type}:{self.rank}"
|
||||
)
|
||||
fsdp_m = FSDP(m, use_orig_params=True)
|
||||
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
|
||||
outputs = fsdp_m(inputs)
|
||||
@ -748,7 +764,9 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
from torch._dynamo.utils import counters
|
||||
|
||||
counters.clear()
|
||||
m, inputs, correct_outputs = get_forced_getattr_module(f"cuda:{self.rank}")
|
||||
m, inputs, correct_outputs = get_forced_getattr_module(
|
||||
f"{self.device_type}:{self.rank}"
|
||||
)
|
||||
fsdp_m = FSDP(m, use_orig_params=True)
|
||||
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
|
||||
outputs = fsdp_m(inputs)
|
||||
@ -762,7 +780,9 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
from torch._dynamo.utils import counters
|
||||
|
||||
counters.clear()
|
||||
m, inputs, correct_outputs = get_forced_getattr_module(f"cuda:{self.rank}")
|
||||
m, inputs, correct_outputs = get_forced_getattr_module(
|
||||
f"{self.device_type}:{self.rank}"
|
||||
)
|
||||
fsdp_m = FSDP(m, use_orig_params=True)
|
||||
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
|
||||
outputs = fsdp_m(inputs)
|
||||
@ -774,14 +794,14 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
def test_fsdp_inductor(self):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
# Test with basic FSDP wrapping (outer wrap around whole model)
|
||||
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
|
||||
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
|
||||
fsdp_m = FSDP(m, use_orig_params=True)
|
||||
fsdp_m = torch.compile(fsdp_m, backend="inductor")
|
||||
outputs = fsdp_m(inputs)
|
||||
self.assertTrue(same(correct_outputs, outputs))
|
||||
|
||||
# Test with recursive wrapping, nested FSDP around each Linear
|
||||
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
|
||||
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
|
||||
fsdp_m = FSDP(
|
||||
m,
|
||||
auto_wrap_policy=functools.partial(
|
||||
@ -799,7 +819,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
def test_fsdp_activation_checkpointing(self):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
model, inputs = get_toy_model_for_activation_checkpointing(
|
||||
f"cuda:{self.rank}"
|
||||
f"{self.device_type}:{self.rank}"
|
||||
)
|
||||
is_inner = lambda module: isinstance(module, ToyInnerModel) # noqa: E731
|
||||
wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=is_inner)
|
||||
@ -961,7 +981,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
torch._dynamo.utils.clear_compilation_metrics()
|
||||
|
||||
# TODO: This should be possible to do inside the function, but
|
||||
device = f"cuda:{self.rank}"
|
||||
device = f"{self.device_type}:{self.rank}"
|
||||
|
||||
@torch.compile()
|
||||
def f(x, y):
|
||||
@ -1181,7 +1201,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
pg = dist.distributed_c10d._get_default_group()
|
||||
|
||||
device = f"cuda:{self.rank}"
|
||||
device = f"{self.device_type}:{self.rank}"
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def f(x):
|
||||
@ -1196,6 +1216,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
pg = dist.distributed_c10d.GroupMember.NON_GROUP_MEMBER
|
||||
self.assertEqual(f(x), x + 1)
|
||||
|
||||
@skipIfXpu # ProcessGroupXCCL doesn't support _set_default_timeout yet.
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@patch.object(torch._inductor.config, "fx_graph_cache", False)
|
||||
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
|
||||
@ -1205,7 +1226,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
torch._dynamo.utils.clear_compilation_metrics()
|
||||
|
||||
device = f"cuda:{self.rank}"
|
||||
device = f"{self.device_type}:{self.rank}"
|
||||
|
||||
pg = dist.distributed_c10d._get_default_group()
|
||||
|
||||
@ -1238,7 +1259,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
w = pg.allreduce(x)
|
||||
w.wait()
|
||||
torch.cuda.synchronize(device)
|
||||
torch.accelerator.synchronize(device)
|
||||
|
||||
metrics = torch._dynamo.utils.get_compilation_metrics()
|
||||
# Number of compiles same on all nodes
|
||||
@ -1247,6 +1268,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
for r in res[1:]:
|
||||
self.assertEqual(res[0], r)
|
||||
|
||||
@skipIfXpu # ProcessGroupXCCL doesn't support _set_default_timeout yet.
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@patch.object(torch._inductor.config, "fx_graph_cache", True)
|
||||
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
|
||||
@ -1258,7 +1280,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
with fresh_cache(), _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
torch._dynamo.utils.clear_compilation_metrics()
|
||||
|
||||
device = f"cuda:{self.rank}"
|
||||
device = f"{self.device_type}:{self.rank}"
|
||||
|
||||
pg = dist.distributed_c10d._get_default_group()
|
||||
|
||||
@ -1281,7 +1303,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
w = pg.allreduce(x)
|
||||
w.wait()
|
||||
torch.cuda.synchronize(device)
|
||||
torch.accelerator.synchronize(device)
|
||||
torch._dynamo.reset()
|
||||
|
||||
if self.rank == 0:
|
||||
@ -1298,11 +1320,11 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
w = pg.allreduce(x)
|
||||
w.wait()
|
||||
torch.cuda.synchronize(device)
|
||||
torch.accelerator.synchronize(device)
|
||||
|
||||
|
||||
@requires_nccl()
|
||||
@requires_cuda
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@unittest.skipUnless(torch.accelerator.is_available(), "Requires accelerator")
|
||||
class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
"""
|
||||
Test harness initializes dist process group.
|
||||
@ -1311,6 +1333,10 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
Use TestMultiProc for things that really need to run on multiple nodes
|
||||
"""
|
||||
|
||||
device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
)
|
||||
|
||||
def get_model(
|
||||
self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
|
||||
):
|
||||
@ -1428,6 +1454,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
self.assertEqual(len(break_reasons), 4)
|
||||
self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
|
||||
|
||||
@skipIfXpu # XPU device doesn't support flex_attention yet.
|
||||
@patch.object(config, "optimize_ddp", True)
|
||||
def test_compiled_flex_attention_full_model_ddp(self):
|
||||
class Model(torch.nn.Module):
|
||||
@ -1474,16 +1501,16 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
S = 512
|
||||
D = 64
|
||||
|
||||
device = "cuda"
|
||||
model = Model(S, H, D)
|
||||
model.to(device)
|
||||
model.to(self.device_type)
|
||||
model = torch.compile(model)
|
||||
model = DDP(model, device_ids=self.device_ids)
|
||||
|
||||
hidden_states = torch.randn(B, S, H * D).to(device)
|
||||
hidden_states = torch.randn(B, S, H * D).to(self.device_type)
|
||||
model(hidden_states)
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@skipIfXpu # XPU device doesn't support flex_attention yet.
|
||||
@patch.object(config, "optimize_ddp", True)
|
||||
def test_compiled_flex_attention_local_ddp(self):
|
||||
class Model(torch.nn.Module):
|
||||
@ -1530,15 +1557,14 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
S = 512
|
||||
D = 64
|
||||
|
||||
device = "cuda"
|
||||
model = Model(S, H, D)
|
||||
model.to(device)
|
||||
model.to(self.device_type)
|
||||
model = torch.compile(model)
|
||||
model = DDP(model, device_ids=self.device_ids)
|
||||
|
||||
hidden_states = torch.randn(B, S, H * D).to(device)
|
||||
hidden_states = torch.randn(B, S, H * D).to(self.device_type)
|
||||
model(hidden_states)
|
||||
torch.cuda.synchronize()
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@patch.object(config, "optimize_ddp", True)
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@ -1787,9 +1813,9 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
a = torch.cos(a)
|
||||
return a
|
||||
|
||||
mod = MockModule().cuda()
|
||||
mod = MockModule().to(self.device_type)
|
||||
mod = DDP(mod, bucket_cap_mb=1)
|
||||
x = torch.randn(N, N, device="cuda", requires_grad=True)
|
||||
x = torch.randn(N, N, device=self.device_type, requires_grad=True)
|
||||
args = (x,)
|
||||
|
||||
backend = "aot_eager"
|
||||
@ -1799,7 +1825,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
def test_fsdp_orig_params_assert(self):
|
||||
# Test with basic FSDP wrapping (outer wrap around whole model)
|
||||
m, inputs, _ = get_model(f"cuda:{self.rank}")
|
||||
m, inputs, _ = get_model(f"{self.device_type}:{self.rank}")
|
||||
fsdp_m = FSDP(m, use_orig_params=False)
|
||||
# Test is that this function call does not throw an exception.
|
||||
fsdp_m = torch.compile(fsdp_m)
|
||||
@ -1845,7 +1871,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
return out
|
||||
|
||||
device = f"cuda:{self.rank}"
|
||||
device = f"{self.device_type}:{self.rank}"
|
||||
m = ToyModel(
|
||||
in_feat=10,
|
||||
hidden_feat=5000,
|
||||
@ -1892,7 +1918,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
device = f"cuda:{self.rank}"
|
||||
device = f"{self.device_type}:{self.rank}"
|
||||
m = ToyModel(
|
||||
in_feat=10,
|
||||
hidden_feat=5000,
|
||||
@ -1933,9 +1959,14 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
class DuplicateModule(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._param = torch.randn((3,), device="cuda")
|
||||
device_type = (
|
||||
acc.type
|
||||
if (acc := torch.accelerator.current_accelerator())
|
||||
else "cpu"
|
||||
)
|
||||
self._param = torch.randn((3,), device=device_type)
|
||||
self._buf = torch.nn.Buffer(
|
||||
torch.randn((3,), requires_grad=False, device="cuda")
|
||||
torch.randn((3,), requires_grad=False, device=device_type)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -1948,7 +1979,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
model = DuplicateModule()
|
||||
fsdp_model = FSDP(copy.deepcopy(model), use_orig_params=True)
|
||||
fsdp_model = torch.compile(fsdp_model, backend="aot_eager")
|
||||
inp = torch.randn((2, 3), device="cuda")
|
||||
inp = torch.randn((2, 3), device=self.device_type)
|
||||
local_out = model(inp)
|
||||
fsdp_out = fsdp_model(inp)
|
||||
self.assertEqual(local_out, fsdp_out)
|
||||
@ -1965,8 +1996,13 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
class BufModule(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
device_type = (
|
||||
acc.type
|
||||
if (acc := torch.accelerator.current_accelerator())
|
||||
else "cpu"
|
||||
)
|
||||
self._buf = nn.Buffer(
|
||||
torch.randn((3,), requires_grad=False, device="cuda")
|
||||
torch.randn((3,), requires_grad=False, device=device_type)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -1975,7 +2011,12 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
class Model(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._param = nn.Parameter(torch.randn((1,), device="cuda"))
|
||||
device_type = (
|
||||
acc.type
|
||||
if (acc := torch.accelerator.current_accelerator())
|
||||
else "cpu"
|
||||
)
|
||||
self._param = nn.Parameter(torch.randn((1,), device=device_type))
|
||||
self._buf_module = BufModule()
|
||||
# Share the buffer, meaning same tensor but different source
|
||||
self._buf = self._buf_module._buf
|
||||
@ -1992,7 +2033,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
fsdp_model = FSDP(Model(), use_orig_params=True)
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
fsdp_model = torch.compile(fsdp_model, backend=cnt)
|
||||
inp = torch.randn((2, 3), device="cuda")
|
||||
inp = torch.randn((2, 3), device=self.device_type)
|
||||
for _ in range(15):
|
||||
fsdp_model(inp)
|
||||
# Check for no recompiles (if there were incorrect de-dup guards, then
|
||||
@ -2011,7 +2052,12 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
super().__init__()
|
||||
self._use_self = use_self
|
||||
torch.manual_seed(42) # force `_param` to be deterministic
|
||||
self._param = nn.Parameter(torch.randn((3,), device="cuda"))
|
||||
device_type = (
|
||||
acc.type
|
||||
if (acc := torch.accelerator.current_accelerator())
|
||||
else "cpu"
|
||||
)
|
||||
self._param = nn.Parameter(torch.randn((3,), device=device_type))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self._use_self:
|
||||
@ -2026,7 +2072,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
return x + y
|
||||
|
||||
model = ModuleWithStaticMethod(False)
|
||||
x = torch.randn((2, 3), device="cuda")
|
||||
x = torch.randn((2, 3), device=self.device_type)
|
||||
ref_out = model(x)
|
||||
test_outs: list[torch.Tensor] = []
|
||||
|
||||
|
@ -10,6 +10,7 @@ import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.logging
|
||||
import torch._dynamo.test_case
|
||||
import torch.distributed as c10d
|
||||
|
||||
# for some reason importing functional collectives after dynamo breaks collectives handling!
|
||||
import torch.distributed._functional_collectives as _functional_collectives
|
||||
@ -37,14 +38,14 @@ from torch.testing._internal.common_distributed import (
|
||||
DynamoDistributedMultiProcTestCase,
|
||||
DynamoDistributedSingleProcTestCase,
|
||||
MultiProcessTestCase,
|
||||
requires_nccl,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
requires_cuda,
|
||||
skipIfRocm,
|
||||
skipIfXpu,
|
||||
TEST_XPU,
|
||||
xfailIf,
|
||||
)
|
||||
@ -59,13 +60,15 @@ def _tolist_with_constrain_as_size(tensor):
|
||||
return lst
|
||||
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@instantiate_parametrized_tests
|
||||
class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
"""
|
||||
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
|
||||
"""
|
||||
|
||||
device = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
def get_world_trs(self):
|
||||
return {
|
||||
"tag": "",
|
||||
@ -102,8 +105,11 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
example,
|
||||
**self.get_world_trs(),
|
||||
)
|
||||
t = torch.randn(4, 4, device="cuda")
|
||||
inputs = (t if self.rank == 0 else torch.zeros(4, 4, device="cuda"), 0)
|
||||
t = torch.randn(4, 4, device=self.device)
|
||||
inputs = (
|
||||
t if self.rank == 0 else torch.zeros(4, 4, device=self.device),
|
||||
0,
|
||||
)
|
||||
eager_out = example(*inputs)
|
||||
self.assertTrue(same(t, eager_out))
|
||||
|
||||
@ -137,7 +143,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
matmul_cat_col,
|
||||
**self.get_world_trs(),
|
||||
)
|
||||
inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 6
|
||||
inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 6
|
||||
|
||||
eager_out = matmul_cat_col(*inputs)
|
||||
compiled_matmul_cat_col = compile(matmul_cat_col, inputs)
|
||||
@ -179,7 +185,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
for nelem in [1024, 2048, 4096]:
|
||||
# CI (Tesla T4) does not support bfloat16 compilation natively,
|
||||
# using float
|
||||
x = torch.randn(nelem, device="cuda", dtype=torch.float)
|
||||
x = torch.randn(nelem, device=self.device, dtype=torch.float)
|
||||
golden_out = eager_func(x)
|
||||
|
||||
for _ in range(3):
|
||||
@ -217,8 +223,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
eager_func,
|
||||
**self.get_world_trs(),
|
||||
)
|
||||
eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4
|
||||
inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
|
||||
eager_inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 4
|
||||
inductor_inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 2
|
||||
|
||||
eager_out = inductor_func(eager_func(*eager_inputs), *inductor_inputs)
|
||||
compiled_inductor_func = compile(
|
||||
@ -256,8 +262,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
inductor_func,
|
||||
**self.get_world_trs(),
|
||||
)
|
||||
inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4
|
||||
eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
|
||||
inductor_inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 4
|
||||
eager_inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 2
|
||||
|
||||
eager_out = eager_func(inductor_func(*inductor_inputs), *eager_inputs)
|
||||
compiled_inductor_func = compile(inductor_func, inductor_inputs)
|
||||
@ -270,6 +276,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1728
|
||||
@skipIfRocm
|
||||
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1728
|
||||
def test_eager_async_allreduce_inductor_wait(self):
|
||||
import torch.distributed as dist
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
@ -292,7 +299,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
return y * y
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
x = torch.ones(12800, 12800, device="cuda") + self.rank
|
||||
x = torch.ones(12800, 12800, device=self.device) + self.rank
|
||||
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
|
||||
|
||||
# NOTE: We run for 10 iterations each, to ensure that the GPU execution is way behind CPU
|
||||
@ -363,7 +370,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
return (e,)
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
inputs = torch.ones(4, 4, device="cuda") + self.rank
|
||||
inputs = torch.ones(4, 4, device=self.device) + self.rank
|
||||
compiled = torch.compile(func)
|
||||
out = compiled(inputs, **self.get_world_trs())
|
||||
correct = func(inputs, **self.get_world_trs())
|
||||
@ -380,7 +387,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
inputs = (
|
||||
# rank0: [0., 1.], rank1: [2., 3.]
|
||||
torch.arange(2, dtype=torch.float32, device="cuda") + 2 * self.rank,
|
||||
torch.arange(2, dtype=torch.float32, device=self.device)
|
||||
+ 2 * self.rank,
|
||||
[1, 0],
|
||||
)
|
||||
compiled = torch.compile(func)
|
||||
@ -389,7 +397,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
# rank0: [2., 3.], rank1: [0., 1.]
|
||||
expected = torch.arange(2, dtype=torch.float32, device="cuda") + 2 * (
|
||||
expected = torch.arange(2, dtype=torch.float32, device=self.device) + 2 * (
|
||||
(self.rank - 1 + self.world_size) % self.world_size
|
||||
)
|
||||
self.assertEqual(out, expected)
|
||||
@ -412,9 +420,9 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
return out
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
model = Model().cuda()
|
||||
model = Model().to(self.device)
|
||||
model_compiled = torch.compile(model)
|
||||
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda")
|
||||
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device=self.device)
|
||||
out = model_compiled(inp, self.world_size, **self.get_world_trs())
|
||||
correct = model(inp, self.world_size, **self.get_world_trs())
|
||||
self.assertTrue(same(out, correct))
|
||||
@ -429,7 +437,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
func_compiled = torch.compile(func)
|
||||
inp = torch.tensor(self.rank, dtype=torch.long, device="cuda")
|
||||
inp = torch.tensor(self.rank, dtype=torch.long, device=self.device)
|
||||
out = func_compiled(inp, self.world_size)
|
||||
correct = func(inp, self.world_size)
|
||||
self.assertTrue(same(out, correct))
|
||||
@ -451,9 +459,9 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
return out
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
model = Model().cuda()
|
||||
model = Model().to(self.device)
|
||||
model_compiled = torch.compile(model)
|
||||
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda")
|
||||
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device=self.device)
|
||||
out = model_compiled(inp, self.world_size, **self.get_world_trs())
|
||||
correct = model(inp, self.world_size, **self.get_world_trs())
|
||||
self.assertTrue(same(out, correct))
|
||||
@ -482,7 +490,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
example,
|
||||
**self.get_world_trs(),
|
||||
)
|
||||
inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
|
||||
inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 2
|
||||
|
||||
eager_out = example(*inputs)
|
||||
compiled_matmul_cat_col = compile(example, inputs)
|
||||
@ -509,7 +517,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
example,
|
||||
**self.get_world_trs(),
|
||||
)
|
||||
inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
|
||||
inputs = (torch.ones(4, 4, device=self.device) + self.rank,) * 2
|
||||
|
||||
eager_out = example(*inputs)
|
||||
compiled_fn = compile(example, inputs)
|
||||
@ -563,7 +571,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
dtype=torch.int64,
|
||||
)
|
||||
inputs = (
|
||||
torch.ones(int(row), 5, device="cuda") * (self.rank + 1),
|
||||
torch.ones(int(row), 5, device=self.device) * (self.rank + 1),
|
||||
input_split_sizes_tensor,
|
||||
output_split_sizes_tensor,
|
||||
)
|
||||
@ -732,7 +740,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
dtype=torch.int64,
|
||||
)
|
||||
inputs = (
|
||||
torch.ones(int(row), 5, device="cuda", requires_grad=True)
|
||||
torch.ones(int(row), 5, device=self.device, requires_grad=True)
|
||||
* (self.rank + 1),
|
||||
input_split_sizes_tensor,
|
||||
output_split_sizes_tensor,
|
||||
@ -795,7 +803,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
inputs = (
|
||||
torch.ones(self.world_size, self.world_size, device="cuda")
|
||||
torch.ones(self.world_size, self.world_size, device=self.device)
|
||||
* (self.rank + 1),
|
||||
)
|
||||
trs = self.get_world_trs()
|
||||
@ -819,8 +827,11 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@requires_nccl()
|
||||
@requires_cuda
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@unittest.skipIf(
|
||||
not torch.accelerator.is_available(),
|
||||
"No accelerator is available",
|
||||
)
|
||||
class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
"""
|
||||
Prefer single-proc test runner for basic tests as it is easier to work with.
|
||||
@ -843,7 +854,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
ar = torch.ops.c10d_functional.wait_tensor(ar)
|
||||
return ar
|
||||
|
||||
inputs = torch.ones(4, 4, device="cuda")
|
||||
inputs = torch.ones(4, 4, device=self.device)
|
||||
|
||||
compiled = torch.compile(func)
|
||||
out = compiled(inputs, **self.get_world_trs())
|
||||
@ -878,7 +889,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
other = torch.ones_like(inp) + 22
|
||||
return ar, other
|
||||
|
||||
inputs = torch.ones(4, 4, device="cuda")
|
||||
inputs = torch.ones(4, 4, device=self.device)
|
||||
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
@ -911,7 +922,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
other = torch.ones_like(inp) + 22
|
||||
return ar, y, other
|
||||
|
||||
inputs = torch.ones(4, 4, device="cuda")
|
||||
inputs = torch.ones(4, 4, device=self.device)
|
||||
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
@ -952,7 +963,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
ar = _functional_collectives.all_reduce(inp, "sum", "0")
|
||||
return ar
|
||||
|
||||
inputs = torch.ones(4, 4, device="cuda")
|
||||
inputs = torch.ones(4, 4, device=self.device)
|
||||
counter = CompileCounter()
|
||||
compiled = torch.compile(func, backend=counter)
|
||||
out = compiled(inputs)
|
||||
@ -963,12 +974,13 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
self.assertEqual(counter.op_count, 2)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
def test_dynamo_trace_all_gather_tensor(self):
|
||||
def func(inp):
|
||||
ar = _functional_collectives.all_gather_tensor(inp, 0, "0")
|
||||
return ar
|
||||
|
||||
inputs = torch.ones(4, 4, device="cuda")
|
||||
inputs = torch.ones(4, 4, device=self.device)
|
||||
counter = CompileCounter()
|
||||
compiled = torch.compile(func, backend=counter)
|
||||
out = compiled(inputs)
|
||||
@ -979,6 +991,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
self.assertEqual(counter.op_count, 2)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
def test_dynamo_trace_all_gather_tensor_pg(self):
|
||||
def func(inp, *, pg):
|
||||
ar = _functional_collectives.all_gather_tensor(inp, 0, pg)
|
||||
@ -995,6 +1008,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
self.assertEqual(counter.op_count, 2)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
def test_dynamo_rewrite_dist_all_gather(self):
|
||||
def func(inp, out, *, pg):
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
@ -1020,6 +1034,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
assert counter.op_count == 3
|
||||
assert same(outputs, correct_outputs)
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
def test_dynamo_rewrite_dist_all_gather_list(self):
|
||||
def func(inp, out, *, pg):
|
||||
torch.distributed.all_gather(
|
||||
@ -1042,6 +1057,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
assert counter.frame_count == 1
|
||||
assert same(outputs, correct_outputs)
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
def test_dynamo_rewrite_dist_all_gather_args_match(self):
|
||||
# Duplicated most of the structure from test_dynamo_rewrite_dist_all_gather
|
||||
# except uses kwargs to ensure rewrite has matching arg names
|
||||
@ -1070,6 +1086,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
assert counter.op_count == 3
|
||||
assert same(outputs, correct_outputs)
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
def test_dynamo_rewrite_dist_reduce_scatter(self):
|
||||
def func(inp, out, *, pg):
|
||||
torch.distributed.reduce_scatter_tensor(
|
||||
@ -1237,6 +1254,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
input = torch.ones(2, device=self.device)
|
||||
compiled(input)
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
def test_dynamo_support_collective_op_with_async_op_False(self):
|
||||
def func(inp, out, *, pg):
|
||||
# user explicitly set the attribute `async_op` to False,
|
||||
@ -1296,12 +1314,13 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
assert counter.op_count == 1
|
||||
assert same(outputs, correct_outputs)
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
def test_dynamo_trace_reduce_scatter_tensor(self):
|
||||
def func(inp):
|
||||
ar = _functional_collectives.reduce_scatter_tensor(inp, "sum", 0, "0")
|
||||
return ar
|
||||
|
||||
inputs = torch.ones(4, 4, device="cuda")
|
||||
inputs = torch.ones(4, 4, device=self.device)
|
||||
counter = CompileCounter()
|
||||
compiled = torch.compile(func, backend=counter)
|
||||
out = compiled(inputs)
|
||||
@ -1312,6 +1331,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
self.assertEqual(counter.op_count, 2)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
def test_dynamo_trace_allgather_coalesced(self):
|
||||
def func(inp, *, tag, ranks, group_size):
|
||||
ar = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(
|
||||
@ -1319,7 +1339,10 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
)
|
||||
return ar
|
||||
|
||||
inputs = [torch.ones(4, 4, device="cuda"), torch.ones(6, 6, device="cuda")]
|
||||
inputs = [
|
||||
torch.ones(4, 4, device=self.device),
|
||||
torch.ones(6, 6, device=self.device),
|
||||
]
|
||||
counter = CompileCounter()
|
||||
compiled = torch.compile(func, backend=counter)
|
||||
out = compiled(inputs, **self.get_world_trs())
|
||||
@ -1339,7 +1362,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
ar = _functional_collectives.all_reduce(inp, "sum", "0")
|
||||
return ar
|
||||
|
||||
input = torch.ones(4, 4, device="cuda", requires_grad=True)
|
||||
input = torch.ones(4, 4, device=self.device, requires_grad=True)
|
||||
compiled = torch.compile(
|
||||
func, backend="aot_eager"
|
||||
) # inductor bug with single-op allreduce graph
|
||||
@ -1357,6 +1380,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs())
|
||||
self.assertEqual(x.size(), out.size())
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
|
||||
def test_inductor_all_gather_coalesced(self):
|
||||
@ -1376,7 +1400,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
other = torch.ones_like(inp) + 22
|
||||
return ar0, y, other, ar1
|
||||
|
||||
inputs = torch.ones(4, 4, device="cuda")
|
||||
inputs = torch.ones(4, 4, device=self.device)
|
||||
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
@ -1403,6 +1427,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
correct = func(inputs, **self.get_world_trs())
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
|
||||
def test_inductor_reduce_scatter_coalesced(self):
|
||||
@ -1422,7 +1447,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
other = torch.ones_like(inp) + 22
|
||||
return ar0, y, other, ar1
|
||||
|
||||
inputs = torch.ones(4, 4, device="cuda")
|
||||
inputs = torch.ones(4, 4, device=self.device)
|
||||
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
@ -1449,6 +1474,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
correct = func(inputs, **self.get_world_trs())
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_reorder_peak_memory(self):
|
||||
"""
|
||||
@ -1470,7 +1496,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
other = torch.ones_like(inp) + 22
|
||||
return ar0, y, other, ar1
|
||||
|
||||
inputs = torch.ones(4, 4, device="cuda")
|
||||
inputs = torch.ones(4, 4, device=self.device)
|
||||
|
||||
# get stats directly from the internal helper without affecting the real pass's signature
|
||||
node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None
|
||||
@ -1642,10 +1668,10 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
return y, ag_0_out, ag_1_out
|
||||
|
||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
x = torch.ones(4, 384, device=self.device, dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device=self.device, dtype=torch.float32)
|
||||
ag_0 = torch.ones(384, 512, device=self.device, dtype=torch.float32)
|
||||
ag_1 = torch.ones(384, 512, device=self.device, dtype=torch.float32)
|
||||
inputs = [x, w, ag_0, ag_1]
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
@ -1810,12 +1836,12 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
rs_3_out,
|
||||
)
|
||||
|
||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_0 = torch.ones(1024, 512, device="cuda", dtype=torch.float32)
|
||||
ag_1 = torch.ones(512, 1024, device="cuda", dtype=torch.float32)
|
||||
ag_2 = torch.ones(1024, 512, device="cuda", dtype=torch.float32)
|
||||
ag_3 = torch.ones(512, 1024, device="cuda", dtype=torch.float32)
|
||||
x = torch.ones(4, 384, device=self.device, dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device=self.device, dtype=torch.float32)
|
||||
ag_0 = torch.ones(1024, 512, device=self.device, dtype=torch.float32)
|
||||
ag_1 = torch.ones(512, 1024, device=self.device, dtype=torch.float32)
|
||||
ag_2 = torch.ones(1024, 512, device=self.device, dtype=torch.float32)
|
||||
ag_3 = torch.ones(512, 1024, device=self.device, dtype=torch.float32)
|
||||
inputs = [x, w, ag_0, ag_1, ag_2, ag_3]
|
||||
|
||||
# get stats directly from the internal helper without affecting the real pass's signature
|
||||
@ -1917,6 +1943,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
node_stat1 = next(it)
|
||||
self.assertTrue("collective ordering" in node_stat1.limiting_factor)
|
||||
|
||||
@skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_reorder_respects_wait_dep(self):
|
||||
"""
|
||||
@ -1939,7 +1966,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
# ensure other is not incorrectly aliasing ar's buffer
|
||||
return ag_1_wait
|
||||
|
||||
inputs = torch.ones(4, 4, device="cuda")
|
||||
inputs = torch.ones(4, 4, device=self.device)
|
||||
|
||||
# get stats directly from the internal helper without affecting the real pass's signature
|
||||
node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None
|
||||
@ -1988,7 +2015,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
self.assertEqual(stats.moves, 0)
|
||||
|
||||
|
||||
@requires_nccl()
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
class TestSyncDecisionCrossRanks(MultiProcessTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
@ -2004,16 +2031,21 @@ class TestSyncDecisionCrossRanks(MultiProcessTestCase):
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(f"cuda:{self.rank}")
|
||||
device_type = torch.accelerator.current_accelerator().type
|
||||
return torch.device(f"{device_type}:{self.rank}")
|
||||
|
||||
def _init_process_group(self) -> None:
|
||||
torch._inductor.config.triton.store_cubin = True
|
||||
torch._inductor.config.debug = True
|
||||
|
||||
torch.cuda.set_device(self.device)
|
||||
torch.get_device_module(self.device).set_device(self.device)
|
||||
store = torch.distributed.FileStore(self.file_name, self.world_size)
|
||||
backend = c10d.get_default_backend_for_device(
|
||||
torch.accelerator.current_accelerator().type
|
||||
)
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
backend=backend,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
|
@ -54,6 +54,8 @@ DEFAULT_HOSTNAME = "localhost"
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
def gpus_for_rank(world_size):
|
||||
"""Multigpu tests are designed to simulate the multi nodes with multi
|
||||
@ -61,8 +63,8 @@ def gpus_for_rank(world_size):
|
||||
On a single node, all visible GPUs are evenly
|
||||
divided to subsets, each process only uses a subset.
|
||||
"""
|
||||
visible_devices = list(range(torch.cuda.device_count()))
|
||||
gpus_per_process = torch.cuda.device_count() // world_size
|
||||
visible_devices = list(range(torch.accelerator.device_count()))
|
||||
gpus_per_process = torch.accelerator.device_count() // world_size
|
||||
gpus_for_rank = []
|
||||
for rank in range(world_size):
|
||||
gpus_for_rank.append(
|
||||
@ -1174,8 +1176,8 @@ class TestClientProtocol(TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert not torch.cuda._initialized, (
|
||||
"test_distributed must not have initialized CUDA context on main process"
|
||||
)
|
||||
|
||||
if device_type != "cpu":
|
||||
assert not torch.get_device_module()._initialized, (
|
||||
"test_distributed must not have initialized {device_type} context on main process"
|
||||
)
|
||||
run_tests()
|
||||
|
@ -115,10 +115,13 @@ from torch.testing._internal.common_utils import (
|
||||
set_default_dtype,
|
||||
set_rng_seed,
|
||||
skipIfTorchDynamo,
|
||||
TEST_XPU,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
load_tests = load_tests
|
||||
@ -1788,18 +1791,21 @@ class TestDistributions(DistributionsTestCase):
|
||||
).logpmf(sample)
|
||||
self.assertEqual(log_prob, expected, atol=1e-4, rtol=0)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "CUDA and XPU not found")
|
||||
def test_zero_excluded_binomial(self):
|
||||
vals = Binomial(
|
||||
total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.9).cuda()
|
||||
total_count=torch.tensor(1.0).to(device_type),
|
||||
probs=torch.tensor(0.9).to(device_type),
|
||||
).sample(torch.Size((100000000,)))
|
||||
self.assertTrue((vals >= 0).all())
|
||||
vals = Binomial(
|
||||
total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.1).cuda()
|
||||
total_count=torch.tensor(1.0).to(device_type),
|
||||
probs=torch.tensor(0.1).to(device_type),
|
||||
).sample(torch.Size((100000000,)))
|
||||
self.assertTrue((vals < 2).all())
|
||||
vals = Binomial(
|
||||
total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.5).cuda()
|
||||
total_count=torch.tensor(1.0).to(device_type),
|
||||
probs=torch.tensor(0.5).to(device_type),
|
||||
).sample(torch.Size((10000,)))
|
||||
# vals should be roughly half zeroes, half ones
|
||||
assert (vals == 0.0).sum() > 4000
|
||||
@ -2050,15 +2056,15 @@ class TestDistributions(DistributionsTestCase):
|
||||
)
|
||||
torch.set_default_dtype(saved_dtype)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "CUDA and XPU not found")
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_poisson_gpu_sample(self):
|
||||
set_rng_seed(1)
|
||||
for rate in [0.12, 0.9, 4.0]:
|
||||
self._check_sampler_discrete(
|
||||
Poisson(torch.tensor([rate]).cuda()),
|
||||
Poisson(torch.tensor([rate]).to(device_type)),
|
||||
scipy.stats.poisson(rate),
|
||||
f"Poisson(lambda={rate}, cuda)",
|
||||
f"Poisson(lambda={rate}, {device_type})",
|
||||
failure_rate=1e-3,
|
||||
)
|
||||
|
||||
@ -3490,13 +3496,13 @@ class TestDistributions(DistributionsTestCase):
|
||||
|
||||
self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "CUDA and XPU not found")
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_gamma_gpu_shape(self):
|
||||
alpha = torch.randn(2, 3).cuda().exp().requires_grad_()
|
||||
beta = torch.randn(2, 3).cuda().exp().requires_grad_()
|
||||
alpha_1d = torch.randn(1).cuda().exp().requires_grad_()
|
||||
beta_1d = torch.randn(1).cuda().exp().requires_grad_()
|
||||
alpha = torch.randn(2, 3).to(device_type).exp().requires_grad_()
|
||||
beta = torch.randn(2, 3).to(device_type).exp().requires_grad_()
|
||||
alpha_1d = torch.randn(1).to(device_type).exp().requires_grad_()
|
||||
beta_1d = torch.randn(1).to(device_type).exp().requires_grad_()
|
||||
self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
|
||||
self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
|
||||
self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
|
||||
@ -3527,7 +3533,10 @@ class TestDistributions(DistributionsTestCase):
|
||||
def test_gamma_gpu_sample(self):
|
||||
set_rng_seed(0)
|
||||
for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
|
||||
a, b = torch.tensor([alpha]).cuda(), torch.tensor([beta]).cuda()
|
||||
a, b = (
|
||||
torch.tensor([alpha]).to(device_type),
|
||||
torch.tensor([beta]).to(device_type),
|
||||
)
|
||||
self._check_sampler_sampler(
|
||||
Gamma(a, b),
|
||||
scipy.stats.gamma(alpha, scale=1.0 / beta),
|
||||
@ -3973,11 +3982,11 @@ class TestDistributions(DistributionsTestCase):
|
||||
self.assertEqual(frac_zeros, 0.5, atol=0.05, rtol=0)
|
||||
self.assertEqual(frac_ones, 0.5, atol=0.05, rtol=0)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "CUDA and XPU not found")
|
||||
def test_beta_underflow_gpu(self):
|
||||
set_rng_seed(1)
|
||||
num_samples = 50000
|
||||
conc = torch.tensor(1e-2, dtype=torch.float64).cuda()
|
||||
conc = torch.tensor(1e-2, dtype=torch.float64).to(device_type)
|
||||
beta_samples = Beta(conc, conc).sample([num_samples])
|
||||
self.assertEqual((beta_samples == 0).sum(), 0)
|
||||
self.assertEqual((beta_samples == 1).sum(), 0)
|
||||
|
@ -258,8 +258,6 @@ class TestCommAnalysis(TestCase):
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
# lack of profiler on XPU
|
||||
@expectedFailureXPU
|
||||
def test_legacy_all_reduce(self):
|
||||
def fn(x):
|
||||
r = c10d.all_reduce(x, "sum", "", self.RANKS, self.WORLD_SIZE)
|
||||
@ -268,8 +266,6 @@ class TestCommAnalysis(TestCase):
|
||||
inp = T(10, 10)
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
# lack of profiler on XPU
|
||||
@expectedFailureXPU
|
||||
def test_legacy_all_reduce_coalesced(self):
|
||||
def fn(x):
|
||||
rs = c10d.all_reduce_coalesced(x, "sum", "", self.RANKS, self.WORLD_SIZE)
|
||||
@ -278,8 +274,6 @@ class TestCommAnalysis(TestCase):
|
||||
inp = [T(10, 10), T(15, 15)]
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
# lack of profiler on XPU
|
||||
@expectedFailureXPU
|
||||
def test_legacy_all_gather_into_tensor_coalesced(self):
|
||||
def fn(x):
|
||||
rs = c10d.all_gather_into_tensor_coalesced(
|
||||
@ -293,8 +287,6 @@ class TestCommAnalysis(TestCase):
|
||||
inp = [T(10, 10), T(15, 15)]
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
# lack of profiler on XPU
|
||||
@expectedFailureXPU
|
||||
def test_all_reduce(self):
|
||||
def fn(x):
|
||||
r = _c10d.all_reduce(x, "sum", "0")
|
||||
@ -303,8 +295,6 @@ class TestCommAnalysis(TestCase):
|
||||
inp = T(10, 10)
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
# lack of profiler on XPU
|
||||
@expectedFailureXPU
|
||||
def test_all_reduce_coalesced(self):
|
||||
def fn(x):
|
||||
rs = _c10d.all_reduce_coalesced(x, "sum", "0")
|
||||
@ -313,8 +303,6 @@ class TestCommAnalysis(TestCase):
|
||||
inp = [T(10, 10), T(15, 15)]
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
# lack of profiler on XPU
|
||||
@expectedFailureXPU
|
||||
def test_all_gather_into_tensor(self):
|
||||
def fn(x):
|
||||
rs = _c10d.all_gather_into_tensor(
|
||||
@ -327,8 +315,6 @@ class TestCommAnalysis(TestCase):
|
||||
inp = T(10, 10)
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
# lack of profiler on XPU
|
||||
@expectedFailureXPU
|
||||
def test_all_gather_into_tensor_coalesced(self):
|
||||
def fn(x):
|
||||
rs = _c10d.all_gather_into_tensor_coalesced(
|
||||
@ -341,8 +327,6 @@ class TestCommAnalysis(TestCase):
|
||||
inp = [T(10, 10), T(15, 15)]
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
# lack of profiler on XPU
|
||||
@expectedFailureXPU
|
||||
def test_reduce_scatter_tensor(self):
|
||||
def fn(x):
|
||||
rs = _c10d.reduce_scatter_tensor(
|
||||
@ -356,8 +340,6 @@ class TestCommAnalysis(TestCase):
|
||||
inp = T(self.WORLD_SIZE, 10)
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
# lack of profiler on XPU
|
||||
@expectedFailureXPU
|
||||
def test_reduce_scatter_tensor_coalesced(self):
|
||||
def fn(x):
|
||||
rs = _c10d.reduce_scatter_tensor_coalesced(
|
||||
|
@ -337,10 +337,12 @@ class Backend(str): # noqa: SLOT000
|
||||
# assume default devices "cpu" and "cuda", but warn
|
||||
warnings.warn(
|
||||
f"Device capability of {name} unspecified, assuming `cpu` and "
|
||||
"`cuda`. Please specify it via the `devices` argument of "
|
||||
"`cuda` or `xpu`. Please specify it via the `devices` argument of "
|
||||
"`register_backend`."
|
||||
)
|
||||
Backend.backend_capability[name.lower()] = ["cpu", "cuda"]
|
||||
Backend.backend_capability[name.lower()] = (
|
||||
["cpu", "cuda", "xpu"] if torch.xpu.is_available() else ["cpu", "cuda"]
|
||||
)
|
||||
elif isinstance(devices, str):
|
||||
# Single device string specified. Simply convert to list.
|
||||
Backend.backend_capability[name.lower()] = [devices]
|
||||
|
@ -339,26 +339,17 @@ def requires_gloo():
|
||||
|
||||
|
||||
def requires_nccl_version(version, msg):
|
||||
if TEST_CUDA:
|
||||
if not c10d.is_nccl_available():
|
||||
return skip_but_pass_in_sandcastle(
|
||||
"c10d was not compiled with the NCCL backend",
|
||||
)
|
||||
else:
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
torch.cuda.nccl.version() < version,
|
||||
f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
|
||||
)
|
||||
if not TEST_CUDA:
|
||||
return lambda f: f
|
||||
if not c10d.is_nccl_available():
|
||||
return skip_but_pass_in_sandcastle(
|
||||
"c10d was not compiled with the NCCL backend",
|
||||
)
|
||||
else:
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
torch.cuda.nccl.version() < version,
|
||||
f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
|
||||
)
|
||||
|
||||
|
||||
def requires_nccl():
|
||||
|
@ -28,5 +28,5 @@ def _create_fake_pg(common_opts, backend_opts):
|
||||
|
||||
|
||||
dist.Backend.register_backend(
|
||||
"fake", _create_fake_pg, extended_api=True, devices=["cpu", "cuda", "hpu"]
|
||||
"fake", _create_fake_pg, extended_api=True, devices=["cpu", "cuda", "hpu", "xpu"]
|
||||
)
|
||||
|
Reference in New Issue
Block a user