Revert "add foreach support for custom device (#102047)"

This reverts commit b088ff467794bc1125133fb0428749d5bcd6ae3a.

Reverted https://github.com/pytorch/pytorch/pull/102047 on behalf of https://github.com/malfet due to Broke inductor, see b088ff4677 ([comment](https://github.com/pytorch/pytorch/pull/102047#issuecomment-1572368942))
This commit is contained in:
PyTorch MergeBot
2023-06-01 16:33:03 +00:00
parent 74f10b9ea5
commit 9d77949b9e
4 changed files with 11 additions and 37 deletions

View File

@ -5,17 +5,6 @@ import torch
from torch import Tensor
from torch.autograd.grad_mode import no_grad
def _get_foreach_kernels_supported_devices() -> List[str]:
r"""
Return the device type list that supports foreach kernels.
"""
return ["cuda", torch._C._get_privateuse1_backend_name()]
def _get_fused_kernels_supported_devices() -> List[str]:
r"""
Return the device type list that supports fused kernels in optimizer.
"""
return ["cuda", torch._C._get_privateuse1_backend_name()]
# This util function splits tensors into groups by device and dtype, which is useful before sending
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
@ -48,6 +37,6 @@ def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
return per_device_and_dtype_tensors
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
if device.type not in set(_get_foreach_kernels_supported_devices() + ["cpu"]) or torch.jit.is_scripting():
if device.type not in ['cpu', 'cuda'] or torch.jit.is_scripting():
return False
return all(t is None or type(t) == torch.Tensor for t in tensors)