mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Add recompile reason for set_stance fail_on_recompile (#165445)
Fixes #163500 ### Summary: For `set_stance("fail_on_recompile")` failures will provide the reason why the recompilation occurred ### Impacts: module: dynamo Pull Request resolved: https://github.com/pytorch/pytorch/pull/165445 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
a88587348b
commit
8139f33fa5
@ -2,6 +2,7 @@
|
||||
import functools
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import unittest.mock as mock
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -1472,6 +1473,30 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(out1, inp + 2)
|
||||
self.assertEqual(out2, inp + 2)
|
||||
|
||||
def test_fail_on_recompile_shows_guard_details(self):
|
||||
@torch.compile(backend="eager", dynamic=False)
|
||||
def f(x):
|
||||
return x + 1
|
||||
|
||||
f(torch.ones(4))
|
||||
f(torch.ones(5))
|
||||
|
||||
def post_munge(s):
|
||||
return re.sub(r"line number: \d+", "line number: N", s)
|
||||
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
f(torch.ones(4))
|
||||
self.assertExpectedInlineMunged(
|
||||
RuntimeError,
|
||||
lambda: f(torch.ones(7)),
|
||||
"""\
|
||||
Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: 'test_decorators.py', function name: 'f', line number: N
|
||||
triggered by the following guard failure(s):
|
||||
- 0/0: tensor 'x' size mismatch at index 0. expected 4, actual 7
|
||||
- 0/1: tensor 'x' size mismatch at index 0. expected 5, actual 7""", # noqa: B950
|
||||
post_munge=post_munge,
|
||||
)
|
||||
|
||||
def test_set_stance_fail_on_recompile_with_disable(self):
|
||||
@torch.compiler.disable
|
||||
def inner(x):
|
||||
|
@ -235,7 +235,11 @@ def _callback_from_stance(callback: DynamoCallback) -> DynamoCallback:
|
||||
if not convert_frame.has_tensor_in_frame(frame):
|
||||
return ConvertFrameReturn()
|
||||
|
||||
from torch._C._dynamo.eval_frame import _debug_get_precompile_entries
|
||||
from torch._C._dynamo.eval_frame import (
|
||||
_debug_get_cache_entry_list,
|
||||
_debug_get_precompile_entries,
|
||||
)
|
||||
from torch._dynamo.guards import get_and_maybe_log_recompilation_reasons
|
||||
|
||||
message = (
|
||||
"Detected recompile when torch.compile stance is 'fail_on_recompile'. "
|
||||
@ -243,6 +247,17 @@ def _callback_from_stance(callback: DynamoCallback) -> DynamoCallback:
|
||||
+ f"function name: '{frame.f_code.co_name}', "
|
||||
+ f"line number: {frame.f_lineno}"
|
||||
)
|
||||
cache_entries = _debug_get_cache_entry_list(frame.f_code)
|
||||
if cache_entries:
|
||||
reasons = get_and_maybe_log_recompilation_reasons(
|
||||
cache_entries[0], frame, skip_logging=True
|
||||
)
|
||||
if reasons:
|
||||
failures = textwrap.indent("\n".join(reasons), "- ")
|
||||
guard_failure_details = (
|
||||
f"triggered by the following guard failure(s):\n{failures}"
|
||||
)
|
||||
message += f"\n{textwrap.indent(guard_failure_details, ' ')}"
|
||||
precompile_entries = _debug_get_precompile_entries(frame.f_code)
|
||||
if len(precompile_entries) > 0:
|
||||
message += "\nFailed on the following precompiled guards: "
|
||||
|
Reference in New Issue
Block a user