add foreach support for custom device (#102047)

Fixes #ISSUE_NUMBER
for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047
Approved by: https://github.com/janeyx99
This commit is contained in:
shibo19
2023-06-01 06:22:39 +00:00
committed by PyTorch MergeBot
parent 9fa82c90f7
commit b088ff4677
4 changed files with 37 additions and 11 deletions

View File

@ -5,6 +5,17 @@ import torch
from torch import Tensor
from torch.autograd.grad_mode import no_grad
def _get_foreach_kernels_supported_devices() -> List[str]:
r"""
Return the device type list that supports foreach kernels.
"""
return ["cuda", torch._C._get_privateuse1_backend_name()]
def _get_fused_kernels_supported_devices() -> List[str]:
r"""
Return the device type list that supports fused kernels in optimizer.
"""
return ["cuda", torch._C._get_privateuse1_backend_name()]
# This util function splits tensors into groups by device and dtype, which is useful before sending
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
@ -37,6 +48,6 @@ def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
return per_device_and_dtype_tensors
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
if device.type not in ['cpu', 'cuda'] or torch.jit.is_scripting():
if device.type not in set(_get_foreach_kernels_supported_devices() + ["cpu"]) or torch.jit.is_scripting():
return False
return all(t is None or type(t) == torch.Tensor for t in tensors)