mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
PEP585 update - torch/utils (#145201)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145201 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
693d8c7e94
commit
2f9d378f7b
@ -1,20 +1,20 @@
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.autograd.grad_mode import no_grad
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
def _get_foreach_kernels_supported_devices() -> List[str]:
|
||||
def _get_foreach_kernels_supported_devices() -> list[str]:
|
||||
r"""Return the device type list that supports foreach kernels."""
|
||||
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
|
||||
|
||||
def _get_fused_kernels_supported_devices() -> List[str]:
|
||||
def _get_fused_kernels_supported_devices() -> list[str]:
|
||||
r"""Return the device type list that supports fused kernels in optimizer."""
|
||||
return ["mps", "cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()]
|
||||
|
||||
TensorListList: TypeAlias = List[List[Optional[Tensor]]]
|
||||
Indices: TypeAlias = List[int]
|
||||
TensorListList: TypeAlias = list[list[Optional[Tensor]]]
|
||||
Indices: TypeAlias = list[int]
|
||||
_foreach_supported_types = [torch.Tensor]
|
||||
|
||||
|
||||
@ -33,12 +33,12 @@ _foreach_supported_types = [torch.Tensor]
|
||||
def _group_tensors_by_device_and_dtype(
|
||||
tensorlistlist: TensorListList,
|
||||
with_indices: bool = False,
|
||||
) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]:
|
||||
) -> dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]]:
|
||||
return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
|
||||
|
||||
def _device_has_foreach_support(device: torch.device) -> bool:
|
||||
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
|
||||
|
||||
|
||||
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
|
||||
def _has_foreach_support(tensors: list[Tensor], device: torch.device) -> bool:
|
||||
return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors)
|
||||
|
Reference in New Issue
Block a user