Files
pytorch/torch/utils/_foreach_utils.py
Nitin Singh fe4b88f6aa [HPU] Add hpu to fused kernels supported devices (#148666)
This change adds "hpu" to the list of device types that support fused kernels in the optimizer, ensuring
compatibility with HPU backend.

Without this change, when `test_all_gather_extension_outer_size_stride` of `pytorch/test/distributed/_composable/fsdp/test_fully_shard_extensions.py` is run on 'hpu' backend, it fails with:

RuntimeError: fused=True requires all the params to be floating point Tensors
of supported devices: ['mps', 'cuda', 'xpu', 'cpu', 'privateuseone']
but torch.float32 and hpu

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148666
Approved by: https://github.com/albanD
2025-03-07 04:28:33 +00:00

45 lines
2.3 KiB
Python

from typing import Optional
import torch
from torch import Tensor
from torch.autograd.grad_mode import no_grad
from typing_extensions import TypeAlias
def _get_foreach_kernels_supported_devices() -> list[str]:
r"""Return the device type list that supports foreach kernels."""
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
def _get_fused_kernels_supported_devices() -> list[str]:
r"""Return the device type list that supports fused kernels in optimizer."""
return ["mps", "cuda", "xpu", "hpu", "cpu", torch._C._get_privateuse1_backend_name()]
TensorListList: TypeAlias = list[list[Optional[Tensor]]]
Indices: TypeAlias = list[int]
_foreach_supported_types = [torch.Tensor]
# This util function splits tensors into groups by device and dtype, which is useful before sending
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
# - tensorlists CAN be None
# - all tensors in the first specified list cannot be None
# - given an index i, all specified tensorlist[i]s match in dtype and device
# with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry.
# It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out.
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
# may be necessary. Check out torch/optim/sgd.py for an example.
@no_grad()
def _group_tensors_by_device_and_dtype(
tensorlistlist: TensorListList,
with_indices: bool = False,
) -> dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]]:
return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
def _device_has_foreach_support(device: torch.device) -> bool:
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
def _has_foreach_support(tensors: list[Tensor], device: torch.device) -> bool:
return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors)