Files
pytorch/test/dynamo/test_exc.py
Sidharth aeaf6b59e2 [dynamo] Weblink generation when unimplemented_v2() is called (#156033)
This PR includes the GBID weblink whenever a user encounters a graph break. I also had to include the JSON file in setup.py, so it can be part of the files that are packaged in during CI. It also fixes the issue of the hardcoded error messages stripping away one of the '/' in 'https'.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156033
Approved by: https://github.com/williamwen42
2025-06-22 11:39:31 +00:00

363 lines
10 KiB
Python

# Owner(s): ["module: dynamo"]
import logging
import unittest
import torch
import torch._dynamo
import torch._dynamo.config
import torch._dynamo.test_case
from torch._dynamo.comptime import comptime
from torch._dynamo.exc import Unsupported
from torch.testing._internal.common_device_type import skipIf
from torch.testing._internal.common_utils import (
IS_FBCODE,
munge_exc,
skipIfWindows,
TEST_Z3,
)
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
class ExcTests(LoggingTestCase):
maxDiff = None
def test_unsupported_real_stack(self):
# exercise Unsupported constructor and augment_exc_message
def fn002(x):
torch._dynamo.graph_break()
def fn001(x):
x = x + 1
fn002(x)
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn001, backend="eager", fullgraph=True)(
torch.randn(1)
),
"""\
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://compile-graph-break-site.vercel.app/gb/GB0025
from user code:
File "test_exc.py", line N, in fn001
fn002(x)
File "test_exc.py", line N, in fn002
torch._dynamo.graph_break()""",
)
@torch._dynamo.config.patch(verbose=True, suppress_errors=True)
@make_logging_test()
@unittest.skipIf(IS_FBCODE, "stack trace slightly different in fbcode")
def test_internal_error_suppress_errors(self, records):
def fn001(x):
def f(ctx):
raise AssertionError
comptime(f)
torch.compile(fn001, backend="eager")(torch.randn(1))
record = self.getRecord(records, "WON'T CONVERT")
self.assertExpectedInline(
munge_exc(record.getMessage()),
"""\
WON'T CONVERT fn001 test_exc.py line N
========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
File "test_exc.py", line N, in f
raise AssertionError
AssertionError:
from user code:
File "test_exc.py", line N, in fn001
comptime(f)
========== The above exception occurred while processing the following code ==========
File "test_exc.py", line N, in test_internal_error_suppress_errors
torch.compile(fn001, backend="eager")(torch.randn(1))
File "test_exc.py", line N, in fn001
comptime(f)
==========""",
)
@make_logging_test()
def test_not_implemented_error(self, records):
def fn001(x):
def f(ctx):
raise NotImplementedError
# Ensure graph break is not possible
for _ in range(3):
comptime(f)
torch.compile(fn001, backend="eager")(torch.randn(1))
record = self.getRecord(records, "WON'T CONVERT")
self.assertExpectedInline(
munge_exc(record.getMessage()),
"""\
WON'T CONVERT fn001 test_exc.py line N
due to:
Traceback (most recent call last):
File "test_exc.py", line N, in f
raise NotImplementedError
torch._dynamo.exc.InternalTorchDynamoError: NotImplementedError:
from user code:
File "test_exc.py", line N, in fn001
comptime(f)""",
)
@torch._dynamo.config.patch(inject_BUILD_SET_unimplemented_TESTING_ONLY=True)
@make_logging_test(dynamo=logging.DEBUG)
def test_unsupported_error(self, records):
def fn001(x):
return {1, 2}
torch.compile(fn001, backend="eager")(torch.randn(1))
# TODO: There is no graph break log! This is because the graph break
# logging is not in a centralized location; unsupported
# instruction bypasses it
self.getRecord(records, "Graph break:")
@torch._dynamo.config.patch(suppress_errors=False)
def test_internal_error_no_suppress(self):
def fn001(x):
# NB: avoid decorator, as 3.11 changed the line number attributed
# in this situation
def f(ctx):
raise AssertionError
comptime(f)
# NB: OK for user code to be truncated here, because the regular
# exception backtrace has the rest of the crumbs
self.assertExpectedInlineMunged(
AssertionError,
lambda: torch.compile(fn001, backend="eager")(torch.randn(1)),
"""\
from user code:
File "test_exc.py", line N, in fn001
comptime(f)""",
)
@make_logging_test(graph_breaks=True)
def test_graph_break_log(self, records):
def fn002(x):
x = x + 1
torch._dynamo.graph_break()
x = x + 1
return x
def fn001(x):
return fn002(x)
torch.compile(fn001, backend="eager")(torch.randn(1))
record = self.getRecord(records, "Graph break in user code")
# TODO: This should also report the enclosing frames; need to plumb
# frame object to it
self.assertExpectedInline(
munge_exc(record.getMessage()),
"""\
Graph break in user code at test_exc.py:N
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://compile-graph-break-site.vercel.app/gb/GB0025
User code traceback:
File "test_exc.py", line N, in test_graph_break_log
torch.compile(fn001, backend="eager")(torch.randn(1))
File "test_exc.py", line N, in fn001
return fn002(x)
File "test_exc.py", line N, in fn002
torch._dynamo.graph_break()
""", # noqa: B950
)
@make_logging_test(graph_breaks=True)
def test_graph_break_log_generic_jump(self, records):
def fn(x):
if x.sum() > 0:
return x + 1
else:
return x - 1
torch.compile(fn, backend="eager")(torch.ones(3, 3))
# check for record existence
self.getRecord(records, "Graph break in user code")
@torch._dynamo.config.patch(suppress_errors=False)
def test_backend_suppress_line(self):
def fn001(x):
x = torch.relu(x)
return x + 1
# Do NOT let this get attributed to x + 1
self.assertExpectedInlineMunged(
torch._dynamo.exc.BackendCompilerFailed,
lambda: torch.compile(fn001, backend="relu_compile_error_TESTING_ONLY")(
torch.randn(1)
),
"""\
backend='relu_compile_error_TESTING_ONLY' raised:
ReluCompileError:""",
)
@skipIf(not TEST_Z3, "z3 not installed")
@torch._dynamo.config.patch(
assume_static_by_default=False,
suppress_errors=False,
)
@torch.fx.experimental._config.patch(
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
translation_validation=True,
translation_validation_no_bisect=True,
)
@skipIfWindows(
msg='AssertionError: "tran[551 chars]s1 s2 s3) s0)\n ==> (<= (+ s1 s2) (+ s0 (* -1[511 chars][0])' # noqa: PLR0133
!= 'tran[551 chars]s1 s2) (+ s0 (* -1 s3)))\n ==> (<= (+ s1 s2) [483 chars][0])"'
)
def test_trigger_on_error(self):
from torch.fx.experimental.validator import ValidationException
@torch.compile
def fn(x, shape):
return x.split(shape)
self.assertExpectedInlineMunged(
ValidationException,
lambda: fn(torch.randn(20), (5, 10, 5)),
"""\
translation validation failed.
Model:
==> L['shape'][0]: 0
==> L['shape'][1]: 1
==> L['shape'][2]: 1
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s3: 1
==> s52: 1
==> s77: 3
==> s86: 0
Assertions:
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s86)
==> (== L['shape'][1] s52)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s77)
==> (> s77 1)
Target Expressions:
==> (!= (+ s3 s52 s86) s77)
==> (<= 0 s3)
==> (<= 0 s52)
==> (<= 0 s86)
==> (<= 2 s77)
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s86)
==> (== L['shape'][1] s52)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s77)
==> (> s77 0)
==> (>= 0 s86)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
)
@skipIf(not TEST_Z3, "z3 not installed")
@torch._dynamo.config.patch(
assume_static_by_default=False,
suppress_errors=False,
)
@torch.fx.experimental._config.patch(
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
translation_validation=True,
)
def test_trigger_bisect_on_error(self):
from torch.fx.experimental.validator import BisectValidationException
@torch.compile
def fn(x, shape):
return x.split(shape)
self.assertExpectedInlineMunged(
BisectValidationException,
lambda: fn(torch.randn(20), (5, 10, 5)),
"""\
translation validation failed when evaluating: Eq(s3 + s52 + s86, s77)
Failure occurred while running node:
%split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})
Model:
==> L['shape'][0]: 1
==> L['shape'][1]: 1
==> L['shape'][2]: 0
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s3: 0
==> s52: 1
==> s77: 3
==> s86: 1
Assertions:
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s86)
==> (== L['shape'][1] s52)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s77)
==> (> s77 1)
Target Expressions:
==> (!= (+ s3 s52 s86) s77)
==> (<= 0 s3)
==> (<= 0 s52)
==> (<= 0 s86)
==> (<= 2 s77)
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s86)
==> (== L['shape'][1] s52)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s77)
==> (> s77 0)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()