Add option for TorchDispatchMode to ignore torch.compile internals (#161648)

If TorchDispatchMode.ignore_compile_internals() is True, then we turn
off the TorchDispatchMode during the compilation process, instead
turning it back on during runtime of the compiled artifact.

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161648
Approved by: https://github.com/bdhirsh
This commit is contained in:
rzou
2025-08-27 15:15:40 -07:00
committed by PyTorch MergeBot
parent 199c3633bf
commit 5edc3d814f
3 changed files with 115 additions and 1 deletions

View File

@ -11,6 +11,7 @@ from torch._C import (
_pop_torch_function_stack,
_push_on_torch_function_stack,
)
from torch._dynamo.utils import counters
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.inductor_utils import GPU_TYPE
@ -61,6 +62,53 @@ class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
def tearDownClass(cls):
super().tearDownClass()
def test_torch_dispatch_ignore_compile_internals(self):
counters.clear()
from torch.utils._python_dispatch import TorchDispatchMode
@torch.library.custom_op("mylib::foo", mutates_args=())
def foo(x: torch.Tensor) -> torch.Tensor:
return x.clone()
def checksum(x):
return x.abs().sum()
_checksums = []
class ChecksumFoo(TorchDispatchMode):
@classmethod
def ignore_compile_internals(cls):
return True
def __init__(self) -> None:
super().__init__()
def __torch_dispatch__(self, func, types, args, kwargs=None):
kwargs = kwargs or {}
if func is torch.ops.mylib.foo.default:
# Do some compute, smoketest to see if there's a bad interaction
_checksums.append(args[0].abs().sum())
return func(*args, **kwargs)
# test e2e, with Inductor, as smoketest.
@torch.compile(fullgraph=True, backend="inductor")
def g(x):
return 2 * x.sin().cos()
x = torch.randn(3)
with ChecksumFoo():
foo(x)
g(x)
foo(x)
self.assertEqual(len(_checksums), 2)
# The correct result here is 1: Dynamo should capture the `g` frame.
self.assertEqual(counters["frames"]["total"], 1)
self.assertEqual(counters["frames"]["ok"], 1)
def test_skip_torch_dispatch_modes(self):
class RewriteAddToMul(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):

View File

@ -73,6 +73,7 @@ from torch.monitor import _WaitCounter
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils._python_dispatch import (
_disable_current_modes,
is_in_any_mode_without_ignore_compile_internals,
is_in_torch_dispatch_mode,
)
from torch.utils._traceback import CapturedTraceback, format_traceback_short
@ -1775,6 +1776,10 @@ class ConvertFrameProtocol(typing.Protocol):
) -> ConvertFrameReturn: ...
def should_skip_due_to_torch_dispatch_mode() -> bool:
return is_in_any_mode_without_ignore_compile_internals()
class CatchErrorsWrapper:
def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None:
functools.wraps(callback)(self)
@ -1802,7 +1807,7 @@ class CatchErrorsWrapper:
or is_skipfile
or config.disable
or (
is_in_torch_dispatch_mode(include_infra_modes=False)
should_skip_due_to_torch_dispatch_mode()
and not getattr(self._torchdynamo_orig_backend, "_export", False)
)
):

View File

@ -28,6 +28,8 @@ from torch._C import (
_is_in_torch_dispatch_mode = False
_is_in_non_infra_torch_dispatch_mode = False
# If inside any mode that has ignore_compile_internals() = False
_is_in_any_mode_without_ignore_compile_internals = False
def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool:
@ -38,6 +40,10 @@ def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool:
)
def is_in_any_mode_without_ignore_compile_internals() -> bool:
return _is_in_any_mode_without_ignore_compile_internals
class TorchDispatchMode:
"""
A ``TorchDispatchMode`` allows you to override the meaning of all
@ -82,6 +88,9 @@ class TorchDispatchMode:
self.old_dispatch_mode_flags: deque[bool] = deque()
self.old_non_infra_dispatch_mode_flags: deque[bool] = deque()
self.old_without_ignore_compile_internals_dispatch_mode_flags: deque[bool] = (
deque()
)
def _lazy_init_old_dispatch_mode_flags(self):
if not hasattr(self, "old_dispatch_mode_flags"):
@ -90,12 +99,21 @@ class TorchDispatchMode:
if not hasattr(self, "old_non_infra_dispatch_mode_flags"):
self.old_non_infra_dispatch_mode_flags: deque[bool] = deque() # type: ignore[no-redef]
if not hasattr(
self, "old_without_ignore_compile_internals_dispatch_mode_flags"
):
self.old_without_ignore_compile_internals_dispatch_mode_flags: deque[ # type: ignore[no-redef]
bool
] = deque()
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise NotImplementedError
def __enter__(self):
global _is_in_torch_dispatch_mode
global _is_in_non_infra_torch_dispatch_mode
global _is_in_any_mode_without_ignore_compile_internals
# Previously, there wasn't any state in this class' constructor
# super calls were added to existing modes, but for any new modes
# this will replicate the previous behavior of not strictly needing
@ -109,6 +127,13 @@ class TorchDispatchMode:
_is_in_non_infra_torch_dispatch_mode = (
_is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode()
)
self.old_without_ignore_compile_internals_dispatch_mode_flags.append(
_is_in_any_mode_without_ignore_compile_internals
)
_is_in_any_mode_without_ignore_compile_internals = (
_is_in_any_mode_without_ignore_compile_internals
or not self.ignore_compile_internals()
)
_push_mode(self)
return self
@ -124,6 +149,10 @@ class TorchDispatchMode:
_is_in_non_infra_torch_dispatch_mode = (
self.old_non_infra_dispatch_mode_flags.pop()
)
global _is_in_any_mode_without_ignore_compile_internals
_is_in_any_mode_without_ignore_compile_internals = (
self.old_without_ignore_compile_internals_dispatch_mode_flags.pop()
)
_pop_mode(mb_dk_or_mode_key)
@classmethod
@ -138,6 +167,38 @@ class TorchDispatchMode:
def is_infra_mode(cls):
return False
@classmethod
def ignore_compile_internals(cls):
"""Ignore operators that are compiled via torch.compile.
If ``True``, then this TorchDispatchMode ignores operators that
are optimized by :func:`torch.compile`. Mechanically, this involves
turning off the TorchDispatchMode throughout the whole compilation process,
and turning it back on for the runtime of the compiled artifact(s).
For example,
@torch.compile
def f(x):
return x.sin().cos()
with LoggingMode():
f(x)
The above example will not log anything if
``LoggingMode.ignore_compile_internals()`` is True.
torch.compile will fuse sin() and cos() into a single operation
and this TorchDispatchMode will not be passed sin and cos.
If ``False`` (default), :func:`torch.compile` will respect
the eager semantics of passing this TorchDispatchMode all
operators that would have run during eager execution.
The way this will usually happen is that :func:`torch.compile`
will just fallback to eager-mode PyTorch.
"""
if cls.is_infra_mode():
return True
return False
def _get_current_dispatch_mode():
stack_len = _len_torch_dispatch_stack()