mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add option for TorchDispatchMode to ignore torch.compile internals (#161648)
If TorchDispatchMode.ignore_compile_internals() is True, then we turn off the TorchDispatchMode during the compilation process, instead turning it back on during runtime of the compiled artifact. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/161648 Approved by: https://github.com/bdhirsh
This commit is contained in:
@ -11,6 +11,7 @@ from torch._C import (
|
||||
_pop_torch_function_stack,
|
||||
_push_on_torch_function_stack,
|
||||
)
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
|
||||
from torch.testing._internal.common_utils import skipIfXpu
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
@ -61,6 +62,53 @@ class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
|
||||
def test_torch_dispatch_ignore_compile_internals(self):
|
||||
counters.clear()
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
@torch.library.custom_op("mylib::foo", mutates_args=())
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.clone()
|
||||
|
||||
def checksum(x):
|
||||
return x.abs().sum()
|
||||
|
||||
_checksums = []
|
||||
|
||||
class ChecksumFoo(TorchDispatchMode):
|
||||
@classmethod
|
||||
def ignore_compile_internals(cls):
|
||||
return True
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __torch_dispatch__(self, func, types, args, kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if func is torch.ops.mylib.foo.default:
|
||||
# Do some compute, smoketest to see if there's a bad interaction
|
||||
_checksums.append(args[0].abs().sum())
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# test e2e, with Inductor, as smoketest.
|
||||
@torch.compile(fullgraph=True, backend="inductor")
|
||||
def g(x):
|
||||
return 2 * x.sin().cos()
|
||||
|
||||
x = torch.randn(3)
|
||||
|
||||
with ChecksumFoo():
|
||||
foo(x)
|
||||
g(x)
|
||||
foo(x)
|
||||
|
||||
self.assertEqual(len(_checksums), 2)
|
||||
# The correct result here is 1: Dynamo should capture the `g` frame.
|
||||
self.assertEqual(counters["frames"]["total"], 1)
|
||||
self.assertEqual(counters["frames"]["ok"], 1)
|
||||
|
||||
def test_skip_torch_dispatch_modes(self):
|
||||
class RewriteAddToMul(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
|
@ -73,6 +73,7 @@ from torch.monitor import _WaitCounter
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.utils._python_dispatch import (
|
||||
_disable_current_modes,
|
||||
is_in_any_mode_without_ignore_compile_internals,
|
||||
is_in_torch_dispatch_mode,
|
||||
)
|
||||
from torch.utils._traceback import CapturedTraceback, format_traceback_short
|
||||
@ -1775,6 +1776,10 @@ class ConvertFrameProtocol(typing.Protocol):
|
||||
) -> ConvertFrameReturn: ...
|
||||
|
||||
|
||||
def should_skip_due_to_torch_dispatch_mode() -> bool:
|
||||
return is_in_any_mode_without_ignore_compile_internals()
|
||||
|
||||
|
||||
class CatchErrorsWrapper:
|
||||
def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None:
|
||||
functools.wraps(callback)(self)
|
||||
@ -1802,7 +1807,7 @@ class CatchErrorsWrapper:
|
||||
or is_skipfile
|
||||
or config.disable
|
||||
or (
|
||||
is_in_torch_dispatch_mode(include_infra_modes=False)
|
||||
should_skip_due_to_torch_dispatch_mode()
|
||||
and not getattr(self._torchdynamo_orig_backend, "_export", False)
|
||||
)
|
||||
):
|
||||
|
@ -28,6 +28,8 @@ from torch._C import (
|
||||
|
||||
_is_in_torch_dispatch_mode = False
|
||||
_is_in_non_infra_torch_dispatch_mode = False
|
||||
# If inside any mode that has ignore_compile_internals() = False
|
||||
_is_in_any_mode_without_ignore_compile_internals = False
|
||||
|
||||
|
||||
def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool:
|
||||
@ -38,6 +40,10 @@ def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_in_any_mode_without_ignore_compile_internals() -> bool:
|
||||
return _is_in_any_mode_without_ignore_compile_internals
|
||||
|
||||
|
||||
class TorchDispatchMode:
|
||||
"""
|
||||
A ``TorchDispatchMode`` allows you to override the meaning of all
|
||||
@ -82,6 +88,9 @@ class TorchDispatchMode:
|
||||
|
||||
self.old_dispatch_mode_flags: deque[bool] = deque()
|
||||
self.old_non_infra_dispatch_mode_flags: deque[bool] = deque()
|
||||
self.old_without_ignore_compile_internals_dispatch_mode_flags: deque[bool] = (
|
||||
deque()
|
||||
)
|
||||
|
||||
def _lazy_init_old_dispatch_mode_flags(self):
|
||||
if not hasattr(self, "old_dispatch_mode_flags"):
|
||||
@ -90,12 +99,21 @@ class TorchDispatchMode:
|
||||
if not hasattr(self, "old_non_infra_dispatch_mode_flags"):
|
||||
self.old_non_infra_dispatch_mode_flags: deque[bool] = deque() # type: ignore[no-redef]
|
||||
|
||||
if not hasattr(
|
||||
self, "old_without_ignore_compile_internals_dispatch_mode_flags"
|
||||
):
|
||||
self.old_without_ignore_compile_internals_dispatch_mode_flags: deque[ # type: ignore[no-redef]
|
||||
bool
|
||||
] = deque()
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
global _is_in_torch_dispatch_mode
|
||||
global _is_in_non_infra_torch_dispatch_mode
|
||||
global _is_in_any_mode_without_ignore_compile_internals
|
||||
|
||||
# Previously, there wasn't any state in this class' constructor
|
||||
# super calls were added to existing modes, but for any new modes
|
||||
# this will replicate the previous behavior of not strictly needing
|
||||
@ -109,6 +127,13 @@ class TorchDispatchMode:
|
||||
_is_in_non_infra_torch_dispatch_mode = (
|
||||
_is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode()
|
||||
)
|
||||
self.old_without_ignore_compile_internals_dispatch_mode_flags.append(
|
||||
_is_in_any_mode_without_ignore_compile_internals
|
||||
)
|
||||
_is_in_any_mode_without_ignore_compile_internals = (
|
||||
_is_in_any_mode_without_ignore_compile_internals
|
||||
or not self.ignore_compile_internals()
|
||||
)
|
||||
_push_mode(self)
|
||||
return self
|
||||
|
||||
@ -124,6 +149,10 @@ class TorchDispatchMode:
|
||||
_is_in_non_infra_torch_dispatch_mode = (
|
||||
self.old_non_infra_dispatch_mode_flags.pop()
|
||||
)
|
||||
global _is_in_any_mode_without_ignore_compile_internals
|
||||
_is_in_any_mode_without_ignore_compile_internals = (
|
||||
self.old_without_ignore_compile_internals_dispatch_mode_flags.pop()
|
||||
)
|
||||
_pop_mode(mb_dk_or_mode_key)
|
||||
|
||||
@classmethod
|
||||
@ -138,6 +167,38 @@ class TorchDispatchMode:
|
||||
def is_infra_mode(cls):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def ignore_compile_internals(cls):
|
||||
"""Ignore operators that are compiled via torch.compile.
|
||||
|
||||
If ``True``, then this TorchDispatchMode ignores operators that
|
||||
are optimized by :func:`torch.compile`. Mechanically, this involves
|
||||
turning off the TorchDispatchMode throughout the whole compilation process,
|
||||
and turning it back on for the runtime of the compiled artifact(s).
|
||||
For example,
|
||||
|
||||
@torch.compile
|
||||
def f(x):
|
||||
return x.sin().cos()
|
||||
|
||||
with LoggingMode():
|
||||
f(x)
|
||||
|
||||
The above example will not log anything if
|
||||
``LoggingMode.ignore_compile_internals()`` is True.
|
||||
torch.compile will fuse sin() and cos() into a single operation
|
||||
and this TorchDispatchMode will not be passed sin and cos.
|
||||
|
||||
If ``False`` (default), :func:`torch.compile` will respect
|
||||
the eager semantics of passing this TorchDispatchMode all
|
||||
operators that would have run during eager execution.
|
||||
The way this will usually happen is that :func:`torch.compile`
|
||||
will just fallback to eager-mode PyTorch.
|
||||
"""
|
||||
if cls.is_infra_mode():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_current_dispatch_mode():
|
||||
stack_len = _len_torch_dispatch_stack()
|
||||
|
Reference in New Issue
Block a user