Renaming all_known_overloads to all_py_loaded_overloads and add comment (#97672)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97672
Approved by: https://github.com/Skylion007
This commit is contained in:
Edward Z. Yang
2023-03-27 08:04:39 -07:00
committed by PyTorch MergeBot
parent bb85b43c0b
commit b2f1edabfe

View File

@ -4,6 +4,8 @@ import unittest.mock
import torch
import torch.utils._pytree as pytree
import itertools
from typing import Iterator
import torch._ops
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
@ -25,7 +27,24 @@ def enable_python_dispatcher():
CROSSREF_FUNCTIONALIZE = False
def all_known_overloads():
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
"""
Warning: the set of overloads this will report is very subtle. It is precisely
the set of torch.ops functions that have actually been accessed from Python
(e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
from the set of registered operators, which will in general be a larger set,
as this would include all operators which we ran C++ static initializers or
Python operator registration on. This does not eagerly populate the list on
torch.ops.aten; this list is lazy!
In other words, this is good for traversing over everything that has an
OpOverload object allocated in Python. We use it for cache invalidation, but
don't rely on this list being complete.
Note that even if we did report all C++ registered overloads, this isn't guaranteed
to be complete either, as a subsequent lazy load of a library which triggers more
registrations could add more things to the set.
"""
for ns in torch.ops:
packets = getattr(torch.ops, ns)
for op_name in packets:
@ -131,12 +150,12 @@ def make_crossref_functionalize(op, final_key):
# for debugging purposes.
@contextmanager
def enable_crossref_functionalize():
for op in all_known_overloads():
for op in all_py_loaded_overloads():
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
try:
with enable_python_dispatcher(), unittest.mock.patch(
'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True):
yield
finally:
for op in all_known_overloads():
for op in all_py_loaded_overloads():
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)