mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-04 16:04:58 +08:00
Revert "add foreach support for custom device (#102047)"
This reverts commit b088ff467794bc1125133fb0428749d5bcd6ae3a.
Reverted https://github.com/pytorch/pytorch/pull/102047 on behalf of https://github.com/malfet due to Broke inductor, see b088ff4677 ([comment](https://github.com/pytorch/pytorch/pull/102047#issuecomment-1572368942))
This commit is contained in:
@ -5,17 +5,6 @@ import torch
|
||||
from torch import Tensor
|
||||
from torch.autograd.grad_mode import no_grad
|
||||
|
||||
def _get_foreach_kernels_supported_devices() -> List[str]:
|
||||
r"""
|
||||
Return the device type list that supports foreach kernels.
|
||||
"""
|
||||
return ["cuda", torch._C._get_privateuse1_backend_name()]
|
||||
|
||||
def _get_fused_kernels_supported_devices() -> List[str]:
|
||||
r"""
|
||||
Return the device type list that supports fused kernels in optimizer.
|
||||
"""
|
||||
return ["cuda", torch._C._get_privateuse1_backend_name()]
|
||||
|
||||
# This util function splits tensors into groups by device and dtype, which is useful before sending
|
||||
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
|
||||
@ -48,6 +37,6 @@ def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
|
||||
return per_device_and_dtype_tensors
|
||||
|
||||
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
|
||||
if device.type not in set(_get_foreach_kernels_supported_devices() + ["cpu"]) or torch.jit.is_scripting():
|
||||
if device.type not in ['cpu', 'cuda'] or torch.jit.is_scripting():
|
||||
return False
|
||||
return all(t is None or type(t) == torch.Tensor for t in tensors)
|
||||
|
||||
Reference in New Issue
Block a user