mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
add foreach support for custom device (#102047)
Fixes #ISSUE_NUMBER for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
9fa82c90f7
commit
b088ff4677
@ -5,6 +5,17 @@ import torch
|
||||
from torch import Tensor
|
||||
from torch.autograd.grad_mode import no_grad
|
||||
|
||||
def _get_foreach_kernels_supported_devices() -> List[str]:
|
||||
r"""
|
||||
Return the device type list that supports foreach kernels.
|
||||
"""
|
||||
return ["cuda", torch._C._get_privateuse1_backend_name()]
|
||||
|
||||
def _get_fused_kernels_supported_devices() -> List[str]:
|
||||
r"""
|
||||
Return the device type list that supports fused kernels in optimizer.
|
||||
"""
|
||||
return ["cuda", torch._C._get_privateuse1_backend_name()]
|
||||
|
||||
# This util function splits tensors into groups by device and dtype, which is useful before sending
|
||||
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
|
||||
@ -37,6 +48,6 @@ def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
|
||||
return per_device_and_dtype_tensors
|
||||
|
||||
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
|
||||
if device.type not in ['cpu', 'cuda'] or torch.jit.is_scripting():
|
||||
if device.type not in set(_get_foreach_kernels_supported_devices() + ["cpu"]) or torch.jit.is_scripting():
|
||||
return False
|
||||
return all(t is None or type(t) == torch.Tensor for t in tensors)
|
||||
|
Reference in New Issue
Block a user