diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py index f0814889ba2d..bbc4c277273e 100644 --- a/torch/_dispatch/python.py +++ b/torch/_dispatch/python.py @@ -4,6 +4,8 @@ import unittest.mock import torch import torch.utils._pytree as pytree import itertools +from typing import Iterator +import torch._ops __all__ = ['enable_python_dispatcher', 'no_python_dispatcher'] @@ -25,7 +27,24 @@ def enable_python_dispatcher(): CROSSREF_FUNCTIONALIZE = False -def all_known_overloads(): +def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]: + """ + Warning: the set of overloads this will report is very subtle. It is precisely + the set of torch.ops functions that have actually been accessed from Python + (e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT + from the set of registered operators, which will in general be a larger set, + as this would include all operators which we ran C++ static initializers or + Python operator registration on. This does not eagerly populate the list on + torch.ops.aten; this list is lazy! + + In other words, this is good for traversing over everything that has an + OpOverload object allocated in Python. We use it for cache invalidation, but + don't rely on this list being complete. + + Note that even if we did report all C++ registered overloads, this isn't guaranteed + to be complete either, as a subsequent lazy load of a library which triggers more + registrations could add more things to the set. + """ for ns in torch.ops: packets = getattr(torch.ops, ns) for op_name in packets: @@ -131,12 +150,12 @@ def make_crossref_functionalize(op, final_key): # for debugging purposes. @contextmanager def enable_crossref_functionalize(): - for op in all_known_overloads(): + for op in all_py_loaded_overloads(): op._uncache_dispatch(torch._C.DispatchKey.Functionalize) try: with enable_python_dispatcher(), unittest.mock.patch( 'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True): yield finally: - for op in all_known_overloads(): + for op in all_py_loaded_overloads(): op._uncache_dispatch(torch._C.DispatchKey.Functionalize)