ghstack-source-id: 7b4ffc7c494c970ad780d8c5f89209c79a09a573
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165863
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-10-19 08:48:14 -07:00
parent 9f6ca9edeb
commit ff6dab6f76
3 changed files with 95 additions and 7 deletions

View File

@ -242,6 +242,67 @@ class MiscTests(torch._inductor.test_case.TestCase):
self.assertTrue(same(val4, correct1)) self.assertTrue(same(val4, correct1))
self.assertEqual(counter.frame_count, 3) self.assertEqual(counter.frame_count, 3)
def test_compile_non_infra_inside_compile(self):
from torch.utils._python_dispatch import TorchDispatchMode
backend = torch._dynamo.testing.EagerAndRecordGraphs()
class YoloMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
out = torch.compile(func, backend=backend, fullgraph=True)(
*args, **kwargs
)
return out
x = torch.randn(5)
with YoloMode():
out = torch.add(x, x)
self.assertEqual(len(backend.graphs), 1)
def test_compile_non_infra_empty(self):
from torch.utils._python_dispatch import TorchDispatchMode
backend = torch._dynamo.testing.EagerAndRecordGraphs()
class YoloMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return torch.ops.aten.mul.Tensor(args[0], args[1])
x = torch.ones(5)
with YoloMode():
out = torch.compile(torch.add, backend=backend, fullgraph=True)(x, x)
self.assertEqual(out.sum().item(), 5.0)
self.assertEqual(len(backend.graphs), 0)
def test_compile_non_infra_multiple(self):
from torch.utils._python_dispatch import TorchDispatchMode
backend3 = torch._dynamo.testing.EagerAndRecordGraphs()
backend2 = torch._dynamo.testing.EagerAndRecordGraphs()
backend = torch._dynamo.testing.EagerAndRecordGraphs()
class YoloMode2(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
out = torch.compile(func, backend=backend3, fullgraph=True)(
*args, **kwargs
)
return out
class YoloMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
out = torch.compile(torch.add, backend=backend2, fullgraph=True)(args[0], args[1])
return out
x = torch.ones(5)
with YoloMode(), YoloMode2():
torch.compile(lambda x, y: torch.add(x, y), fullgraph=True, backend=backend)(x, x)
self.assertEqual(len(backend2.graphs), 1)
self.assertEqual(len(backend3.graphs), 0)
self.assertEqual(len(backend.graphs), 0)
def test_dynamo_inside_custom_op(self): def test_dynamo_inside_custom_op(self):
cnt = torch._dynamo.testing.InductorAndRecordGraphs() cnt = torch._dynamo.testing.InductorAndRecordGraphs()
cnt1 = torch._dynamo.testing.InductorAndRecordGraphs() cnt1 = torch._dynamo.testing.InductorAndRecordGraphs()

View File

@ -74,8 +74,8 @@ from torch.monitor import _WaitCounter
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils._python_dispatch import ( from torch.utils._python_dispatch import (
_disable_current_modes, _disable_current_modes,
is_in_any_mode_without_ignore_compile_internals,
is_in_torch_dispatch_mode, is_in_torch_dispatch_mode,
any_torch_dispatch_mode_on_stack
) )
from torch.utils._traceback import CapturedTraceback, format_traceback_short from torch.utils._traceback import CapturedTraceback, format_traceback_short
@ -1940,10 +1940,6 @@ class ConvertFrameProtocol(typing.Protocol):
) -> ConvertFrameReturn: ... ) -> ConvertFrameReturn: ...
def should_skip_due_to_torch_dispatch_mode() -> bool:
return is_in_any_mode_without_ignore_compile_internals()
class CatchErrorsWrapper: class CatchErrorsWrapper:
def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None: def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None:
functools.wraps(callback)(self) functools.wraps(callback)(self)
@ -1964,13 +1960,15 @@ class CatchErrorsWrapper:
has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code) has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code)
else: else:
has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code) has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code)
should_skip_due_to_torch_dispatch = any_torch_dispatch_mode_on_stack(include_infra_modes=False, respect_ignore_compile_internals=True)
if ( if (
# TODO: the first condition is not covered by any test # TODO: the first condition is not covered by any test
has_started_execution has_started_execution
or is_skipfile or is_skipfile
or config.disable or config.disable
or ( or (
should_skip_due_to_torch_dispatch_mode() should_skip_due_to_torch_dispatch
and not getattr(self._torchdynamo_orig_backend, "_export", False) and not getattr(self._torchdynamo_orig_backend, "_export", False)
) )
): ):
@ -1979,7 +1977,7 @@ class CatchErrorsWrapper:
skip_reason = "traced frame already" skip_reason = "traced frame already"
elif trace_rules.check(frame.f_code): elif trace_rules.check(frame.f_code):
skip_reason = "in skipfiles" skip_reason = "in skipfiles"
elif is_in_torch_dispatch_mode(include_infra_modes=False): elif should_skip_due_to_torch_dispatch:
skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile" skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile"
else: else:
skip_reason = "dynamo tracing is disabled" skip_reason = "dynamo tracing is disabled"

View File

@ -48,6 +48,35 @@ def is_in_torch_dispatch_mode(include_infra_modes: bool = True) -> bool:
def is_in_any_mode_without_ignore_compile_internals() -> bool: def is_in_any_mode_without_ignore_compile_internals() -> bool:
return _is_in_any_mode_without_ignore_compile_internals return _is_in_any_mode_without_ignore_compile_internals
def any_torch_dispatch_mode_on_stack(
*,
include_infra_modes=True,
respect_ignore_compile_internals=False,
) -> bool:
"""
Check if there are any ambient (non-entered) modes on the stack.
Returns True if there are modes that are on the stack but NOT entered.
"""
stack_len = torch._C._len_torch_dispatch_stack()
for idx in range(stack_len):
mode = _get_dispatch_stack_at(idx)
print("MODE", mode)
# Apply filters first
if mode.is_infra_mode():
if not include_infra_modes:
continue
if mode.ignore_compile_internals():
print("HUUSH")
if respect_ignore_compile_internals:
continue
print("GOT HERE")
breakpoint()
return True
breakpoint()
return False
class TorchDispatchMode: class TorchDispatchMode:
""" """