Files
pytorch/torch/utils/_foreach_utils.py
Matthew Hoffman 0616952d13 Merge and improve torch optim optimizer type stubs (#102593)
Fixes #102428

Also improves hook registration type hints:

```python
from typing import Any, Dict, Tuple

from torch import nn
from torch.optim import Adam, Adagrad, Optimizer

linear = nn.Linear(2,2)
optimizer = Adam(linear.parameters(), lr=0.001)

def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def pre_hook_fn_return_modified(
    optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
    return inputs, kwargs

def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

optimizer.register_step_post_hook(hook_fn)  # OK

optimizer.register_step_pre_hook(pre_hook_fn_return_none)  # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_modified)  # OK

optimizer.register_step_post_hook(hook_fn_other_optimizer)  # Parameter 1: type "Adam" cannot be assigned to type "Adagrad"

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593
Approved by: https://github.com/janeyx99, https://github.com/malfet
2023-07-26 11:56:42 +00:00

50 lines
2.3 KiB
Python

from typing import List, Dict, Tuple, Optional
import torch
from torch import Tensor
from torch.autograd.grad_mode import no_grad
from typing_extensions import TypeAlias
def _get_foreach_kernels_supported_devices() -> List[str]:
r"""
Return the device type list that supports foreach kernels.
"""
return ["cuda", torch._C._get_privateuse1_backend_name()]
def _get_fused_kernels_supported_devices() -> List[str]:
r"""
Return the device type list that supports fused kernels in optimizer.
"""
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
TensorListList: TypeAlias = List[List[Optional[Tensor]]]
Indices: TypeAlias = List[int]
# This util function splits tensors into groups by device and dtype, which is useful before sending
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
# - tensorlists CAN be None
# - all tensors in the first specified list cannot be None
# - given an index i, all specified tensorlist[i]s match in dtype and device
# with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry.
# It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out.
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
# may be necessary. Check out torch/optim/sgd.py for an example.
@no_grad()
def _group_tensors_by_device_and_dtype(
tensorlistlist: TensorListList,
with_indices: bool = False,
) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]:
return {
(device, getattr(torch, str_dtype)): value
for (device, str_dtype), value in
torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices).items()
}
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
if device.type not in set(_get_foreach_kernels_supported_devices() + ["cpu"]) or torch.jit.is_scripting():
return False
return all(t is None or type(t) == torch.Tensor for t in tensors)