mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
TEMP
ghstack-source-id: 7b4ffc7c494c970ad780d8c5f89209c79a09a573 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165863
This commit is contained in:
@ -242,6 +242,67 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
|||||||
self.assertTrue(same(val4, correct1))
|
self.assertTrue(same(val4, correct1))
|
||||||
self.assertEqual(counter.frame_count, 3)
|
self.assertEqual(counter.frame_count, 3)
|
||||||
|
|
||||||
|
def test_compile_non_infra_inside_compile(self):
|
||||||
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
|
backend = torch._dynamo.testing.EagerAndRecordGraphs()
|
||||||
|
|
||||||
|
class YoloMode(TorchDispatchMode):
|
||||||
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
|
out = torch.compile(func, backend=backend, fullgraph=True)(
|
||||||
|
*args, **kwargs
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = torch.randn(5)
|
||||||
|
with YoloMode():
|
||||||
|
out = torch.add(x, x)
|
||||||
|
|
||||||
|
self.assertEqual(len(backend.graphs), 1)
|
||||||
|
|
||||||
|
def test_compile_non_infra_empty(self):
|
||||||
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
|
backend = torch._dynamo.testing.EagerAndRecordGraphs()
|
||||||
|
|
||||||
|
class YoloMode(TorchDispatchMode):
|
||||||
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
|
return torch.ops.aten.mul.Tensor(args[0], args[1])
|
||||||
|
|
||||||
|
x = torch.ones(5)
|
||||||
|
with YoloMode():
|
||||||
|
out = torch.compile(torch.add, backend=backend, fullgraph=True)(x, x)
|
||||||
|
|
||||||
|
self.assertEqual(out.sum().item(), 5.0)
|
||||||
|
self.assertEqual(len(backend.graphs), 0)
|
||||||
|
|
||||||
|
def test_compile_non_infra_multiple(self):
|
||||||
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
|
backend3 = torch._dynamo.testing.EagerAndRecordGraphs()
|
||||||
|
backend2 = torch._dynamo.testing.EagerAndRecordGraphs()
|
||||||
|
backend = torch._dynamo.testing.EagerAndRecordGraphs()
|
||||||
|
|
||||||
|
class YoloMode2(TorchDispatchMode):
|
||||||
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
|
out = torch.compile(func, backend=backend3, fullgraph=True)(
|
||||||
|
*args, **kwargs
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class YoloMode(TorchDispatchMode):
|
||||||
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
|
out = torch.compile(torch.add, backend=backend2, fullgraph=True)(args[0], args[1])
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = torch.ones(5)
|
||||||
|
with YoloMode(), YoloMode2():
|
||||||
|
torch.compile(lambda x, y: torch.add(x, y), fullgraph=True, backend=backend)(x, x)
|
||||||
|
|
||||||
|
self.assertEqual(len(backend2.graphs), 1)
|
||||||
|
self.assertEqual(len(backend3.graphs), 0)
|
||||||
|
self.assertEqual(len(backend.graphs), 0)
|
||||||
|
|
||||||
def test_dynamo_inside_custom_op(self):
|
def test_dynamo_inside_custom_op(self):
|
||||||
cnt = torch._dynamo.testing.InductorAndRecordGraphs()
|
cnt = torch._dynamo.testing.InductorAndRecordGraphs()
|
||||||
cnt1 = torch._dynamo.testing.InductorAndRecordGraphs()
|
cnt1 = torch._dynamo.testing.InductorAndRecordGraphs()
|
||||||
|
@ -74,8 +74,8 @@ from torch.monitor import _WaitCounter
|
|||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
from torch.utils._python_dispatch import (
|
from torch.utils._python_dispatch import (
|
||||||
_disable_current_modes,
|
_disable_current_modes,
|
||||||
is_in_any_mode_without_ignore_compile_internals,
|
|
||||||
is_in_torch_dispatch_mode,
|
is_in_torch_dispatch_mode,
|
||||||
|
any_torch_dispatch_mode_on_stack
|
||||||
)
|
)
|
||||||
from torch.utils._traceback import CapturedTraceback, format_traceback_short
|
from torch.utils._traceback import CapturedTraceback, format_traceback_short
|
||||||
|
|
||||||
@ -1940,10 +1940,6 @@ class ConvertFrameProtocol(typing.Protocol):
|
|||||||
) -> ConvertFrameReturn: ...
|
) -> ConvertFrameReturn: ...
|
||||||
|
|
||||||
|
|
||||||
def should_skip_due_to_torch_dispatch_mode() -> bool:
|
|
||||||
return is_in_any_mode_without_ignore_compile_internals()
|
|
||||||
|
|
||||||
|
|
||||||
class CatchErrorsWrapper:
|
class CatchErrorsWrapper:
|
||||||
def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None:
|
def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None:
|
||||||
functools.wraps(callback)(self)
|
functools.wraps(callback)(self)
|
||||||
@ -1964,13 +1960,15 @@ class CatchErrorsWrapper:
|
|||||||
has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code)
|
has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code)
|
||||||
else:
|
else:
|
||||||
has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code)
|
has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code)
|
||||||
|
|
||||||
|
should_skip_due_to_torch_dispatch = any_torch_dispatch_mode_on_stack(include_infra_modes=False, respect_ignore_compile_internals=True)
|
||||||
if (
|
if (
|
||||||
# TODO: the first condition is not covered by any test
|
# TODO: the first condition is not covered by any test
|
||||||
has_started_execution
|
has_started_execution
|
||||||
or is_skipfile
|
or is_skipfile
|
||||||
or config.disable
|
or config.disable
|
||||||
or (
|
or (
|
||||||
should_skip_due_to_torch_dispatch_mode()
|
should_skip_due_to_torch_dispatch
|
||||||
and not getattr(self._torchdynamo_orig_backend, "_export", False)
|
and not getattr(self._torchdynamo_orig_backend, "_export", False)
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
@ -1979,7 +1977,7 @@ class CatchErrorsWrapper:
|
|||||||
skip_reason = "traced frame already"
|
skip_reason = "traced frame already"
|
||||||
elif trace_rules.check(frame.f_code):
|
elif trace_rules.check(frame.f_code):
|
||||||
skip_reason = "in skipfiles"
|
skip_reason = "in skipfiles"
|
||||||
elif is_in_torch_dispatch_mode(include_infra_modes=False):
|
elif should_skip_due_to_torch_dispatch:
|
||||||
skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile"
|
skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile"
|
||||||
else:
|
else:
|
||||||
skip_reason = "dynamo tracing is disabled"
|
skip_reason = "dynamo tracing is disabled"
|
||||||
|
@ -48,6 +48,35 @@ def is_in_torch_dispatch_mode(include_infra_modes: bool = True) -> bool:
|
|||||||
def is_in_any_mode_without_ignore_compile_internals() -> bool:
|
def is_in_any_mode_without_ignore_compile_internals() -> bool:
|
||||||
return _is_in_any_mode_without_ignore_compile_internals
|
return _is_in_any_mode_without_ignore_compile_internals
|
||||||
|
|
||||||
|
def any_torch_dispatch_mode_on_stack(
|
||||||
|
*,
|
||||||
|
include_infra_modes=True,
|
||||||
|
respect_ignore_compile_internals=False,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if there are any ambient (non-entered) modes on the stack.
|
||||||
|
Returns True if there are modes that are on the stack but NOT entered.
|
||||||
|
"""
|
||||||
|
stack_len = torch._C._len_torch_dispatch_stack()
|
||||||
|
|
||||||
|
for idx in range(stack_len):
|
||||||
|
mode = _get_dispatch_stack_at(idx)
|
||||||
|
print("MODE", mode)
|
||||||
|
|
||||||
|
# Apply filters first
|
||||||
|
if mode.is_infra_mode():
|
||||||
|
if not include_infra_modes:
|
||||||
|
continue
|
||||||
|
if mode.ignore_compile_internals():
|
||||||
|
print("HUUSH")
|
||||||
|
if respect_ignore_compile_internals:
|
||||||
|
continue
|
||||||
|
print("GOT HERE")
|
||||||
|
breakpoint()
|
||||||
|
return True
|
||||||
|
breakpoint()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class TorchDispatchMode:
|
class TorchDispatchMode:
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user