[dynamo] improved graph break messages for some common graph break sites [1/N] (#146525)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146525
Approved by: https://github.com/jansel
This commit is contained in:
William Wen
2025-02-19 11:46:13 -08:00
committed by PyTorch MergeBot
parent 1e94c7aaa4
commit 16e202a38e
18 changed files with 841 additions and 166 deletions

View File

@ -38,6 +38,7 @@ per-file-ignores =
torchgen/api/types/__init__.py: F401,F403
torchgen/executorch/api/types/__init__.py: F401,F403
test/dynamo/test_higher_order_ops.py: B950
test/dynamo/test_graph_break_messages.py: B950
torch/testing/_internal/dynamo_test_failures.py: B950
# TOR901 is only for test, we want to ignore it for everything else.
# It's not easy to configure this without affecting other per-file-ignores,

View File

@ -1251,7 +1251,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
x = torch.randn(4, 4).to(device)
opt_fn = torch.compile(fn, fullgraph=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "skip function graph_break in file"
torch._dynamo.exc.Unsupported, "User-inserted graph break"
):
opt_fn(x)

View File

@ -271,7 +271,10 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
model = CustomFuncBwdPrintModule()
opt_model = torch.compile(model, backend="eager", fullgraph=True)
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: print"):
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Dynamo does not know how to trace builtin operator `print`",
):
opt_model(x)
def test_stride_in_bwd(self):

View File

@ -330,7 +330,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
fn3(torch.randn(4, 5))
self.assertFalse(True)
except torch._dynamo.exc.Unsupported as e:
self.assertIn("call torch._dynamo.disable() wrapped function", str(e))
self.assertIn("Skip calling `torch.compiler.disable()`d function", str(e))
def test_disable_optimize(self):
cnt = torch._dynamo.testing.CompileCounter()

View File

@ -37,7 +37,12 @@ class ExcTests(LoggingTestCase):
torch.randn(1)
),
"""\
'skip function graph_break in file _dynamo/decorators.py'
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
from user code:
File "test_exc.py", line N, in fn001
@ -171,7 +176,12 @@ from user code:
munge_exc(record.getMessage()),
"""\
Graph break in user code at test_exc.py:N
Reason: Unsupported: 'skip function graph_break in file _dynamo/decorators.py'
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
User code traceback:
File "test_exc.py", line N, in fn001
return fn002(x)

View File

@ -0,0 +1,505 @@
# Owner(s): ["module: dynamo"]
import re
import unittest
import warnings
import torch
import torch._dynamo
import torch._dynamo.config
import torch._dynamo.test_case
import torch.utils._pytree as python_pytree
from torch._dynamo.exc import Unsupported
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import IS_FBCODE, scoped_load_inline
"""
NOTE Adding tests to this file:
It is good practice to add a minimal repro for each graph break site (i.e. `unimplemented()` call
to make sure that there aren't any errors that occur when generating graph break messages.
If a graph break message test fails because the graph break no longer repros,
it is good practice to find a new minimal repro that causes the graph break.
If this is too much work, it is likely safe to skip/remove the test, assuming
it was previously passing and the graph break message is not changed.
However, if you add a new graph break or modify a graph break message, you should
make sure that there is a test for it.
"""
class GraphBreakMessagesTest(torch._dynamo.test_case.TestCase):
maxDiff = None
def test_dynamic_shape_operator(self):
def fn():
return torch.nonzero(torch.rand([10, 10]))
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
"""\
Dynamic shape operator
Explanation: Operator `aten.nonzero.default`'s output shape depends on input Tensor data.
Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`
Developer debug context: aten.nonzero.default
from user code:
File "test_graph_break_messages.py", line N, in fn
return torch.nonzero(torch.rand([10, 10]))""",
)
def test_dynamic_shape_operator_no_meta_kernel(self):
def fn():
return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10))
with torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True):
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
"""\
Dynamic shape operator (no meta kernel)
Explanation: Operator `aten.linalg_lstsq.default` does not have a meta kernel that supports dynamic output shapes
Hint: Please report an issue to PyTorch
Developer debug context: aten.linalg_lstsq.default
from user code:
File "test_graph_break_messages.py", line N, in fn
return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10))""",
)
def test_data_dependent_operator(self):
def fn(x):
return x.item()
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
torch.Tensor([1])
),
"""\
Tensor.item
from user code:
File "test_graph_break_messages.py", line N, in fn
return x.item()""",
)
def test_data_dependent_operator2(self):
def fn(x):
return torch.equal(x, x)
with torch._dynamo.config.patch(capture_scalar_outputs=True):
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
torch.ones(3)
),
"""\
Data dependent operator
Explanation: Operator `aten.equal.default` has a non-Tensor output whose value is dependent on the data of Tensor inputs.
Hint: Consider wrapping the operator into a PyTorch-understood custom operator (see https:/pytorch.org/tutorials/advanced/custom_ops_landing_page.html)
Developer debug context: aten.equal.default
from user code:
File "test_graph_break_messages.py", line N, in fn
return torch.equal(x, x)""",
)
def test_super_call_method(self):
def fn(it):
return [x + 1 for x in it]
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
zip(range(5), range(10))
),
"""\
Unsupported method call
Explanation: Dynamo does not know how to trace method `__iter__` of class `zip`
Hint: Avoid calling `zip.__iter__` in your code.
Hint: Please report an issue to PyTorch.
Hint: Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). This can happen unintentionally if a previous graph break happens with a builtin iterator in the local scope.
Developer debug context: call_method UserDefinedObjectVariable(zip) __iter__ () {}
from user code:
File "test_graph_break_messages.py", line N, in fn
return [x + 1 for x in it]""",
)
def test_super_call_function(self):
def fn(it):
return [x + 1 for x in it()]
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
zip(range(5), range(10))
),
"""\
Unsupported function call
Explanation: Dynamo does not know how to trace the function `UserDefinedObjectVariable(zip)`
Hint: Avoid calling `UserDefinedObjectVariable(zip)` in your code.
Hint: Please report an issue to PyTorch.
Developer debug context: call_function UserDefinedObjectVariable(zip) [] {}
from user code:
File "test_graph_break_messages.py", line N, in fn
return [x + 1 for x in it()]""",
)
def test_unsupported_context(self):
def fn(obj):
with obj:
return 1
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(3),
"""\
Unsupported context manager
Explanation: Dynamo does not know how to enter a `int` context manager.
Hint: Avoid using the unsupported context manager.
Hint: File an issue to PyTorch. Simple context managers can potentially be supported, but note that context managers can't be supported in general
Developer debug context: Attempted SETUP_WITH/BEFORE_WITH on ConstantVariable(int: 3)
from user code:
File "test_graph_break_messages.py", line N, in fn
with obj:""",
)
def test_backend_fake_tensor_exc(self):
def bad_backend(gm, ex):
raise torch._subclasses.fake_tensor.UnsupportedFakeTensorException("test")
def fn(x):
return x + 1
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend=bad_backend, fullgraph=True)(
torch.ones(3, 3)
),
"""\
Backend compiler exception
Explanation: Backend compiler `bad_backend` failed with test. Adding a graph break.
Hint: Report an issue to the backend compiler repo.
Developer debug context: Backend: bad_backend
Exception:test
Traceback:
File "test_graph_break_messages.py", line N, in fn
return x + 1""",
)
def test_unsupported_builtin(self):
def fn():
print("abc")
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
"""\
Failed to trace builtin operator
Explanation: Dynamo does not know how to trace builtin operator `print` with argument types ['str'] (has_kwargs False)
Hint: Avoid calling builtin `print` with argument types ['str']. Consider using an equivalent alternative function/method to `print`.
Hint: If you are attempting to call a logging function (e.g. `print`), you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.
Hint: Please report an issue to PyTorch.
Developer debug context: builtin print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
from user code:
File "test_graph_break_messages.py", line N, in fn
print("abc")""",
)
def test_skipfile_call(self):
def fn():
return unittest.skip("test")
def post_munge(s):
return re.sub(r"file `.*case\.py`", "file `case.py`", s)
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
"""\
Attempted to call function marked as skipped
Explanation: Dynamo developers have intentionally marked that the function `skip` in file `case.py` should not be traced.
Hint: Avoid calling the function `skip`.
Hint: Remove the function `skip` or the file `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function.
Hint: Please file an issue to PyTorch.
Developer debug context: module: unittest.case, qualname: skip, skip reason: <missing reason>
from user code:
File "test_graph_break_messages.py", line N, in fn
return unittest.skip("test")""",
post_munge=post_munge,
)
def test_skipfile_dynamo_call(self):
def fn():
torch._dynamo.disable()
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
"""\
Attempted to call function marked as skipped
Explanation: Dynamo developers have intentionally marked that the function `disable` in file `_dynamo/decorators.py` should not be traced.
Hint: Avoid calling the function `disable`.
Developer debug context: module: torch._dynamo.decorators, qualname: disable, skip reason: <missing reason>
from user code:
File "test_graph_break_messages.py", line N, in fn
torch._dynamo.disable()""",
)
def test_skipfile_inline(self):
class Foo:
fn = unittest.skip
def fn():
Foo().fn()
def post_munge(s):
return re.sub(r"`.*case\.py`", "`case.py`", s)
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
"""\
Attempted to inline function marked as skipped
Explanation: Dynamo developers have intentionally marked that the function `skip` should not be traced.
Hint: Avoid calling the function `skip`.
Hint: Remove the function `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function.
Hint: Please file an issue to PyTorch.
Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup SKIP_DIRS
from user code:
File "test_graph_break_messages.py", line N, in fn
Foo().fn()""",
post_munge=post_munge,
)
def test_disable(self):
@torch.compiler.disable
def inner():
return 1
def fn():
return inner()
def post_munge(s):
return re.sub(
r"<function GraphBreakMessagesTest\.test_disable\.<locals>\.inner at 0x[0-9A-Fa-f]+>",
"<function inner>",
s,
)
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
"""\
Skip calling `torch.compiler.disable()`d function
Explanation: Skip calling function `<function inner>` since it was wrapped with `torch.compiler.disable`
Hint: Remove the `torch.compiler.disable` call
Developer debug context: <function inner>
from user code:
File "test_graph_break_messages.py", line N, in fn
return inner()""",
post_munge=post_munge,
)
def test_dynamo_graph_break_fn(self):
def fn():
torch._dynamo.graph_break()
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
"""\
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
from user code:
File "test_graph_break_messages.py", line N, in fn
torch._dynamo.graph_break()""",
)
def test_dynamo_graph_break_fn_with_msg(self):
def fn():
torch._dynamo.graph_break(msg="test graph break")
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
"""\
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: test graph break
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{'msg': ConstantVariable(str: 'test graph break')}`
from user code:
File "test_graph_break_messages.py", line N, in fn
torch._dynamo.graph_break(msg="test graph break")""",
)
def test_warnings(self):
def fn():
warnings.warn("test")
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
"""\
Attempted to call function marked as skipped
Explanation: Dynamo does not know how to trace the Python builtin `_warnings.warn`.
Hint: If you are attempting to call a logging function (e.g. `_warnings.warn`), you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.
Hint: Please file an issue on GitHub so the PyTorch team can add support for it.
Developer debug context: module: _warnings, qualname: warn, skip reason: <missing reason>
from user code:
File "test_graph_break_messages.py", line N, in fn
warnings.warn("test")""",
)
@unittest.skipIf(not python_pytree._cxx_pytree_exists, "missing optree package")
def test_optree_graph_break_message(self):
import optree
@torch.compile(backend="eager")
def fn(x):
d = {"a": 1}
optree.tree_flatten(d)
return torch.sin(x)
fn(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 1)
first_graph_break = next(iter(counters["graph_break"].keys()))
self.assertExpectedInline(
first_graph_break,
"""\
Attempted to call function marked as skipped
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten.
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: <missing reason>
""",
)
@scoped_load_inline
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
@unittest.skipIf(IS_FBCODE, "inline cpp_extension doesn't work in fbcode")
def test_cpp_extension_recommends_custom_ops(self, load_inline):
cpp_source = """
#include <torch/extension.h>
at::Tensor foobar(const at::Tensor& x) {
return x.clone();
}
"""
module = load_inline(
name="mylib",
cpp_sources=cpp_source,
functions="foobar",
verbose=True,
)
x = torch.ones(2, 2, requires_grad=True)
counters.clear()
@torch.compile(backend="eager")
def f(x):
return module.foobar(x)
with self.assertWarnsOnceRegex(
UserWarning,
"(?s).*https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html.*",
):
f(x)
self.assertEqual(len(counters["graph_break"]), 1)
first_graph_break = next(iter(counters["graph_break"].keys()))
first_graph_break = re.sub(r"mylib(_v\d+)?", "mylib", first_graph_break)
self.assertExpectedInline(
first_graph_break,
"""\
Attempted to call function marked as skipped
Explanation: Dynamo does not know how to trace the builtin `mylib.PyCapsule.foobar.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
Hint: If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
Developer debug context: module: mylib, qualname: PyCapsule.foobar, skip reason: <missing reason>
""",
)
cpp_source = """
#include <torch/extension.h>
at::Tensor baz(const at::Tensor& x) {
return x.clone();
}
"""
module2 = load_inline(
name="mylib2",
cpp_sources=cpp_source,
functions="baz",
verbose=True,
)
torch._dynamo.reset()
# Test that each warning only happens once
@torch.compile(backend="eager")
def f(x):
module2.baz(x)
module.foobar(x)
module.foobar(x)
module2.baz(x)
module.foobar(x)
module2.baz(x)
return x.clone()
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always")
f(x)
f(x)
self.assertEqual(len(ws), 2)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -512,7 +512,9 @@ class HooksTests(torch._dynamo.test_case.TestCase):
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd._enable(compiler_fn):
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2, obj)
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: str"):
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "Failed to trace builtin operator"
):
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(obj.count, 2)

View File

@ -293,24 +293,6 @@ class MiscTests(torch._inductor.test_case.TestCase):
with self.assertRaises(TypeError):
fn(torch.randn(16))
@unittest.skipIf(not python_pytree._cxx_pytree_exists, "missing optree package")
def test_optree_graph_break_message(self):
import optree
@torch.compile(backend="eager")
def fn(x):
d = {"a": 1}
optree.tree_flatten(d)
return torch.sin(x)
fn(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 1)
first_graph_break = list(counters["graph_break"].keys())[0]
self.assertExpectedInline(
first_graph_break,
"Graph break for an optree C/C++ function optree._C.PyCapsule.flatten. Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py",
)
def test_scalar_device_movement(self):
if not torch._dynamo.config.assume_static_by_default:
self.skipTest("Doesn't work with symints")
@ -324,74 +306,6 @@ class MiscTests(torch._inductor.test_case.TestCase):
res_compiled = add_fn(2, 3, torch.tensor(0.0))
self.assertEqual(res, res_compiled)
@scoped_load_inline
@skipIfNNModuleInlined("fails internal CI")
@unittest.skipIf(IS_FBCODE, "inline cpp_extension doesn't work in fbcode")
def test_cpp_extension_recommends_custom_ops(self, load_inline):
cpp_source = """
#include <torch/extension.h>
at::Tensor foobar(const at::Tensor& x) {
return x.clone();
}
"""
module = load_inline(
name="mylib",
cpp_sources=cpp_source,
functions="foobar",
verbose=True,
)
x = torch.ones(2, 2, requires_grad=True)
counters.clear()
@torch.compile(backend="eager")
def f(x):
return module.foobar(x)
with self.assertWarnsOnceRegex(
UserWarning,
".*https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html.*",
):
f(x)
self.assertEqual(len(counters["graph_break"]), 1)
first_graph_break = list(counters["graph_break"].keys())[0]
self.assertExpectedInline(
first_graph_break,
"""Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""",
)
cpp_source = """
#include <torch/extension.h>
at::Tensor baz(const at::Tensor& x) {
return x.clone();
}
"""
module2 = load_inline(
name="mylib2",
cpp_sources=cpp_source,
functions="baz",
verbose=True,
)
torch._dynamo.reset()
# Test that each warning only happens once
@torch.compile(backend="eager")
def f(x):
module2.baz(x)
module.foobar(x)
module.foobar(x)
module2.baz(x)
module.foobar(x)
module2.baz(x)
return x.clone()
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always")
f(x)
f(x)
self.assertEqual(len(ws), 2)
def test_callpacked(self):
def call_packed(args):
a, b, c = args
@ -8967,7 +8881,7 @@ def ___make_guard_fn():
# and so the guard story for the objects passed into input just isn't there atm.
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"^call_method UserDefinedObjectVariable\\(set\\).*",
"Unsupported method call",
):
foo(inp)
@ -10643,7 +10557,7 @@ ShapeEnv not equal: field values don't match:
# Should only be one restart per event
(restart_reason,) = metrics[0].restart_reasons
self.assertTrue(
"skip function graph_break" in restart_reason,
"User-inserted graph break" in restart_reason,
"Should have logged graph break reason",
)
self.assertTrue(
@ -10653,7 +10567,7 @@ ShapeEnv not equal: field values don't match:
(restart_reason,) = metrics[1].restart_reasons
self.assertTrue(
"skip function graph_break" in restart_reason,
"User-inserted graph break" in restart_reason,
"Should have logged graph break reason",
)
self.assertTrue(

View File

@ -1772,8 +1772,12 @@ def forward(self, x_1):
self.assertExpectedInline(
next(iter(counters["graph_break"].keys())).replace(";", "\n"),
"""\
dynamic shape operator: _torch_testing.numpy_nonzero.default
to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""",
Dynamic shape operator
Explanation: Operator `_torch_testing.numpy_nonzero.default`'s output shape depends on input Tensor data.
Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`
Developer debug context: _torch_testing.numpy_nonzero.default
""",
)
# pre-existing problem: torch.compile(dynamic=True) will, by default,

View File

@ -206,7 +206,7 @@ def disallow_in_graph(fn):
@_disallow_in_graph_helper(throw_if_not_allowed=False)
def graph_break():
def graph_break(msg=""):
"""Force a graph break"""

View File

@ -441,6 +441,69 @@ def unimplemented(
raise Unsupported(msg, case_name=case_name)
def unimplemented_v2_with_warning(
e: Exception,
code: types.CodeType,
gb_type: str,
context: str,
explanation: str,
hints: list[str],
) -> NoReturn:
# This function calls unimplemented internally and eventually graph breaks
# or falls to eager. unimplemented itself does not print any user warnings,
# i.e., its very silent. This helper function is intended when an error is
# encountered in the torch.compile stack which is worth showing as warning
# to the user. For example, if AOT Autograd backend fails with a fake tensor
# exception, its ok to fallback to eager but not silently. Here, we can use
# this function to log the message and the stack trace.
graph_break_msg = format_error_msg_verbose(e, code)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "dynamo_graph_break_reason",
"encoding": "string",
},
payload_fn=lambda: graph_break_msg,
)
graph_breaks_log.debug("%s", graph_break_msg)
unimplemented_v2(gb_type, context, explanation, hints, from_exc=e, log_warning=True)
# TODO replace old unimplemented later
def unimplemented_v2(
gb_type: str,
context: str,
explanation: str,
hints: list[str],
*,
from_exc: Any = _NOTHING,
log_warning: bool = False,
) -> NoReturn:
"""
Called within dynamo to cause a graph break.
Args:
gb_type: Context-free graph break type. It should be a short string without any
information specific to the tracing context (i.e. no dynamically-generated strings)
context: Developer context for the graph break. It can contain tracing context/dynamic strings.
explanation: User-facing context-dependent explanation for the graph break. Can be dynamic.
hints: List of user-facing hints for the graph break.
"""
hints_str = "\n".join(hints)
hints_str = textwrap.indent(hints_str, " Hint: ")
msg = f"""\
{gb_type}
Explanation: {explanation}
{hints_str}
Developer debug context: {context}
"""
if log_warning:
log.warning(msg)
if from_exc is not _NOTHING:
raise Unsupported(msg) from from_exc
raise Unsupported(msg)
def warning(msg: str) -> None:
counters["warnings"][msg] += 1
assert msg != os.environ.get("BREAK", False)

View File

@ -82,7 +82,7 @@ from .exc import (
exceptions_allowed_to_be_fallback,
SkipFrame,
unimplemented,
unimplemented_with_warning,
unimplemented_v2_with_warning,
)
from .graph_deduplication import apply_graph_deduplication
from .graph_region_tracker import GraphRegionTracker
@ -652,9 +652,11 @@ class OutputGraph:
"""
global_state = cast(
dict[str, tuple[Callable[..., Any], bool]],
out
if out is not None
else self.tracing_context.global_context.global_state,
(
out
if out is not None
else self.tracing_context.global_context.global_state
),
)
# TODO - Consider having a torch level API for torch_function_state. As
@ -1475,12 +1477,12 @@ class OutputGraph:
gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment]
name = (
self.compiler_fn.__name__
if hasattr(self.compiler_fn, "__name__")
else "<unknown compiler_fn>"
)
try:
name = (
self.compiler_fn.__name__
if hasattr(self.compiler_fn, "__name__")
else ""
)
_step_logger()(logging.INFO, f"calling compiler function {name}")
compiler_fn = self.compiler_fn
if config.verify_correctness:
@ -1495,12 +1497,16 @@ class OutputGraph:
raise BackendCompilerFailed(
self.compiler_fn, e, inspect.currentframe()
).with_traceback(e.__traceback__) from None
msg = (
f"Backend compiler failed with {str(e)} at \n"
f"{self.root_tx.format_frame_summary()}"
"Adding a graph break."
unimplemented_v2_with_warning(
e,
self.root_tx.f_code,
gb_type="Backend compiler exception",
context=f"Backend: {name}\nException:{str(e)}\nTraceback:\n{self.root_tx.format_frame_summary()}",
explanation=f"Backend compiler `{name}` failed with {str(e)}. Adding a graph break.",
hints=[
"Report an issue to the backend compiler repo.",
],
)
unimplemented_with_warning(e, self.root_tx.f_code, msg)
except SkipFrame as e:
# The backend compiler has requested that we skip the frame, instead of
# aborting execution.

View File

@ -74,7 +74,13 @@ from .bytecode_transformation import (
)
from .code_context import code_context
from .codegen import PyCodegen
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
from .exc import (
ArgsMismatchError,
BackendCompilerFailed,
unimplemented,
unimplemented_v2,
Unsupported,
)
from .funcname_cache import get_funcname
from .guards import GuardBuilder, install_guard
from .output_graph import GraphCompileReason, OutputGraph
@ -488,7 +494,7 @@ def log_graph_break(code_options, reason="", exc_info=False, user_stack=None):
user_stack_formatted = "".join(traceback.format_list(user_stack))
user_stack_trace = (
"Graph break in user code at %s:%s\nReason: %s\nUser code traceback:\n%s" # noqa: UP031
"Graph break in user code at %s:%s\nGraph Break Reason: %s\nUser code traceback:\n%s" # noqa: UP031
% (
frame_loc[0],
frame_loc[1],
@ -523,7 +529,7 @@ def log_graph_break(code_options, reason="", exc_info=False, user_stack=None):
# exercised by
# python test/dynamo/test_misc.py -k test_duplicate_graph_break_log
graph_break_log.debug(
"Graph break (details suppressed) in user code at %s:%s\nReason: %s",
"Graph break (user stack suppressed due to duplicate graph break) in user code at %s:%s\nGraph Break Reason: %s",
frame_loc[0],
frame_loc[1],
reason,
@ -762,7 +768,7 @@ def break_graph_if_unsupported(*, push):
log_graph_break(
self.code_options,
exc_info=True,
reason=f"Unsupported: {excp}",
reason=str(excp),
user_stack=excp.real_stack,
)
@ -2546,7 +2552,16 @@ class InstructionTranslatorBase(
if not isinstance(
ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
):
unimplemented(f"{inst.opname} {ctx}")
unimplemented_v2(
gb_type="Unsupported context manager",
context=f"Attempted SETUP_WITH/BEFORE_WITH on {ctx}",
explanation=f"Dynamo does not know how to enter a `{ctx.python_type_name()}` context manager.",
hints=[
"Avoid using the unsupported context manager.",
"File an issue to PyTorch. Simple context managers can potentially be supported, "
"but note that context managers can't be supported in general",
],
)
if (
isinstance(ctx, GenericContextWrappingVariable)
@ -3289,15 +3304,36 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
False, "allowlist in dynamo known function"
)
fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else ""
unimplemented(
f"'inline in skipfiles: {fn_qualname} | {func.get_name()} {func.get_filename()}, {result.reason}'"
hints = [
f"Avoid calling the function `{fn_qualname}`.",
]
if "_dynamo" not in func.get_filename():
hints += [
f"Remove the function `{fn_qualname}` or the file `{func.get_filename()}` "
"from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of "
"attempting to trace into the function.",
"Please file an issue to PyTorch.",
# TODO suggest mark_force_inline when implemented
]
unimplemented_v2(
gb_type="Attempted to inline function marked as skipped",
context=f"qualname: {fn_qualname}, name: {func.get_name()}, "
f"filename: `{func.get_filename()}`, skip reason: {result.reason}",
explanation=f"Dynamo developers have intentionally marked that the function `{fn_qualname}` "
"should not be traced.",
hints=hints,
)
if isinstance(func, UserFunctionVariable) and inspect.getattr_static(
func.get_function(), "_torchdynamo_disable", False
):
unimplemented(
f"call torch._dynamo.disable() wrapped function {func.get_function()}"
unimplemented_v2(
gb_type="Skip inlining `torch.compiler.disable()`d function",
context=str(func.get_function()),
explanation=f"Skip inlining function {func.get_function()} since it was wrapped with `torch.compiler.disable`",
hints=[
"Remove the `torch.compiler.disable` call",
],
)
else:
return result

View File

@ -3004,6 +3004,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
from .exc import (
TorchRuntimeError,
unimplemented,
unimplemented_v2,
Unsupported,
UserError,
UserErrorType,
@ -3063,23 +3064,50 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
if isinstance(
cause, torch._subclasses.fake_tensor.DataDependentOutputException
):
unimplemented(
f"data dependent operator: {cause.func}; "
"to enable, set torch._dynamo.config.capture_scalar_outputs = True"
# capture_scalar_outputs only works for these ops right now
# see torch/_subclasses/fake_impls.py
if cause.func in (
torch.ops.aten.item.default,
torch.ops.aten._local_scalar_dense.default,
):
# does this actually get triggered?
hints = [
"Enable tracing of data-dependent output operators with "
"`torch._dynamo.config.capture_scalar_outputs = True`",
]
else:
hints = [
"Consider wrapping the operator into a PyTorch-understood custom operator "
"(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html)",
]
unimplemented_v2(
gb_type="Data dependent operator",
context=str(cause.func),
explanation=f"Operator `{cause.func}` has a non-Tensor output "
"whose value is dependent on the data of Tensor inputs.",
hints=hints,
)
elif isinstance(
cause, torch._subclasses.fake_tensor.DynamicOutputShapeException
):
if not torch._dynamo.config.capture_dynamic_output_shape_ops:
unimplemented(
f"dynamic shape operator: {cause.func}; "
"to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True"
unimplemented_v2(
gb_type="Dynamic shape operator",
context=str(cause.func),
explanation=f"Operator `{cause.func}`'s output shape depends on input Tensor data.",
hints=[
"Enable tracing of dynamic shape operators with "
"`torch._dynamo.config.capture_dynamic_output_shape_ops = True`",
],
)
else:
unimplemented(
f"dynamic shape operator: {cause.func}; "
"Operator does not have a meta kernel that supports dynamic output shapes, "
"please report an issue to PyTorch"
unimplemented_v2(
gb_type="Dynamic shape operator (no meta kernel)",
context=str(cause.func),
explanation=f"Operator `{cause.func}` does not have a meta kernel that supports dynamic output shapes",
hints=[
"Please report an issue to PyTorch",
],
)
elif isinstance(
cause, torch._subclasses.fake_tensor.UnsupportedOperatorException

View File

@ -21,7 +21,7 @@ from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
from .. import variables
from ..current_scope_id import current_scope_id
from ..exc import unimplemented
from ..exc import unimplemented, unimplemented_v2
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, Source
from ..utils import cmp_name_to_op_mapping, istype
@ -310,6 +310,12 @@ class VariableTracker(metaclass=VariableTrackerMeta):
except NotImplementedError:
raise NotImplementedError(f"{self} has no type") from None
def python_type_name(self):
try:
return self.python_type().__name__
except NotImplementedError:
return "<unknown type>"
def as_python_constant(self):
"""For constants"""
raise NotImplementedError(f"{self} is not a constant")
@ -408,7 +414,15 @@ class VariableTracker(metaclass=VariableTrackerMeta):
args: Sequence["VariableTracker"],
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
unimplemented(f"call_function {self} {args} {kwargs}")
unimplemented_v2(
gb_type="Unsupported function call",
context=f"call_function {self} {args} {kwargs}",
explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`",
hints=[
f"Avoid calling `{self.debug_repr()}` in your code.",
"Please report an issue to PyTorch.",
],
)
def call_method(
self,
@ -450,7 +464,27 @@ class VariableTracker(metaclass=VariableTrackerMeta):
self.as_python_constant(), other.as_python_constant()
)
)
unimplemented(f"call_method {self} {name} {args} {kwargs}")
hints = [
f"Avoid calling `{self.python_type_name()}.{name}` in your code.",
"Please report an issue to PyTorch.",
]
# additional hint for method calls on improperly constructed iterators
if isinstance(self, variables.UserDefinedObjectVariable) and name in (
"__iter__",
"__next__",
):
hints.append(
"Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) "
"passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). "
"This can happen unintentionally if a previous graph break happens with a builtin iterator "
"in the local scope."
)
unimplemented_v2(
gb_type="Unsupported method call",
context=f"call_method {self} {name} {args} {kwargs}",
explanation=f"Dynamo does not know how to trace method `{name}` of class `{self.python_type_name()}`",
hints=hints,
)
def set_name_hint(self, name):
pass

View File

@ -24,6 +24,7 @@ from ..exc import (
ObservedAttributeError,
raise_observed_exception,
unimplemented,
unimplemented_v2,
Unsupported,
UserError,
UserErrorType,
@ -902,9 +903,24 @@ class BuiltinVariable(VariableTracker):
handlers.append(constant_fold_handler)
error_msg = f"builtin: {fn.__name__} {arg_types} {has_kwargs}"
def call_unimplemented_v2(args):
real_arg_types = [arg.python_type_name() for arg in args]
unimplemented_v2(
gb_type="Failed to trace builtin operator",
context=f"builtin {fn.__name__} {arg_types} {has_kwargs}",
explanation=f"Dynamo does not know how to trace builtin operator `{fn.__name__}` "
f"with argument types {real_arg_types} (has_kwargs {has_kwargs})",
hints=[
f"Avoid calling builtin `{fn.__name__}` with argument types {real_arg_types}. "
f"Consider using an equivalent alternative function/method to `{fn.__name__}`.",
"If you are attempting to call a logging function (e.g. `print`), "
"you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.",
"Please report an issue to PyTorch.",
],
)
if len(handlers) == 0:
return lambda *args: unimplemented(error_msg)
return lambda tx, args, kwargs: call_unimplemented_v2(args)
elif len(handlers) == 1:
(handler,) = handlers
@ -912,7 +928,7 @@ class BuiltinVariable(VariableTracker):
rv = handler(tx, args, kwargs)
if rv:
return rv
unimplemented(error_msg)
call_unimplemented_v2(args)
else:
@ -921,7 +937,7 @@ class BuiltinVariable(VariableTracker):
rv = fn(tx, args, kwargs)
if rv:
return rv
unimplemented(error_msg)
call_unimplemented_v2(args)
return builtin_dispatch

View File

@ -49,6 +49,7 @@ from ..exc import (
raise_observed_exception,
SkipFrame,
unimplemented,
unimplemented_v2,
Unsupported,
)
from ..guards import GuardBuilder, install_guard
@ -1136,7 +1137,26 @@ class SkipFunctionVariable(VariableTracker):
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
unimplemented_v2(
gb_type="Skip calling `torch.compiler.disable()`d function",
context=str(self.value),
explanation=f"Skip calling function `{self.value}` since it was wrapped with `torch.compiler.disable`",
hints=[
"Remove the `torch.compiler.disable` call",
],
)
elif self.value is torch._dynamo.graph_break:
graph_break_msg = kwargs.get("msg", None)
if graph_break_msg:
graph_break_msg = graph_break_msg.as_python_constant()
unimplemented_v2(
gb_type="Call to `torch._dynamo.graph_break()`",
context=f"Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`",
explanation=f"User-inserted graph break. Message: {graph_break_msg}",
hints=[
"Remove the `torch._dynamo.graph_break()` call.",
],
)
elif isinstance(self.value, types.WrapperDescriptorType):
msg = (
f"Graph break due to unsupported wrapper descriptor {self.value}. "
@ -1148,49 +1168,79 @@ class SkipFunctionVariable(VariableTracker):
else:
try:
path = inspect.getfile(self.value)
msg = f"'skip function {self.value.__qualname__} in file {path}'"
explanation = (
f"Dynamo developers have intentionally marked that the function `{self.value.__qualname__}` "
f"in file `{path}` should not be traced."
)
hints = [
f"Avoid calling the function `{self.value.__qualname__}`.",
]
# TODO improve trace_rules reasoning to provide better hints.
# How do we tell that a function/file should NOT be removed from skip files?
# Do a very basic check for now.
if "_dynamo" not in path:
hints += [
f"Remove the function `{self.value.__qualname__}` or the file `{path}` "
"from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of "
"attempting to trace into the function.",
"Please file an issue to PyTorch.",
# TODO suggest mark_force_inline when implemented
]
except TypeError:
known_python_builtin_modules = {"_abc", "_warnings"}
if self.value.__module__ in known_python_builtin_modules:
msg = (
f"Graph break due to unsupported Python builtin {self.value.__module__}.{self.value.__qualname__}. "
f"Please file an issue on GitHub "
f"so the PyTorch team can add support for it. "
explanation = (
f"Dynamo does not know how to trace the Python builtin "
f"`{self.value.__module__}.{self.value.__qualname__}`."
)
hints = [
"If you are attempting to call a logging function (e.g. `_warnings.warn`), "
"you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.",
"Please file an issue on GitHub "
"so the PyTorch team can add support for it. ",
]
elif (
self.value.__module__ is not None
and self.value.__module__.startswith("optree")
):
msg = (
f"Graph break for an optree C/C++ function {self.value.__module__}.{self.value.__qualname__}."
f" Consider using torch.utils._pytree - "
f"https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py"
)
explanation = f"Dynamo cannot trace optree C/C++ function {self.value.__module__}.{self.value.__qualname__}."
hints = [
" Consider using torch.utils._pytree - "
"https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py"
]
# also warn on it because most users won't see the graph break message
torch._dynamo.utils.warn_once(msg)
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
else:
msg = (
f"Graph break due to unsupported builtin {self.value.__module__}.{self.value.__qualname__}. "
explanation = (
f"Dynamo does not know how to trace the builtin `{self.value.__module__}.{self.value.__qualname__}.` "
f"This function is either a Python builtin (e.g. _warnings.warn) "
f"or a third-party C/C++ Python extension (perhaps created with pybind). "
f"If it is a Python builtin, please file an issue on GitHub "
f"so the PyTorch team can add support for it and see the next case for a workaround. "
f"If it is a third-party C/C++ Python extension, please "
f"either wrap it into a PyTorch-understood custom operator "
f"(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html "
f"for more details) or, if it is traceable, use "
f"torch.compiler.allow_in_graph."
f"or a third-party C/C++ Python extension (perhaps created with pybind)."
)
hints = [
"If it is a Python builtin, please file an issue on GitHub "
"so the PyTorch team can add support for it and see the next case for a workaround.",
"If it is a third-party C/C++ Python extension, please "
"either wrap it into a PyTorch-understood custom operator "
"(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html "
"for more details) or, if it is traceable, use "
"`torch.compiler.allow_in_graph`.",
]
# also warn on it because most users won't see the graph break message
torch._dynamo.utils.warn_once(msg)
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
if self.value.__qualname__ == "allow_in_graph":
msg = (
explanation = (
"Found an allow_in_graph decorator to a function which "
"is created inside the parent function that is getting "
"compiled. This is not supported for now."
)
msg += f"', {self.reason}'" if self.reason else ""
unimplemented(msg)
hints = []
reason = self.reason if self.reason else "<missing reason>"
unimplemented_v2(
gb_type="Attempted to call function marked as skipped",
context=f"module: {self.value.__module__}, qualname: {self.value.__qualname__}, skip reason: {reason}",
explanation=explanation,
hints=hints,
)
def call_obj_hasattr(self, tx: "InstructionTranslator", name):
return variables.ConstantVariable.create(hasattr(self.value, name))

View File

@ -3102,13 +3102,16 @@ class TestCase(expecttest.TestCase):
# Munges exceptions that internally contain stack traces, using munge_exc
def assertExpectedInlineMunged(
self, exc_type, callable, expect, *, suppress_suffix=True
self, exc_type, callable, expect, *, suppress_suffix=True, post_munge=None,
):
try:
callable()
except exc_type as e:
munged = munge_exc(e, suppress_suffix=suppress_suffix, skip=1)
if post_munge:
munged = post_munge(munged)
self.assertExpectedInline(
munge_exc(e, suppress_suffix=suppress_suffix, skip=1), expect, skip=1
munged, expect, skip=1
)
return
self.fail(msg="Did not raise when expected to")