Improve torch.ops typing (#154555)

Summary:
Cloned https://github.com/pytorch/pytorch/pull/153558 from benjaminglass1 and fixed internal typing errors.

Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases.

Decisions made along the way:

1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class.
2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.

The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.

Test Plan: CI

Differential Revision: D75497142

Co-authored-by: Benjamin Glass <bglass@quansight.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154555
Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/zou3519, https://github.com/benjaminglass1
This commit is contained in:
Aaron Orenstein
2025-06-22 15:52:27 +00:00
committed by PyTorch MergeBot
parent 10fb98a004
commit 54b8087f63
16 changed files with 180 additions and 126 deletions

View File

@ -3,12 +3,15 @@ import itertools
import unittest.mock
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Callable, TypeVar, Union
from typing_extensions import ParamSpec
import torch
import torch._C
import torch._ops
import torch.utils._python_dispatch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
@ -19,6 +22,9 @@ enable_pre_dispatch = torch._C._EnablePreDispatch
CROSSREF_FUNCTIONALIZE = False
_P = ParamSpec("_P")
_T = TypeVar("_T")
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
"""
@ -103,14 +109,16 @@ def _fmt(a: object) -> object:
return a
def make_crossref_functionalize(op, final_key):
def make_crossref_functionalize(
op: torch._ops.OpOverload[_P, _T], final_key: DispatchKey
) -> Union[Callable[_P, _T], DispatchKey]:
from torch._subclasses.fake_tensor import FakeTensorMode
# This case is pretty weird, suppress it for now
if op == torch.ops.aten.lift_fresh.default:
return final_key
def handler(*args, **kwargs):
def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
fake_mode = FakeTensorMode()
def fakeify_defun(t):