mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
fix FSDP2 test case failure on XPU (#3771)
* fix FSDP2 test case failure on XPU Signed-off-by: YAO Matrix <matrix.yao@intel.com> * fix style Signed-off-by: YAO Matrix <matrix.yao@intel.com> --------- Signed-off-by: YAO Matrix <matrix.yao@intel.com>
This commit is contained in:
@ -568,25 +568,18 @@ class Accelerator:
|
||||
and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
|
||||
):
|
||||
self.native_amp = True
|
||||
if self.device.type not in (
|
||||
"xpu",
|
||||
"cuda",
|
||||
"npu",
|
||||
"xla",
|
||||
"mlu",
|
||||
"musa",
|
||||
"hpu",
|
||||
"sdaa",
|
||||
"mps",
|
||||
) or is_torch_xla_available(check_is_tpu=True):
|
||||
raise ValueError(f"fp16 mixed precision requires a GPU or MPS device (not {self.device.type!r}).")
|
||||
supported_device = ("xpu", "cuda", "npu", "xla", "mlu", "musa", "hpu", "sdaa", "mps")
|
||||
if self.device.type not in supported_device or is_torch_xla_available(check_is_tpu=True):
|
||||
raise ValueError(
|
||||
f"fp16 mixed precision requires a device in {supported_device} (not {self.device.type!r})."
|
||||
)
|
||||
if self.device.type == "mps" and not is_torch_version(">=", "2.5.0"):
|
||||
raise ValueError("fp16 mixed precision with MPS device requires a Pytorch >= 2.5.0")
|
||||
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
|
||||
|
||||
# FSDP2 doesn't use ShardedGradScaler, don't want to modify `get_grad_scaler`, rather create a simple utility
|
||||
if self.is_fsdp2:
|
||||
self.scaler = get_fsdp2_grad_scaler(**kwargs)
|
||||
self.scaler = get_fsdp2_grad_scaler(device=self.device.type, **kwargs)
|
||||
else:
|
||||
self.scaler = get_grad_scaler(self.distributed_type, **kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user