mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
TEMP
[ghstack-poisoned]
This commit is contained in:
@ -242,6 +242,67 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||
self.assertTrue(same(val4, correct1))
|
||||
self.assertEqual(counter.frame_count, 3)
|
||||
|
||||
def test_compile_non_infra_inside_compile(self):
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
backend = torch._dynamo.testing.EagerAndRecordGraphs()
|
||||
|
||||
class YoloMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
out = torch.compile(func, backend=backend, fullgraph=True)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return out
|
||||
|
||||
x = torch.randn(5)
|
||||
with YoloMode():
|
||||
out = torch.add(x, x)
|
||||
|
||||
self.assertEqual(len(backend.graphs), 1)
|
||||
|
||||
def test_compile_non_infra_empty(self):
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
backend = torch._dynamo.testing.EagerAndRecordGraphs()
|
||||
|
||||
class YoloMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
return torch.ops.aten.mul.Tensor(args[0], args[1])
|
||||
|
||||
x = torch.ones(5)
|
||||
with YoloMode():
|
||||
out = torch.compile(torch.add, backend=backend, fullgraph=True)(x, x)
|
||||
|
||||
self.assertEqual(out.sum().item(), 5.0)
|
||||
self.assertEqual(len(backend.graphs), 0)
|
||||
|
||||
def test_compile_non_infra_multiple(self):
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
backend3 = torch._dynamo.testing.EagerAndRecordGraphs()
|
||||
backend2 = torch._dynamo.testing.EagerAndRecordGraphs()
|
||||
backend = torch._dynamo.testing.EagerAndRecordGraphs()
|
||||
|
||||
class YoloMode2(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
out = torch.compile(func, backend=backend3, fullgraph=True)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return out
|
||||
|
||||
class YoloMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
out = torch.compile(torch.add, backend=backend2, fullgraph=True)(args[0], args[1])
|
||||
return out
|
||||
|
||||
x = torch.ones(5)
|
||||
with YoloMode(), YoloMode2():
|
||||
torch.compile(lambda x, y: torch.add(x, y), fullgraph=True, backend=backend)(x, x)
|
||||
|
||||
self.assertEqual(len(backend2.graphs), 1)
|
||||
self.assertEqual(len(backend3.graphs), 0)
|
||||
self.assertEqual(len(backend.graphs), 0)
|
||||
|
||||
def test_dynamo_inside_custom_op(self):
|
||||
cnt = torch._dynamo.testing.InductorAndRecordGraphs()
|
||||
cnt1 = torch._dynamo.testing.InductorAndRecordGraphs()
|
||||
|
@ -74,8 +74,8 @@ from torch.monitor import _WaitCounter
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.utils._python_dispatch import (
|
||||
_disable_current_modes,
|
||||
is_in_any_mode_without_ignore_compile_internals,
|
||||
is_in_torch_dispatch_mode,
|
||||
any_torch_dispatch_mode_on_stack
|
||||
)
|
||||
from torch.utils._traceback import CapturedTraceback, format_traceback_short
|
||||
|
||||
@ -1940,10 +1940,6 @@ class ConvertFrameProtocol(typing.Protocol):
|
||||
) -> ConvertFrameReturn: ...
|
||||
|
||||
|
||||
def should_skip_due_to_torch_dispatch_mode() -> bool:
|
||||
return is_in_any_mode_without_ignore_compile_internals()
|
||||
|
||||
|
||||
class CatchErrorsWrapper:
|
||||
def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None:
|
||||
functools.wraps(callback)(self)
|
||||
@ -1964,13 +1960,15 @@ class CatchErrorsWrapper:
|
||||
has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code)
|
||||
else:
|
||||
has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code)
|
||||
|
||||
should_skip_due_to_torch_dispatch = any_torch_dispatch_mode_on_stack(include_infra_modes=False, respect_ignore_compile_internals=True)
|
||||
if (
|
||||
# TODO: the first condition is not covered by any test
|
||||
has_started_execution
|
||||
or is_skipfile
|
||||
or config.disable
|
||||
or (
|
||||
should_skip_due_to_torch_dispatch_mode()
|
||||
should_skip_due_to_torch_dispatch
|
||||
and not getattr(self._torchdynamo_orig_backend, "_export", False)
|
||||
)
|
||||
):
|
||||
@ -1979,7 +1977,7 @@ class CatchErrorsWrapper:
|
||||
skip_reason = "traced frame already"
|
||||
elif trace_rules.check(frame.f_code):
|
||||
skip_reason = "in skipfiles"
|
||||
elif is_in_torch_dispatch_mode(include_infra_modes=False):
|
||||
elif should_skip_due_to_torch_dispatch:
|
||||
skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile"
|
||||
else:
|
||||
skip_reason = "dynamo tracing is disabled"
|
||||
|
@ -48,6 +48,35 @@ def is_in_torch_dispatch_mode(include_infra_modes: bool = True) -> bool:
|
||||
def is_in_any_mode_without_ignore_compile_internals() -> bool:
|
||||
return _is_in_any_mode_without_ignore_compile_internals
|
||||
|
||||
def any_torch_dispatch_mode_on_stack(
|
||||
*,
|
||||
include_infra_modes=True,
|
||||
respect_ignore_compile_internals=False,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if there are any ambient (non-entered) modes on the stack.
|
||||
Returns True if there are modes that are on the stack but NOT entered.
|
||||
"""
|
||||
stack_len = torch._C._len_torch_dispatch_stack()
|
||||
|
||||
for idx in range(stack_len):
|
||||
mode = _get_dispatch_stack_at(idx)
|
||||
print("MODE", mode)
|
||||
|
||||
# Apply filters first
|
||||
if mode.is_infra_mode():
|
||||
if not include_infra_modes:
|
||||
continue
|
||||
if mode.ignore_compile_internals():
|
||||
print("HUUSH")
|
||||
if respect_ignore_compile_internals:
|
||||
continue
|
||||
print("GOT HERE")
|
||||
breakpoint()
|
||||
return True
|
||||
breakpoint()
|
||||
return False
|
||||
|
||||
|
||||
class TorchDispatchMode:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user