From 814ba34fa61f4d95affa6ef9f7207cd3b45cbb75 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Mon, 15 Sep 2025 06:24:55 +0000 Subject: [PATCH] [2/N] Port 5 _composable distributed test to Intel GPU (#159241) For https://github.com/pytorch/pytorch/issues/114850, we will port distributed tests to Intel GPU. This is the second PR for _composable cases, the first is https://github.com/pytorch/pytorch/pull/159118. We could enable Intel GPU with following methods and try the best to keep the original code styles: - Use "torch.accelerator.current_accelerator()" to determine the accelerator backend - Enabled XPU for some test path - Skip some test cases which Intel GPU does not support - Added "cpu:gloo,xpu:xccl" for distributed backend Pull Request resolved: https://github.com/pytorch/pytorch/pull/159241 Approved by: https://github.com/guangyey, https://github.com/d4l3k --- .../_composable/test_checkpoint.py | 21 ++-- .../test_2d_composability.py | 97 +++++++++++-------- .../test_pp_composability.py | 35 ++++--- .../distributed/_composable/test_replicate.py | 39 +++++--- .../test_replicate_with_compiler.py | 2 + .../distributed/_tensor/common_dtensor.py | 1 + 6 files changed, 120 insertions(+), 75 deletions(-) diff --git a/test/distributed/_composable/test_checkpoint.py b/test/distributed/_composable/test_checkpoint.py index f30f8c34f613..7834328f1e35 100644 --- a/test/distributed/_composable/test_checkpoint.py +++ b/test/distributed/_composable/test_checkpoint.py @@ -10,10 +10,13 @@ import torch import torch.nn as nn from torch.distributed._composable import checkpoint from torch.testing._internal.common_cuda import TEST_CUDA -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase from torch.utils.checkpoint import CheckpointError +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + + class MemoryDelta(ContextDecorator): def __init__(self, device: torch.device): self.device: torch.device = device @@ -22,16 +25,16 @@ class MemoryDelta(ContextDecorator): def __enter__(self): self.active_memory_enter = ( - torch.cuda.memory_stats()["active_bytes.all.current"] - if self.device.type == "cuda" + torch.accelerator.memory_stats()["active_bytes.all.current"] + if self.device.type == "cuda" or self.device.type == "xpu" else 0 ) return self def __exit__(self, *exc): self.active_memory_exit = ( - torch.cuda.memory_stats()["active_bytes.all.current"] - if self.device.type == "cuda" + torch.accelerator.memory_stats()["active_bytes.all.current"] + if self.device.type == "cuda" or self.device.type == "xpu" else 0 ) @@ -126,7 +129,7 @@ class TestCheckpoint(TestCase): loss2 = net2(x2).sum() loss2.backward() - if x.is_cuda: + if x.is_cuda or x.is_xpu: self.assertTrue(mem2.delta() < mem1.delta()) for p1, p2 in zip(net1.parameters(), net2.parameters()): @@ -137,10 +140,10 @@ class TestCheckpoint(TestCase): net = ToyModel() self._test_tensor_only(net, x) - @unittest.skipIf(not TEST_CUDA, "no cuda") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda/xpu") def test_tensor_only_gpu(self): - x = torch.randn(20, 100, device="cuda:0") - net = ToyModel().to("cuda:0") + x = torch.randn(20, 100, device=f"{device_type}:0") + net = ToyModel().to(f"{device_type}:0") self._test_tensor_only(net, x) def test_random_cpu(self): diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index bcaf06ea947a..3fd84fbe9e73 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -47,6 +47,8 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, + TEST_XPU, + xfailIf, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -58,6 +60,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import ( from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + + class SimpleModel(nn.Module): def __init__(self): super().__init__() @@ -73,7 +78,7 @@ class SimpleModel(nn.Module): return x def get_input(self): - return torch.rand(4, 5, device="cuda") + return torch.rand(4, 5, device=device_type) class SimpleModelUneven(nn.Module): @@ -94,7 +99,7 @@ class SimpleModelUneven(nn.Module): return x def get_input(self): - return torch.rand(4, 5, device="cuda") + return torch.rand(4, 5, device=device_type) class TestFullyShard2DTraining(FSDPTest): @@ -105,13 +110,15 @@ class TestFullyShard2DTraining(FSDPTest): @property def world_size(self) -> int: - return min(4, torch.cuda.device_count()) + return min(4, torch.accelerator.device_count()) def init_global_mesh(self) -> DeviceMesh: # Prefer to test with >=4 GPUs, but for 2 GPUs, use 2-way TP dp_size = 2 if self.world_size > 2 else 1 return init_device_mesh( - "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp") + device_type, + (dp_size, self.world_size // dp_size), + mesh_dim_names=("dp", "tp"), ) @skip_if_lt_x_gpu(2) @@ -138,7 +145,7 @@ class TestFullyShard2DTraining(FSDPTest): torch.manual_seed(42) model = MLPStack(mlp_dim) - ref_model = copy.deepcopy(model).cuda() + ref_model = copy.deepcopy(model).to(device_type) replicate(ref_model, device_ids=[self.rank], process_group=dp_pg) ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False) model.parallelize( @@ -150,9 +157,8 @@ class TestFullyShard2DTraining(FSDPTest): optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False) torch.manual_seed(42 + dp_pg.rank() + 1) - device = torch.device("cuda") for iter_idx in range(10): - inp = torch.randn((8, mlp_dim), device=device) + inp = torch.randn((8, mlp_dim), device=device_type) losses: list[torch.Tensor] = [] for _model, _optim in ((ref_model, ref_optim), (model, optim)): _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) @@ -162,6 +168,7 @@ class TestFullyShard2DTraining(FSDPTest): self.assertEqual(losses[0], losses[1]) @skip_if_lt_x_gpu(2) + @xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1881 def test_train_parity_2d_transformer(self): self.run_subtests( {"use_shard_placement_fn": [False, True]}, @@ -172,12 +179,12 @@ class TestFullyShard2DTraining(FSDPTest): torch.manual_seed(42) model_args = ModelArgs(n_layers=3, dropout_p=0.0) model = Transformer(model_args) - ref_model = copy.deepcopy(model).cuda() + ref_model = copy.deepcopy(model).to(device_type) ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) dp_size, tp_size = self.world_size // 2, 2 global_mesh = init_device_mesh( - "cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp") + device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp") ) model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True) @@ -205,7 +212,7 @@ class TestFullyShard2DTraining(FSDPTest): self.assertEqual(full_param, ref_param) torch.manual_seed(42 + global_mesh.get_local_rank("dp")) - inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") + inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type) for iter_idx in range(5): ref_loss = ref_model(inp).sum() loss = model(inp).sum() @@ -242,15 +249,16 @@ class TestFullyShard2DTraining(FSDPTest): self.assertEqual(full_param, ref_param) @skip_if_lt_x_gpu(2) + @xfailIf(TEST_XPU) # https://github.com/pytorch/pytorch/issues/156782 def test_tp_with_fsdp_offloading(self): global_mesh = init_device_mesh( - "cuda", (1, self.world_size), mesh_dim_names=("dp", "tp") + device_type, (1, self.world_size), mesh_dim_names=("dp", "tp") ) dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"] torch.manual_seed(42) mlp_dim = 16 model = MLPStack(mlp_dim) - ref_model = copy.deepcopy(model).cuda() + ref_model = copy.deepcopy(model).to(device_type) ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False) # Parallelize with N-way TP and 1-way FSDP model.parallelize( @@ -268,7 +276,7 @@ class TestFullyShard2DTraining(FSDPTest): # NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops # called, but they will just be no-ops without issuing any kernels. # We prefer to keep the no-op check at the c10d level, not in FSDP. - inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks + inp = torch.randn((4, mlp_dim), device=device_type) # same on all ranks for _ in range(10): ref_optim.zero_grad() optim.zero_grad() @@ -297,6 +305,7 @@ class TestFullyShard2DTraining(FSDPTest): ref_optim.step() @skip_if_lt_x_gpu(2) + @xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1881 @with_temp_dir def test_train_parity_2d_transformer_checkpoint_resume(self): """ @@ -352,7 +361,7 @@ class TestFullyShard2DTraining(FSDPTest): ) torch.manual_seed(42 + global_mesh["dp"].get_local_rank() + 1) - inp = torch.randint(0, model_args.vocab_size, (3, 16), device="cuda") + inp = torch.randint(0, model_args.vocab_size, (3, 16), device=device_type) loss_no_cp1 = train_step(model_no_cp, optim_no_cp, inp) loss_no_cp2 = train_step(model_no_cp, optim_no_cp, inp) @@ -410,14 +419,14 @@ class TestFullyShard2DStateDict(DTensorTestBase): @property def backend(self): # need to specify gloo backend for testing cpu offload - return "cpu:gloo,cuda:nccl" + return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl" @with_comms @skip_if_lt_x_gpu(4) def test_fully_shard_tp_2d_set_full_state_dict(self): - dummy_model = SimpleModel().cuda() + dummy_model = SimpleModel().to(device_type) mesh_2d = init_device_mesh( - "cuda", + device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp"), ) @@ -561,7 +570,7 @@ class TestNew2dParallelTraining(DTensorTestBase): self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") ) model = FSDP( - SimpleModel().cuda(), + SimpleModel().to(device_type), device_mesh=mesh_2d["dp"], ) fsdp_state = _get_module_fsdp_state(model) @@ -573,7 +582,7 @@ class TestNew2dParallelTraining(DTensorTestBase): recompute_activation=False, ) -> None: torch.manual_seed(0) - model = SimpleModel().cuda(self.rank) + model = SimpleModel().to(f"{device_type}:{self.rank}") model = FSDP(model, use_orig_params=use_orig_params) optim = torch.optim.Adam(model.parameters(), lr=0.01) @@ -587,7 +596,9 @@ class TestNew2dParallelTraining(DTensorTestBase): "net1": ColwiseParallel(), "net2": RowwiseParallel(), } - model_2d = parallelize_module(SimpleModel().cuda(), tp_mesh, parallelize_plan) + model_2d = parallelize_module( + SimpleModel().to(device_type), tp_mesh, parallelize_plan + ) model_2d = FSDP( model_2d, device_mesh=dp_mesh, @@ -615,7 +626,7 @@ class TestNew2dParallelTraining(DTensorTestBase): # Ensure all input across TP ranks are same. # TODO: add a get_group_rank() to DeviceMesh. torch.manual_seed(i + dist.get_rank(dp_mesh.get_group(mesh_dim=0))) - input = torch.rand(4, 5).cuda(self.rank) + input = torch.rand(4, 5).to(f"{device_type}:{self.rank}") output = model(input) output_2d = model_2d(input) self.assertEqual(output, output_2d) @@ -652,7 +663,7 @@ class TestNew2dParallelStateDict(DTensorTestBase): @property def backend(self): # need to specify gloo backend for testing cpu offload - return "cpu:gloo,cuda:nccl" + return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl" @with_comms @skip_if_lt_x_gpu(4) @@ -669,7 +680,7 @@ class TestNew2dParallelStateDict(DTensorTestBase): "net3": ColwiseParallel(), } model_2d = parallelize_module( - SimpleModel().cuda(), + SimpleModel().to(device_type), mesh_2d["tp"], parallelize_plan=parallelize_plan, ) @@ -679,8 +690,10 @@ class TestNew2dParallelStateDict(DTensorTestBase): isinstance(model_2d_fsdp_state._fsdp_extension, DTensorExtensions) ) - mesh_1d = init_device_mesh("cuda", (self.world_size,)) - model_1d = FSDP(SimpleModel().cuda(), device_mesh=mesh_1d, use_orig_params=True) + mesh_1d = init_device_mesh(device_type, (self.world_size,)) + model_1d = FSDP( + SimpleModel().to(device_type), device_mesh=mesh_1d, use_orig_params=True + ) model_1d_fsdp_state = _get_module_fsdp_state(model_1d) self.assertEqual(model_1d_fsdp_state._fsdp_extension, None) @@ -692,7 +705,7 @@ class TestNew2dParallelStateDict(DTensorTestBase): # Create a model without wrapper torch.manual_seed(0) - no_wrap_model = simple_model().cuda(self.rank) + no_wrap_model = simple_model().to(f"{device_type}:{self.rank}") no_wrap_state_dict = no_wrap_model.state_dict() # Create a model and sharded it with 2D FSDP + TP @@ -706,7 +719,9 @@ class TestNew2dParallelStateDict(DTensorTestBase): "net1": ColwiseParallel(), "net2": RowwiseParallel(), } - model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan) + model_2d = parallelize_module( + simple_model().to(device_type), tp_mesh, parallelize_plan + ) model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True) FSDP.set_state_dict_type( @@ -754,7 +769,9 @@ class TestNew2dParallelStateDict(DTensorTestBase): "net1": ColwiseParallel(), "net2": RowwiseParallel(), } - model_2d = parallelize_module(simple_model().cuda(), tp_mesh, parallelize_plan) + model_2d = parallelize_module( + simple_model().to(device_type), tp_mesh, parallelize_plan + ) model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True) optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) @@ -768,7 +785,7 @@ class TestNew2dParallelStateDict(DTensorTestBase): ref_state_dict = deepcopy(model_2d.state_dict()) # Update the parameters so model.state_dict() will be different from ref_dtensor_sd. - model_2d(model_2d.get_input().cuda(self.rank)).sum().backward() + model_2d(model_2d.get_input().to(f"{device_type}:{self.rank}")).sum().backward() optim_2d.step() # Load ref_state_dict back. @@ -799,9 +816,11 @@ class TestNew2dParallelStateDict(DTensorTestBase): # Create a model without wrapper torch.manual_seed(0) - no_wrap_model = simple_model().cuda(self.rank) + no_wrap_model = simple_model().to(f"{device_type}:{self.rank}") no_wrap_optim = torch.optim.Adam(no_wrap_model.parameters(), lr=0.01) - no_wrap_model(no_wrap_model.get_input().cuda(self.rank)).sum().backward() + no_wrap_model( + no_wrap_model.get_input().to(f"{device_type}:{self.rank}") + ).sum().backward() no_wrap_optim.step() no_wrap_osd = get_optimizer_state_dict(no_wrap_model, optimizers=no_wrap_optim) @@ -815,7 +834,7 @@ class TestNew2dParallelStateDict(DTensorTestBase): "net2": RowwiseParallel(), } model_2d = parallelize_module( - simple_model().cuda(), mesh_2d["tp"], parallelize_plan + simple_model().to(device_type), mesh_2d["tp"], parallelize_plan ) model_2d = FSDP(model_2d, device_mesh=mesh_2d["dp"], use_orig_params=True) FSDP.set_state_dict_type( @@ -823,7 +842,7 @@ class TestNew2dParallelStateDict(DTensorTestBase): StateDictType.SHARDED_STATE_DICT, ) optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) - model_2d(model_2d.get_input().cuda(self.rank)).sum().backward() + model_2d(model_2d.get_input().to(f"{device_type}:{self.rank}")).sum().backward() optim_2d.step() optim_2d_osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d) ref_optim_2d_osd = deepcopy(optim_2d_osd) @@ -842,7 +861,7 @@ class TestNew2dParallelStateDict(DTensorTestBase): # compare with no_wrap state. if isinstance(dist_state, DTensor): dist_state = ( - dist_state.cuda() + dist_state.to(device_type) .redistribute(placements=(Replicate(), Replicate())) .to_local() ) @@ -850,7 +869,7 @@ class TestNew2dParallelStateDict(DTensorTestBase): self.assertTrue(torch.allclose(state, dist_state)) # Update the parameters 2d optim states will be different from ref_optim_state_dict. - model_2d(model_2d.get_input().cuda(self.rank)).sum().backward() + model_2d(model_2d.get_input().to(f"{device_type}:{self.rank}")).sum().backward() optim_2d.step() set_optimizer_state_dict( @@ -892,8 +911,8 @@ class TestNew2dParallelStateDict(DTensorTestBase): 5) dcp.load the state dict from storage 6) load the state dict into the 2D model """ - dummy_model = SimpleModel().cuda() - mesh_1d = init_device_mesh("cuda", (self.world_size,)) + dummy_model = SimpleModel().to(device_type) + mesh_1d = init_device_mesh(device_type, (self.world_size,)) model = FSDP(dummy_model, device_mesh=mesh_1d) optim = torch.optim.Adam(model.parameters(), lr=0.01) model(model.get_input()).sum().backward() @@ -911,9 +930,9 @@ class TestNew2dParallelStateDict(DTensorTestBase): dcp.save(state_dict, checkpoint_id=self.temp_dir) # initialize 2d model - dummy_model = SimpleModel().cuda() + dummy_model = SimpleModel().to(device_type) mesh_2d = init_device_mesh( - "cuda", + device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp"), ) diff --git a/test/distributed/_composable/test_composability/test_pp_composability.py b/test/distributed/_composable/test_composability/test_pp_composability.py index 8f0b938da41b..e4daa81c456c 100644 --- a/test/distributed/_composable/test_composability/test_pp_composability.py +++ b/test/distributed/_composable/test_composability/test_pp_composability.py @@ -30,7 +30,7 @@ from torch.distributed.tensor.parallel import ( from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( MultiProcessTestCase, - requires_nccl, + requires_accelerator_dist_backend, skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import ( @@ -38,6 +38,7 @@ from torch.testing._internal.common_utils import ( parametrize, run_tests, skip_but_pass_in_sandcastle_if, + TEST_XPU, ) from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir @@ -46,6 +47,10 @@ if TYPE_CHECKING: from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +backend = torch.distributed.get_default_backend_for_device(device_type) + + # MLP Layer class MLPModule(torch.nn.Module): def __init__(self, d_hid: int): @@ -79,7 +84,7 @@ class ComposabilityTest(MultiProcessTestCase): @classmethod def backend_str(cls) -> str: # Testing with NCCL backend - return "nccl" + return backend def setUp(self): super().setUp() @@ -100,9 +105,11 @@ class ComposabilityTest(MultiProcessTestCase): def device(self): return self.rank - @requires_nccl() + @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_if_lt_x_gpu(4) - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 4+ GPUs") + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs" + ) def test_pp_and_dcp(self): """ Test that pipeline parallelism and distributed checkpointing can be used together and @@ -143,11 +150,11 @@ class ComposabilityTest(MultiProcessTestCase): x = layer(x) return x - device = torch.device("cuda", self.device) - torch.cuda.set_device(self.device) + device = torch.device(device_type, self.device) + torch.accelerator.set_device_index(self.device) store = torch.distributed.FileStore(self.file_name, self.world_size) torch.distributed.init_process_group( - backend="nccl", + backend=backend, store=store, rank=self.rank, world_size=self.world_size, @@ -192,9 +199,11 @@ class ComposabilityTest(MultiProcessTestCase): _dcp_test(self) - @requires_nccl() + @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_if_lt_x_gpu(8) - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 8+ GPUs") + @skip_but_pass_in_sandcastle_if( + not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs" + ) @parametrize( "ScheduleClass", [ @@ -213,11 +222,11 @@ class ComposabilityTest(MultiProcessTestCase): ], ) def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam): - _device_raii = torch.device("cuda", self.device) - torch.cuda.set_device(self.device) + _device_raii = torch.device(device_type, self.device) + torch.accelerator.set_device_index(self.device) store = torch.distributed.FileStore(self.file_name, self.world_size) torch.distributed.init_process_group( - backend="nccl", + backend=backend, store=store, rank=self.rank, world_size=self.world_size, @@ -228,7 +237,7 @@ class ComposabilityTest(MultiProcessTestCase): num_microbatches = 8 dp_size = self.world_size // (tp_size * pp_size) device_mesh = init_device_mesh( - "cuda", + device_type, mesh_shape=(dp_size, pp_size, tp_size), mesh_dim_names=("dp", "pp", "tp"), ) diff --git a/test/distributed/_composable/test_replicate.py b/test/distributed/_composable/test_replicate.py index a793fe2fed4c..8c1cb3d5df32 100644 --- a/test/distributed/_composable/test_replicate.py +++ b/test/distributed/_composable/test_replicate.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: distributed"] import os +import unittest from copy import deepcopy import torch @@ -14,7 +15,11 @@ from torch.testing._internal.common_distributed import ( MultiProcessTestCase, skip_if_lt_x_gpu, ) -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests, TEST_XPU + + +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +device_module = torch.get_device_module(device_type) class Net(nn.Module): @@ -154,6 +159,7 @@ class ReplicateTest(MultiProcessTestCase): self._compare_module(model, replicate_model) @skip_if_lt_x_gpu(2) + @unittest.skipIf(TEST_XPU, "XPU does not support gloo backend") def test_replicate_move_args_kwargs_to_device(self): class MyNet(nn.Module): def __init__(self) -> None: @@ -166,24 +172,25 @@ class ReplicateTest(MultiProcessTestCase): return self.a(inp) self._init_pg() - torch.cuda.set_device(self.rank) - model = MyNet().cuda() - replicate(model, device_id=torch.cuda.current_device()) + torch.accelerator.set_device_index(self.rank) + model = MyNet().to(device_type) + replicate(model, device_id=torch.accelerator.current_device_index()) # CPU input ensures replicate can move arg and kwargs to device. a, b = torch.randn(2, 2), torch.randn(2, 2) model(a, kwarg=b).sum().backward() @skip_if_lt_x_gpu(2) + @unittest.skipIf(TEST_XPU, "XPU does not support gloo backend") def test_replicate_ignore_module(self): self._init_pg() - torch.cuda.set_device(self.rank) + torch.accelerator.set_device_index(self.rank) # Seed ensures diff input and thus different local grads across ranks. torch.manual_seed(self.rank) - torch.cuda.manual_seed(self.rank) - model = Net().cuda() + device_module.manual_seed(self.rank) + model = Net().to(device_type) replicate(model, ignored_modules=[model.fc1]) # CPU input ensures that replicate can move input to GPU as DDP does. - inp = torch.randn(5, 2, device="cuda") * (self.rank + 1) + inp = torch.randn(5, 2, device=device_type) * (self.rank + 1) out = model(inp) * 10 out.sum().backward() # FC1 grads should not be synchronized, FC2 and 3 should be. @@ -221,10 +228,11 @@ class ReplicateTest(MultiProcessTestCase): self._compare_module(model, replicate_model) @skip_if_lt_x_gpu(2) + @unittest.skipIf(TEST_XPU, "XPU does not support gloo backend") def test_replicate_device_id(self): self._init_pg() model = Net() - model_cuda = deepcopy(model).cuda() + model_cuda = deepcopy(model).to(device_type) model_cuda2 = deepcopy(model_cuda) replicate(model, device_id=torch.device("cpu")) # DDP instance is attached in first pre forward @@ -233,13 +241,15 @@ class ReplicateTest(MultiProcessTestCase): # Should be None for CPU training self.assertEqual(None, replicate_ddp_weakref.device_ids) - replicate(model_cuda, device_id=torch.device(torch.cuda.current_device())) + replicate( + model_cuda, device_id=torch.device(torch.accelerator.current_device_index()) + ) # DDP instance is attached in first pre forward model_cuda(torch.randn(2, 2)) replicate_ddp_weakref = replicate.state(model_cuda)._ddp_weakref() self.assertEqual([0], replicate_ddp_weakref.device_ids) # Pass in int as device_id - replicate(model_cuda2, device_id=int(torch.cuda.current_device())) + replicate(model_cuda2, device_id=int(torch.accelerator.current_device_index())) # DDP instance is attached in first pre forward model_cuda2(torch.randn(2, 2)) replicate_ddp_weakref = replicate.state(model_cuda2)._ddp_weakref() @@ -256,6 +266,7 @@ class ReplicateTest(MultiProcessTestCase): class ReplicateFullyShardInit(ReplicateTest): @skip_if_lt_x_gpu(2) + @unittest.skipIf(TEST_XPU, "XPU does not support gloo backend") def test_replicate_fully_shard_init(self): class ToyModel(nn.Module): def __init__(self, dim: int): @@ -273,14 +284,14 @@ class ReplicateFullyShardInit(ReplicateTest): return y self._init_pg() - torch.cuda.set_device(self.rank) + torch.accelerator.set_device_index(self.rank) dim = 3 bz = 2 - model = ToyModel(dim).cuda() + model = ToyModel(dim).to(device_type) for linear in model.linears: fully_shard(linear) fully_shard(model.linears) - replicate(model, device_id=torch.cuda.current_device()) + replicate(model, device_id=torch.accelerator.current_device_index()) for linear in model.linears: self.assertTrue(isinstance(linear.weight, DTensor)) inp = torch.rand(bz, dim) diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index 11eba3e5bb0c..291b3a426822 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -98,6 +98,8 @@ class ReplicateTest(MultiProcessInductorTestCase): self.create_pg(device) torch._dynamo.config.optimize_ddp = "python_reducer" torch.manual_seed(123) + if device_type == "xpu": + torch.use_deterministic_algorithms(True, warn_only=True) model = Net(checkpoint=checkpoint).to(device) input = torch.randn([1, DIM], device=device) diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index e25e08fbf509..604ba9714f21 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -388,6 +388,7 @@ class DTensorTestBase(MultiProcessTestCase): "hccl", "xccl", "fake", + "cpu:gloo,xpu:xccl", ]: raise RuntimeError(f"Backend {backend} not supported!")