Enable FSDP tests on XPU device (#147518)

**Motivation:**

Enable FSDP tests on XPU device

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147518
Approved by: https://github.com/weifengpy
This commit is contained in:
lzhang2
2025-03-04 23:49:33 +00:00
committed by PyTorch MergeBot
parent c98c3af421
commit 84b58bd63e
20 changed files with 104 additions and 54 deletions

View File

@ -113,7 +113,7 @@ class TestApply(FSDPTest):
transformer.apply(self._init_linear_weights)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestApply, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(TestApply, globals(), only_for=devices, allow_xpu=True)
if __name__ == "__main__":
run_tests()

View File

@ -334,7 +334,9 @@ class TestFSDPCheckpointSubmodule(FSDPTest):
self.assertTrue(p1.grad.allclose(p2.grad))
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestFSDPCheckpointSubmodule, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(
TestFSDPCheckpointSubmodule, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@ -338,7 +338,9 @@ class TestClipGradNorm(FSDPTest):
self.assertEqual(total_norm, torch.tensor(0.0, device=self.device_type))
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestClipGradNorm, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(
TestClipGradNorm, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@ -382,8 +382,12 @@ class TestExplicitUnshard(FSDPTest):
model.module.mlps._wait_unshard_streams_on_current_stream()
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestCommunication, globals(), only_for=devices)
instantiate_device_type_tests(TestExplicitUnshard, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(
TestCommunication, globals(), only_for=devices, allow_xpu=True
)
instantiate_device_type_tests(
TestExplicitUnshard, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@ -512,11 +512,15 @@ class TestAutograd(FSDPTest):
FlatParamHandle._use_unsharded_views = orig_use_unsharded_views
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestHooks, globals(), only_for=devices)
instantiate_device_type_tests(TestParityWithDDP, globals(), only_for=devices)
instantiate_device_type_tests(TestNoGrad, globals(), only_for=devices)
instantiate_device_type_tests(TestParamInit, globals(), only_for=devices)
instantiate_device_type_tests(TestAutograd, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(TestHooks, globals(), only_for=devices, allow_xpu=True)
instantiate_device_type_tests(
TestParityWithDDP, globals(), only_for=devices, allow_xpu=True
)
instantiate_device_type_tests(TestNoGrad, globals(), only_for=devices, allow_xpu=True)
instantiate_device_type_tests(
TestParamInit, globals(), only_for=devices, allow_xpu=True
)
instantiate_device_type_tests(TestAutograd, globals(), only_for=devices, allow_xpu=True)
if __name__ == "__main__":
run_tests()

View File

@ -285,9 +285,9 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
FSDP.optim_state_dict(model, optim)
devices = ("cuda", "hpu")
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(
TestFSDPWithDeviceMeshAndDTensor, globals(), only_for=devices
TestFSDPWithDeviceMeshAndDTensor, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@ -211,7 +211,9 @@ class TestFSDPExecOrder(FSDPTest):
# an `AssertionError` will be raised above for both sharding strategies
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestFSDPExecOrder, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(
TestFSDPExecOrder, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@ -404,7 +404,9 @@ class TestFSDPFineTune(FSDPTest):
self.assertEqual(param, ref_param)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestFSDPFineTune, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(
TestFSDPFineTune, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@ -113,7 +113,9 @@ class TestSymbolicTracing(TestCase):
self.assertEqual(exec_info.visited_params, set(exec_info.param_forward_order))
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestSymbolicTracing, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(
TestSymbolicTracing, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@ -70,7 +70,7 @@ class TestInput(FSDPTest):
optim.zero_grad()
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestInput, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(TestInput, globals(), only_for=devices, allow_xpu=True)
if __name__ == "__main__":
run_tests()

View File

@ -73,7 +73,9 @@ class TestMultiForward(FSDPTest):
self.assertEqual(ddp_state, fsdp_state)
devices = ("cpu", "hpu")
instantiate_device_type_tests(TestMultiForward, globals(), only_for=devices)
devices = ("cpu", "hpu", "xpu")
instantiate_device_type_tests(
TestMultiForward, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@ -61,7 +61,9 @@ class TestMultipleWrapping(FSDPTest):
self.assertEqual(output, rewrapped_output)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestMultipleWrapping, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(
TestMultipleWrapping, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@ -151,7 +151,7 @@ class TestPureFP16(FSDPTest):
self.assertEqual(param.grad.dtype, torch.float16)
devices = ("cuda", "hpu")
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(TestPureFP16, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -61,7 +61,7 @@ class TestTraversal(FSDPTest):
)
devices = ("cuda", "hpu")
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(TestTraversal, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -68,7 +68,9 @@ class TestUnevenParamShard(FSDPTest):
self.assertEqual(ref_weight_out, weight_out)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestUnevenParamShard, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(
TestUnevenParamShard, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@ -324,9 +324,9 @@ class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):
self.assertIsInstance(state["exp_avg_sq"], torch.Tensor)
devices = ("cuda", "hpu")
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(
TestHSDPWithDeviceMeshAndDTensor, globals(), only_for=devices
TestHSDPWithDeviceMeshAndDTensor, globals(), only_for=devices, allow_xpu=True
)
if __name__ == "__main__":
run_tests()

View File

@ -17,6 +17,7 @@ from torch.testing._internal.common_utils import (
subtest,
TEST_HPU,
TEST_WITH_DEV_DBG_ASAN,
TEST_XPU,
TestCase,
)
@ -32,7 +33,12 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
list_device = "hpu" if TEST_HPU else "cuda"
if TEST_HPU:
list_device = "hpu"
elif TEST_XPU:
list_device = "xpu"
else:
list_device = "cuda"
class TestUtils(TestCase):
@ -129,7 +135,7 @@ class TestUtils(TestCase):
self.assertEqual(torch.sum(x), 0)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestUtils, globals(), only_for=devices)
devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(TestUtils, globals(), only_for=devices, allow_xpu=True)
if __name__ == "__main__":
run_tests()

View File

@ -44,6 +44,7 @@ from torch.testing._internal.common_utils import (
TestCase,
run_tests,
TEST_HPU,
TEST_XPU,
)
from torch.testing._internal.distributed.multi_threaded_pg import (
_install_threaded_pg,
@ -105,6 +106,8 @@ class DistTestCases:
backend_feature["plugin"] = set()
if TEST_HPU:
backend_feature["hpu"] = {"hccl"}
if TEST_XPU:
backend_feature["xpu"] = {"xccl"}
def skip_if_no_gpu(func):
@ -120,6 +123,8 @@ def skip_if_no_gpu(func):
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
if TEST_HPU and torch.hpu.device_count < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
if TEST_XPU and torch.xpu.device_count < world_size:
sys.exit(TEST_SKIPS[f"multi-xpu-{world_size}"].exit_code)
return func(*args, **kwargs)
@ -199,6 +204,8 @@ def skip_if_lt_x_gpu(x):
return func(*args, **kwargs)
if TEST_HPU and torch.hpu.device_count() >= x:
return func(*args, **kwargs)
if TEST_XPU and torch.xpu.device_count() >= x:
return func(*args, **kwargs)
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
return wrapper
@ -510,7 +517,8 @@ def init_multigpu_helper(world_size: int, backend: str):
nGPUs = torch.cuda.device_count()
if TEST_HPU:
nGPUs = torch.hpu.device_count()
if TEST_XPU:
nGPUs = torch.xpu.device_count()
visible_devices = range(nGPUs)
# If rank is less than or equal to number of available GPU's
@ -941,6 +949,8 @@ class DistributedTestBase(MultiProcessTestCase):
return "nccl"
elif "hpu" in device : # intel gaudi
return "hccl"
elif "xpu" in device:
return "xccl"
else :
return "gloo"
@ -953,8 +963,8 @@ class DistributedTestBase(MultiProcessTestCase):
rank=self.rank,
store=store
)
if "nccl" in self.backend(device):
torch.cuda.set_device(self.rank)
if "nccl" in self.backend(device) or "xccl" in self.backend(device):
torch.accelerator.set_device_index(self.rank)
return torch.distributed.distributed_c10d._get_default_group()
def rank_to_device(self, device):
@ -1347,7 +1357,7 @@ def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
# Just manually implement the most important part of the dynamo behavior to reset/clear.
if not fake_pg:
torch.cuda.set_device(rank)
torch.accelerator.set_device_index(rank)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '6789'
if init_pg:

View File

@ -59,6 +59,7 @@ from torch.testing._internal.common_utils import (
get_cycles_per_ms,
TEST_CUDA,
TEST_HPU,
TEST_XPU,
)
from torch.utils._triton import has_triton
@ -72,6 +73,10 @@ if TEST_CUDA:
elif TEST_HPU:
DEVICE_TYPE = "hpu:0"
DISTRIBUTED_BACKEND = "hccl"
elif TEST_XPU:
DEVICE_TYPE = "xpu"
DISTRIBUTED_BACKEND = "xccl"
DEVICE_COUNT = torch.xpu.device_count()
else:
DEVICE_TYPE = "cpu"
DISTRIBUTED_BACKEND = "gloo"
@ -647,7 +652,7 @@ class ModuleWithDelay(FSDPTestModel):
def get_loss(self, input, output):
loss = self.module.get_loss(input, output) # type: ignore[operator]
if self.delay_after_loss_ms > 0:
if TEST_HPU:
if TEST_HPU or TEST_XPU:
time.sleep(self.delay_after_loss_ms / 1000)
elif TEST_CUDA:
torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
@ -663,7 +668,7 @@ class ModuleWithDelay(FSDPTestModel):
torch.cuda._sleep(
int(self.delay_before_reduction_ms * get_cycles_per_ms())
)
elif TEST_HPU:
elif TEST_HPU or TEST_XPU:
time.sleep(self.delay_before_reduction_ms / 1000)
return orig_reduce_scatter(*args, **kwargs)
@ -796,7 +801,7 @@ class MixtureOfExperts(NestedWrappedModule):
torch.cuda._sleep(
int(self.delay_before_free_ms * get_cycles_per_ms())
)
elif TEST_HPU:
elif TEST_HPU or TEST_XPU:
time.sleep(self.delay_before_free_ms / 1000)
return orig_reshard(*args, **kwargs)
@ -1209,8 +1214,8 @@ class FSDPTest(MultiProcessTestCase):
device_ids = None
device_id = self.rank % DEVICE_COUNT
if TEST_CUDA:
torch.cuda.set_device(device_id)
if TEST_CUDA or TEST_XPU:
torch.accelerator.set_device_index(device_id)
device_ids = [device_id]
# Execute barrier prior to running test to ensure that every process
@ -1435,7 +1440,7 @@ class FSDPTest(MultiProcessTestCase):
self.assertRaisesRegex(
RuntimeError,
"An FSDP-managed module with parameter CPU offloading enabled "
"has parameters on cuda",
f"has parameters on {DEVICE_TYPE}",
)
if expects_device_error
else nullcontext()

View File

@ -32,6 +32,7 @@ from torch.distributed.tensor.parallel import (
from torch.testing._internal.common_utils import (
TEST_HPU,
TEST_CUDA,
TEST_XPU
)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
@ -52,6 +53,10 @@ elif TEST_HPU:
DEVICE_TYPE = "hpu"
PG_BACKEND = "hccl"
DEVICE_COUNT = _get_device_module("hpu").device_count()
elif TEST_XPU:
DEVICE_TYPE = "xpu"
PG_BACKEND = "xccl"
DEVICE_COUNT = _get_device_module("xpu").device_count()
else:
DEVICE_TYPE = "cpu"
PG_BACKEND = "gloo"
@ -59,7 +64,7 @@ else:
NUM_DEVICES = 4
# We use this as a proxy for "multiple GPUs exist"
if TEST_CUDA and DEVICE_COUNT > 1:
if (TEST_CUDA or TEST_XPU) and DEVICE_COUNT > 1:
# when we actually have multiple GPUs, relax the requirement to smaller counts.
NUM_DEVICES = min(NUM_DEVICES, DEVICE_COUNT)
@ -321,7 +326,7 @@ class DTensorTestBase(MultiProcessTestCase):
@property
def backend(self) -> str:
backend = "nccl" if TEST_CUDA else "hccl" if TEST_HPU else "gloo"
backend = dist.get_default_backend_for_device(DEVICE_TYPE)
return backend
def build_device_mesh(self) -> DeviceMesh:
@ -331,13 +336,13 @@ class DTensorTestBase(MultiProcessTestCase):
if "nccl" in self.backend and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
if self.backend not in ["nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl", "hccl"]:
if self.backend not in ["nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl", "hccl", "xccl"]:
raise RuntimeError(f"Backend {self.backend} not supported!")
device_id = None
if "nccl" in self.backend:
if "nccl" in self.backend or "xccl" in self.backend:
# set device for nccl pg for collectives
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
# we only need to set device_id for nccl backend with eager init
device_id = torch.device(f"{self.device_type}:{self.rank}") if eager_init else None
# For nccl backend, bind the device to the process if device_id is not None
@ -391,7 +396,7 @@ def with_comms(eager_init: Union[TestFunc, bool] = False) -> TestFunc:
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
) -> None:
# if enough GPU we can use GPU, otherwise we fallback to CPU
if not TEST_CUDA or torch.cuda.device_count() < self.world_size:
if not (TEST_CUDA or TEST_XPU) or torch.accelerator.device_count() < self.world_size:
self.device_type = "cpu"
else:
self.device_type = DEVICE_TYPE