fix FSDP2 test case failure on XPU (#3771)

* fix FSDP2 test case failure on XPU

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-09-12 06:05:05 -07:00
committed by GitHub
parent 8b493524c8
commit 45959d7b96

View File

@ -568,25 +568,18 @@ class Accelerator:
and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
):
self.native_amp = True
if self.device.type not in (
"xpu",
"cuda",
"npu",
"xla",
"mlu",
"musa",
"hpu",
"sdaa",
"mps",
) or is_torch_xla_available(check_is_tpu=True):
raise ValueError(f"fp16 mixed precision requires a GPU or MPS device (not {self.device.type!r}).")
supported_device = ("xpu", "cuda", "npu", "xla", "mlu", "musa", "hpu", "sdaa", "mps")
if self.device.type not in supported_device or is_torch_xla_available(check_is_tpu=True):
raise ValueError(
f"fp16 mixed precision requires a device in {supported_device} (not {self.device.type!r})."
)
if self.device.type == "mps" and not is_torch_version(">=", "2.5.0"):
raise ValueError("fp16 mixed precision with MPS device requires a Pytorch >= 2.5.0")
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
# FSDP2 doesn't use ShardedGradScaler, don't want to modify `get_grad_scaler`, rather create a simple utility
if self.is_fsdp2:
self.scaler = get_fsdp2_grad_scaler(**kwargs)
self.scaler = get_fsdp2_grad_scaler(device=self.device.type, **kwargs)
else:
self.scaler = get_grad_scaler(self.distributed_type, **kwargs)