mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Renaming all_known_overloads to all_py_loaded_overloads and add comment (#97672)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/97672 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
bb85b43c0b
commit
b2f1edabfe
@ -4,6 +4,8 @@ import unittest.mock
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
import itertools
|
import itertools
|
||||||
|
from typing import Iterator
|
||||||
|
import torch._ops
|
||||||
|
|
||||||
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
|
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
|
||||||
|
|
||||||
@ -25,7 +27,24 @@ def enable_python_dispatcher():
|
|||||||
|
|
||||||
CROSSREF_FUNCTIONALIZE = False
|
CROSSREF_FUNCTIONALIZE = False
|
||||||
|
|
||||||
def all_known_overloads():
|
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
|
||||||
|
"""
|
||||||
|
Warning: the set of overloads this will report is very subtle. It is precisely
|
||||||
|
the set of torch.ops functions that have actually been accessed from Python
|
||||||
|
(e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
|
||||||
|
from the set of registered operators, which will in general be a larger set,
|
||||||
|
as this would include all operators which we ran C++ static initializers or
|
||||||
|
Python operator registration on. This does not eagerly populate the list on
|
||||||
|
torch.ops.aten; this list is lazy!
|
||||||
|
|
||||||
|
In other words, this is good for traversing over everything that has an
|
||||||
|
OpOverload object allocated in Python. We use it for cache invalidation, but
|
||||||
|
don't rely on this list being complete.
|
||||||
|
|
||||||
|
Note that even if we did report all C++ registered overloads, this isn't guaranteed
|
||||||
|
to be complete either, as a subsequent lazy load of a library which triggers more
|
||||||
|
registrations could add more things to the set.
|
||||||
|
"""
|
||||||
for ns in torch.ops:
|
for ns in torch.ops:
|
||||||
packets = getattr(torch.ops, ns)
|
packets = getattr(torch.ops, ns)
|
||||||
for op_name in packets:
|
for op_name in packets:
|
||||||
@ -131,12 +150,12 @@ def make_crossref_functionalize(op, final_key):
|
|||||||
# for debugging purposes.
|
# for debugging purposes.
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def enable_crossref_functionalize():
|
def enable_crossref_functionalize():
|
||||||
for op in all_known_overloads():
|
for op in all_py_loaded_overloads():
|
||||||
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|
||||||
try:
|
try:
|
||||||
with enable_python_dispatcher(), unittest.mock.patch(
|
with enable_python_dispatcher(), unittest.mock.patch(
|
||||||
'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True):
|
'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True):
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
for op in all_known_overloads():
|
for op in all_py_loaded_overloads():
|
||||||
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|
||||||
|
Reference in New Issue
Block a user