mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Renaming all_known_overloads to all_py_loaded_overloads and add comment (#97672)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/97672 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
bb85b43c0b
commit
b2f1edabfe
@ -4,6 +4,8 @@ import unittest.mock
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
import itertools
|
||||
from typing import Iterator
|
||||
import torch._ops
|
||||
|
||||
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
|
||||
|
||||
@ -25,7 +27,24 @@ def enable_python_dispatcher():
|
||||
|
||||
CROSSREF_FUNCTIONALIZE = False
|
||||
|
||||
def all_known_overloads():
|
||||
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
|
||||
"""
|
||||
Warning: the set of overloads this will report is very subtle. It is precisely
|
||||
the set of torch.ops functions that have actually been accessed from Python
|
||||
(e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
|
||||
from the set of registered operators, which will in general be a larger set,
|
||||
as this would include all operators which we ran C++ static initializers or
|
||||
Python operator registration on. This does not eagerly populate the list on
|
||||
torch.ops.aten; this list is lazy!
|
||||
|
||||
In other words, this is good for traversing over everything that has an
|
||||
OpOverload object allocated in Python. We use it for cache invalidation, but
|
||||
don't rely on this list being complete.
|
||||
|
||||
Note that even if we did report all C++ registered overloads, this isn't guaranteed
|
||||
to be complete either, as a subsequent lazy load of a library which triggers more
|
||||
registrations could add more things to the set.
|
||||
"""
|
||||
for ns in torch.ops:
|
||||
packets = getattr(torch.ops, ns)
|
||||
for op_name in packets:
|
||||
@ -131,12 +150,12 @@ def make_crossref_functionalize(op, final_key):
|
||||
# for debugging purposes.
|
||||
@contextmanager
|
||||
def enable_crossref_functionalize():
|
||||
for op in all_known_overloads():
|
||||
for op in all_py_loaded_overloads():
|
||||
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|
||||
try:
|
||||
with enable_python_dispatcher(), unittest.mock.patch(
|
||||
'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True):
|
||||
yield
|
||||
finally:
|
||||
for op in all_known_overloads():
|
||||
for op in all_py_loaded_overloads():
|
||||
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|
||||
|
Reference in New Issue
Block a user