mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Summary: Cloned https://github.com/pytorch/pytorch/pull/153558 from benjaminglass1 and fixed internal typing errors. Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases. Decisions made along the way: 1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class. 2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables. The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues. Test Plan: CI Differential Revision: D75497142 Co-authored-by: Benjamin Glass <bglass@quansight.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/154555 Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/zou3519, https://github.com/benjaminglass1
193 lines
6.6 KiB
Python
193 lines
6.6 KiB
Python
# mypy: allow-untyped-defs
|
|
import itertools
|
|
import unittest.mock
|
|
from collections.abc import Iterator
|
|
from contextlib import contextmanager
|
|
from typing import Callable, TypeVar, Union
|
|
from typing_extensions import ParamSpec
|
|
|
|
import torch
|
|
import torch._C
|
|
import torch._ops
|
|
import torch.utils._python_dispatch
|
|
import torch.utils._pytree as pytree
|
|
from torch._C import DispatchKey
|
|
|
|
|
|
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
|
|
|
|
no_python_dispatcher = torch._C._DisablePythonDispatcher
|
|
enable_python_dispatcher = torch._C._EnablePythonDispatcher
|
|
enable_pre_dispatch = torch._C._EnablePreDispatch
|
|
|
|
CROSSREF_FUNCTIONALIZE = False
|
|
|
|
_P = ParamSpec("_P")
|
|
_T = TypeVar("_T")
|
|
|
|
|
|
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
|
|
"""
|
|
Warning: the set of overloads this will report is very subtle. It is precisely
|
|
the set of torch.ops functions that have actually been accessed from Python
|
|
(e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
|
|
from the set of registered operators, which will in general be a larger set,
|
|
as this would include all operators which we ran C++ static initializers or
|
|
Python operator registration on. This does not eagerly populate the list on
|
|
torch.ops.aten; this list is lazy!
|
|
|
|
In other words, this is good for traversing over everything that has an
|
|
OpOverload object allocated in Python. We use it for cache invalidation, but
|
|
don't rely on this list being complete.
|
|
|
|
Note that even if we did report all C++ registered overloads, this isn't guaranteed
|
|
to be complete either, as a subsequent lazy load of a library which triggers more
|
|
registrations could add more things to the set.
|
|
"""
|
|
for ns in torch.ops:
|
|
packets = getattr(torch.ops, ns)
|
|
for op_name in packets:
|
|
packet = getattr(packets, op_name)
|
|
for overload in packet:
|
|
yield getattr(packet, overload)
|
|
|
|
|
|
@contextmanager
|
|
def suspend_functionalization():
|
|
f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
|
|
torch._C.DispatchKey.Functionalize
|
|
)
|
|
f_rv = torch._C._functionalization_reapply_views_tls()
|
|
if f_tls:
|
|
torch._disable_functionalization()
|
|
try:
|
|
yield
|
|
finally:
|
|
if f_tls:
|
|
torch._enable_functionalization(reapply_views=f_rv)
|
|
|
|
|
|
def check_tensor_metadata_matches(nv, rv, desc):
|
|
assert callable(desc)
|
|
assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
|
|
assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
|
|
same_strides, idx = torch._prims_common.check_significant_strides(
|
|
nv, rv, only_cuda=False
|
|
)
|
|
assert same_strides, (
|
|
f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
|
|
)
|
|
|
|
|
|
def check_metadata_matches(n, r, desc):
|
|
assert callable(desc)
|
|
n_vals, _n_spec = pytree.tree_flatten(n)
|
|
r_vals, _r_spec = pytree.tree_flatten(r)
|
|
# TODO: test the specs match; empirically sometimes we have a tuple
|
|
# on one side and a list on the other
|
|
assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
|
|
for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
|
|
if not isinstance(rv, torch.Tensor):
|
|
continue
|
|
check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
|
|
|
|
|
|
class Lit:
|
|
def __init__(self, s):
|
|
self.s = s
|
|
|
|
def __repr__(self):
|
|
return self.s
|
|
|
|
|
|
def _fmt(a: object) -> object:
|
|
if isinstance(a, torch.Tensor):
|
|
return Lit(
|
|
f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
|
|
)
|
|
else:
|
|
return a
|
|
|
|
|
|
def make_crossref_functionalize(
|
|
op: torch._ops.OpOverload[_P, _T], final_key: DispatchKey
|
|
) -> Union[Callable[_P, _T], DispatchKey]:
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
|
|
# This case is pretty weird, suppress it for now
|
|
if op == torch.ops.aten.lift_fresh.default:
|
|
return final_key
|
|
|
|
def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
|
fake_mode = FakeTensorMode()
|
|
|
|
def fakeify_defun(t):
|
|
if isinstance(t, torch.Tensor):
|
|
if torch._is_functional_tensor(t):
|
|
r = torch._from_functional_tensor(t)
|
|
# NB: This assumes that the inner tensor sizes/strides match
|
|
# the outer tensor sizes/strides. This doesn't necessarily have to
|
|
# be the case, see discussion at
|
|
# https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
|
|
assert t.size() == r.size()
|
|
assert t.stride() == r.stride()
|
|
else:
|
|
r = t
|
|
# TODO: suppress guards
|
|
return fake_mode.from_tensor(r)
|
|
return t
|
|
|
|
def maybe_detach(t):
|
|
if isinstance(t, torch.Tensor):
|
|
return t.detach()
|
|
else:
|
|
return t
|
|
|
|
# TODO: This probably does the wrong thing if you're running other
|
|
# substantive modes with the normal op outside here
|
|
with (
|
|
torch.utils._python_dispatch._disable_current_modes(),
|
|
suspend_functionalization(),
|
|
):
|
|
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
|
|
orig_f_args, orig_f_kwargs = pytree.tree_map(
|
|
maybe_detach, (f_args, f_kwargs)
|
|
)
|
|
with fake_mode:
|
|
f_r = op(*f_args, **f_kwargs)
|
|
r = op._op_dk(final_key, *args, **kwargs)
|
|
|
|
def desc():
|
|
fmt_args = ", ".join(
|
|
itertools.chain(
|
|
(repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
|
|
(
|
|
f"{k}={pytree.tree_map(_fmt, v)}"
|
|
for k, v in orig_f_kwargs.items()
|
|
),
|
|
)
|
|
)
|
|
return f"{op}({fmt_args})"
|
|
|
|
check_metadata_matches(f_r, r, desc)
|
|
return r
|
|
|
|
return handler
|
|
|
|
|
|
# NB: enabling this is slow, don't do it in a hot loop. This is purely
|
|
# for debugging purposes.
|
|
@contextmanager
|
|
def enable_crossref_functionalize():
|
|
for op in all_py_loaded_overloads():
|
|
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|
|
try:
|
|
with (
|
|
enable_python_dispatcher(),
|
|
unittest.mock.patch("torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True),
|
|
):
|
|
yield
|
|
finally:
|
|
for op in all_py_loaded_overloads():
|
|
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|