mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Merge and improve torch optim optimizer type stubs (#102593)"
This reverts commit 3279f06410032e9798e380cedf552f5b706ac6c1. Reverted https://github.com/pytorch/pytorch/pull/102593 on behalf of https://github.com/malfet due to There is nothing wrong with this PR, but it fails some internal builds that depend on outdated typing_extensions, will reland when update is done ([comment](https://github.com/pytorch/pytorch/pull/102593#issuecomment-1636062515))
This commit is contained in:
@ -3,7 +3,6 @@ from typing import List, Dict, Tuple, Optional
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.autograd.grad_mode import no_grad
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
def _get_foreach_kernels_supported_devices() -> List[str]:
|
||||
r"""
|
||||
@ -17,9 +16,6 @@ def _get_fused_kernels_supported_devices() -> List[str]:
|
||||
"""
|
||||
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
|
||||
|
||||
TensorListList: TypeAlias = List[List[Optional[Tensor]]]
|
||||
Indices: TypeAlias = List[int]
|
||||
|
||||
# This util function splits tensors into groups by device and dtype, which is useful before sending
|
||||
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
|
||||
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
|
||||
@ -33,9 +29,9 @@ Indices: TypeAlias = List[int]
|
||||
# may be necessary. Check out torch/optim/sgd.py for an example.
|
||||
@no_grad()
|
||||
def _group_tensors_by_device_and_dtype(
|
||||
tensorlistlist: TensorListList,
|
||||
tensorlistlist: List[List[Optional[Tensor]]],
|
||||
with_indices: bool = False,
|
||||
) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]:
|
||||
) -> Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Optional[Tensor]]], List[int]]]:
|
||||
return {
|
||||
(device, getattr(torch, str_dtype)): value
|
||||
for (device, str_dtype), value in
|
||||
|
Reference in New Issue
Block a user