From 16e202a38e19abc1e01860a5d644be33a512f8bf Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 19 Feb 2025 11:46:13 -0800 Subject: [PATCH] [dynamo] improved graph break messages for some common graph break sites [1/N] (#146525) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146525 Approved by: https://github.com/jansel --- .flake8 | 1 + test/dynamo/test_activation_checkpointing.py | 2 +- test/dynamo/test_autograd_function.py | 5 +- test/dynamo/test_decorators.py | 2 +- test/dynamo/test_exc.py | 14 +- test/dynamo/test_graph_break_messages.py | 505 +++++++++++++++++++ test/dynamo/test_hooks.py | 4 +- test/dynamo/test_misc.py | 92 +--- test/test_custom_ops.py | 8 +- torch/_dynamo/decorators.py | 2 +- torch/_dynamo/exc.py | 63 +++ torch/_dynamo/output_graph.py | 34 +- torch/_dynamo/symbolic_convert.py | 54 +- torch/_dynamo/utils.py | 48 +- torch/_dynamo/variables/base.py | 40 +- torch/_dynamo/variables/builtin.py | 24 +- torch/_dynamo/variables/functions.py | 102 +++- torch/testing/_internal/common_utils.py | 7 +- 18 files changed, 841 insertions(+), 166 deletions(-) create mode 100644 test/dynamo/test_graph_break_messages.py diff --git a/.flake8 b/.flake8 index 4e1cb4642d41..3a426d48173e 100644 --- a/.flake8 +++ b/.flake8 @@ -38,6 +38,7 @@ per-file-ignores = torchgen/api/types/__init__.py: F401,F403 torchgen/executorch/api/types/__init__.py: F401,F403 test/dynamo/test_higher_order_ops.py: B950 + test/dynamo/test_graph_break_messages.py: B950 torch/testing/_internal/dynamo_test_failures.py: B950 # TOR901 is only for test, we want to ignore it for everything else. # It's not easy to configure this without affecting other per-file-ignores, diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 7d90d3909278..baac1724a9d7 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -1251,7 +1251,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no x = torch.randn(4, 4).to(device) opt_fn = torch.compile(fn, fullgraph=True) with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, "skip function graph_break in file" + torch._dynamo.exc.Unsupported, "User-inserted graph break" ): opt_fn(x) diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index a9ac8d97976e..a7405cf7baca 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -271,7 +271,10 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase): model = CustomFuncBwdPrintModule() opt_model = torch.compile(model, backend="eager", fullgraph=True) x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: print"): + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + "Dynamo does not know how to trace builtin operator `print`", + ): opt_model(x) def test_stride_in_bwd(self): diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index bdf506416c0c..88b84cd92650 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -330,7 +330,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase): fn3(torch.randn(4, 5)) self.assertFalse(True) except torch._dynamo.exc.Unsupported as e: - self.assertIn("call torch._dynamo.disable() wrapped function", str(e)) + self.assertIn("Skip calling `torch.compiler.disable()`d function", str(e)) def test_disable_optimize(self): cnt = torch._dynamo.testing.CompileCounter() diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index 4ecce0028b82..28ec57dc43a7 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -37,7 +37,12 @@ class ExcTests(LoggingTestCase): torch.randn(1) ), """\ -'skip function graph_break in file _dynamo/decorators.py' +Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + from user code: File "test_exc.py", line N, in fn001 @@ -171,7 +176,12 @@ from user code: munge_exc(record.getMessage()), """\ Graph break in user code at test_exc.py:N -Reason: Unsupported: 'skip function graph_break in file _dynamo/decorators.py' +Graph Break Reason: Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + User code traceback: File "test_exc.py", line N, in fn001 return fn002(x) diff --git a/test/dynamo/test_graph_break_messages.py b/test/dynamo/test_graph_break_messages.py new file mode 100644 index 000000000000..4e59bc4d1d09 --- /dev/null +++ b/test/dynamo/test_graph_break_messages.py @@ -0,0 +1,505 @@ +# Owner(s): ["module: dynamo"] + +import re +import unittest +import warnings + +import torch +import torch._dynamo +import torch._dynamo.config +import torch._dynamo.test_case +import torch.utils._pytree as python_pytree +from torch._dynamo.exc import Unsupported +from torch._dynamo.utils import counters +from torch.testing._internal.common_utils import IS_FBCODE, scoped_load_inline + + +""" +NOTE Adding tests to this file: + +It is good practice to add a minimal repro for each graph break site (i.e. `unimplemented()` call +to make sure that there aren't any errors that occur when generating graph break messages. + +If a graph break message test fails because the graph break no longer repros, +it is good practice to find a new minimal repro that causes the graph break. +If this is too much work, it is likely safe to skip/remove the test, assuming +it was previously passing and the graph break message is not changed. +However, if you add a new graph break or modify a graph break message, you should +make sure that there is a test for it. +""" + + +class GraphBreakMessagesTest(torch._dynamo.test_case.TestCase): + maxDiff = None + + def test_dynamic_shape_operator(self): + def fn(): + return torch.nonzero(torch.rand([10, 10])) + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(), + """\ +Dynamic shape operator + Explanation: Operator `aten.nonzero.default`'s output shape depends on input Tensor data. + Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True` + + Developer debug context: aten.nonzero.default + + +from user code: + File "test_graph_break_messages.py", line N, in fn + return torch.nonzero(torch.rand([10, 10]))""", + ) + + def test_dynamic_shape_operator_no_meta_kernel(self): + def fn(): + return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10)) + + with torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True): + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(), + """\ +Dynamic shape operator (no meta kernel) + Explanation: Operator `aten.linalg_lstsq.default` does not have a meta kernel that supports dynamic output shapes + Hint: Please report an issue to PyTorch + + Developer debug context: aten.linalg_lstsq.default + + +from user code: + File "test_graph_break_messages.py", line N, in fn + return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10))""", + ) + + def test_data_dependent_operator(self): + def fn(x): + return x.item() + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)( + torch.Tensor([1]) + ), + """\ +Tensor.item + +from user code: + File "test_graph_break_messages.py", line N, in fn + return x.item()""", + ) + + def test_data_dependent_operator2(self): + def fn(x): + return torch.equal(x, x) + + with torch._dynamo.config.patch(capture_scalar_outputs=True): + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)( + torch.ones(3) + ), + """\ +Data dependent operator + Explanation: Operator `aten.equal.default` has a non-Tensor output whose value is dependent on the data of Tensor inputs. + Hint: Consider wrapping the operator into a PyTorch-understood custom operator (see https:/pytorch.org/tutorials/advanced/custom_ops_landing_page.html) + + Developer debug context: aten.equal.default + + +from user code: + File "test_graph_break_messages.py", line N, in fn + return torch.equal(x, x)""", + ) + + def test_super_call_method(self): + def fn(it): + return [x + 1 for x in it] + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)( + zip(range(5), range(10)) + ), + """\ +Unsupported method call + Explanation: Dynamo does not know how to trace method `__iter__` of class `zip` + Hint: Avoid calling `zip.__iter__` in your code. + Hint: Please report an issue to PyTorch. + Hint: Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). This can happen unintentionally if a previous graph break happens with a builtin iterator in the local scope. + + Developer debug context: call_method UserDefinedObjectVariable(zip) __iter__ () {} + + +from user code: + File "test_graph_break_messages.py", line N, in fn + return [x + 1 for x in it]""", + ) + + def test_super_call_function(self): + def fn(it): + return [x + 1 for x in it()] + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)( + zip(range(5), range(10)) + ), + """\ +Unsupported function call + Explanation: Dynamo does not know how to trace the function `UserDefinedObjectVariable(zip)` + Hint: Avoid calling `UserDefinedObjectVariable(zip)` in your code. + Hint: Please report an issue to PyTorch. + + Developer debug context: call_function UserDefinedObjectVariable(zip) [] {} + + +from user code: + File "test_graph_break_messages.py", line N, in fn + return [x + 1 for x in it()]""", + ) + + def test_unsupported_context(self): + def fn(obj): + with obj: + return 1 + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(3), + """\ +Unsupported context manager + Explanation: Dynamo does not know how to enter a `int` context manager. + Hint: Avoid using the unsupported context manager. + Hint: File an issue to PyTorch. Simple context managers can potentially be supported, but note that context managers can't be supported in general + + Developer debug context: Attempted SETUP_WITH/BEFORE_WITH on ConstantVariable(int: 3) + + +from user code: + File "test_graph_break_messages.py", line N, in fn + with obj:""", + ) + + def test_backend_fake_tensor_exc(self): + def bad_backend(gm, ex): + raise torch._subclasses.fake_tensor.UnsupportedFakeTensorException("test") + + def fn(x): + return x + 1 + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend=bad_backend, fullgraph=True)( + torch.ones(3, 3) + ), + """\ +Backend compiler exception + Explanation: Backend compiler `bad_backend` failed with test. Adding a graph break. + Hint: Report an issue to the backend compiler repo. + + Developer debug context: Backend: bad_backend +Exception:test +Traceback: + File "test_graph_break_messages.py", line N, in fn + return x + 1""", + ) + + def test_unsupported_builtin(self): + def fn(): + print("abc") + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(), + """\ +Failed to trace builtin operator + Explanation: Dynamo does not know how to trace builtin operator `print` with argument types ['str'] (has_kwargs False) + Hint: Avoid calling builtin `print` with argument types ['str']. Consider using an equivalent alternative function/method to `print`. + Hint: If you are attempting to call a logging function (e.g. `print`), you can try adding it to `torch._dynamo.config.reorderable_logging_functions`. + Hint: Please report an issue to PyTorch. + + Developer debug context: builtin print [] False + + +from user code: + File "test_graph_break_messages.py", line N, in fn + print("abc")""", + ) + + def test_skipfile_call(self): + def fn(): + return unittest.skip("test") + + def post_munge(s): + return re.sub(r"file `.*case\.py`", "file `case.py`", s) + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(), + """\ +Attempted to call function marked as skipped + Explanation: Dynamo developers have intentionally marked that the function `skip` in file `case.py` should not be traced. + Hint: Avoid calling the function `skip`. + Hint: Remove the function `skip` or the file `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function. + Hint: Please file an issue to PyTorch. + + Developer debug context: module: unittest.case, qualname: skip, skip reason: + + +from user code: + File "test_graph_break_messages.py", line N, in fn + return unittest.skip("test")""", + post_munge=post_munge, + ) + + def test_skipfile_dynamo_call(self): + def fn(): + torch._dynamo.disable() + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(), + """\ +Attempted to call function marked as skipped + Explanation: Dynamo developers have intentionally marked that the function `disable` in file `_dynamo/decorators.py` should not be traced. + Hint: Avoid calling the function `disable`. + + Developer debug context: module: torch._dynamo.decorators, qualname: disable, skip reason: + + +from user code: + File "test_graph_break_messages.py", line N, in fn + torch._dynamo.disable()""", + ) + + def test_skipfile_inline(self): + class Foo: + fn = unittest.skip + + def fn(): + Foo().fn() + + def post_munge(s): + return re.sub(r"`.*case\.py`", "`case.py`", s) + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(), + """\ +Attempted to inline function marked as skipped + Explanation: Dynamo developers have intentionally marked that the function `skip` should not be traced. + Hint: Avoid calling the function `skip`. + Hint: Remove the function `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function. + Hint: Please file an issue to PyTorch. + + Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup SKIP_DIRS + + +from user code: + File "test_graph_break_messages.py", line N, in fn + Foo().fn()""", + post_munge=post_munge, + ) + + def test_disable(self): + @torch.compiler.disable + def inner(): + return 1 + + def fn(): + return inner() + + def post_munge(s): + return re.sub( + r"\.inner at 0x[0-9A-Fa-f]+>", + "", + s, + ) + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(), + """\ +Skip calling `torch.compiler.disable()`d function + Explanation: Skip calling function `` since it was wrapped with `torch.compiler.disable` + Hint: Remove the `torch.compiler.disable` call + + Developer debug context: + + +from user code: + File "test_graph_break_messages.py", line N, in fn + return inner()""", + post_munge=post_munge, + ) + + def test_dynamo_graph_break_fn(self): + def fn(): + torch._dynamo.graph_break() + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(), + """\ +Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + + +from user code: + File "test_graph_break_messages.py", line N, in fn + torch._dynamo.graph_break()""", + ) + + def test_dynamo_graph_break_fn_with_msg(self): + def fn(): + torch._dynamo.graph_break(msg="test graph break") + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(), + """\ +Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: test graph break + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{'msg': ConstantVariable(str: 'test graph break')}` + + +from user code: + File "test_graph_break_messages.py", line N, in fn + torch._dynamo.graph_break(msg="test graph break")""", + ) + + def test_warnings(self): + def fn(): + warnings.warn("test") + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(), + """\ +Attempted to call function marked as skipped + Explanation: Dynamo does not know how to trace the Python builtin `_warnings.warn`. + Hint: If you are attempting to call a logging function (e.g. `_warnings.warn`), you can try adding it to `torch._dynamo.config.reorderable_logging_functions`. + Hint: Please file an issue on GitHub so the PyTorch team can add support for it. + + Developer debug context: module: _warnings, qualname: warn, skip reason: + + +from user code: + File "test_graph_break_messages.py", line N, in fn + warnings.warn("test")""", + ) + + @unittest.skipIf(not python_pytree._cxx_pytree_exists, "missing optree package") + def test_optree_graph_break_message(self): + import optree + + @torch.compile(backend="eager") + def fn(x): + d = {"a": 1} + optree.tree_flatten(d) + return torch.sin(x) + + fn(torch.randn(4)) + self.assertEqual(len(counters["graph_break"]), 1) + first_graph_break = next(iter(counters["graph_break"].keys())) + self.assertExpectedInline( + first_graph_break, + """\ +Attempted to call function marked as skipped + Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten. + Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py + + Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: +""", + ) + + @scoped_load_inline + @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) + @unittest.skipIf(IS_FBCODE, "inline cpp_extension doesn't work in fbcode") + def test_cpp_extension_recommends_custom_ops(self, load_inline): + cpp_source = """ + #include + at::Tensor foobar(const at::Tensor& x) { + return x.clone(); + } + """ + module = load_inline( + name="mylib", + cpp_sources=cpp_source, + functions="foobar", + verbose=True, + ) + + x = torch.ones(2, 2, requires_grad=True) + counters.clear() + + @torch.compile(backend="eager") + def f(x): + return module.foobar(x) + + with self.assertWarnsOnceRegex( + UserWarning, + "(?s).*https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html.*", + ): + f(x) + self.assertEqual(len(counters["graph_break"]), 1) + first_graph_break = next(iter(counters["graph_break"].keys())) + + first_graph_break = re.sub(r"mylib(_v\d+)?", "mylib", first_graph_break) + + self.assertExpectedInline( + first_graph_break, + """\ +Attempted to call function marked as skipped + Explanation: Dynamo does not know how to trace the builtin `mylib.PyCapsule.foobar.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). + Hint: If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. + Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`. + + Developer debug context: module: mylib, qualname: PyCapsule.foobar, skip reason: +""", + ) + + cpp_source = """ + #include + at::Tensor baz(const at::Tensor& x) { + return x.clone(); + } + """ + module2 = load_inline( + name="mylib2", + cpp_sources=cpp_source, + functions="baz", + verbose=True, + ) + + torch._dynamo.reset() + + # Test that each warning only happens once + @torch.compile(backend="eager") + def f(x): + module2.baz(x) + module.foobar(x) + module.foobar(x) + module2.baz(x) + module.foobar(x) + module2.baz(x) + return x.clone() + + with warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("always") + f(x) + f(x) + self.assertEqual(len(ws), 2) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_hooks.py b/test/dynamo/test_hooks.py index 29ff1ddf93fe..a6b65c4d9a34 100644 --- a/test/dynamo/test_hooks.py +++ b/test/dynamo/test_hooks.py @@ -512,7 +512,9 @@ class HooksTests(torch._dynamo.test_case.TestCase): x2 = torch.ones(4, requires_grad=True) with compiled_autograd._enable(compiler_fn): dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2, obj) - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: str"): + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, "Failed to trace builtin operator" + ): dynamo_out[0].backward(torch.ones(4)) self.assertEqual(obj.count, 2) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 10aa9aba40a1..3289636e72f2 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -293,24 +293,6 @@ class MiscTests(torch._inductor.test_case.TestCase): with self.assertRaises(TypeError): fn(torch.randn(16)) - @unittest.skipIf(not python_pytree._cxx_pytree_exists, "missing optree package") - def test_optree_graph_break_message(self): - import optree - - @torch.compile(backend="eager") - def fn(x): - d = {"a": 1} - optree.tree_flatten(d) - return torch.sin(x) - - fn(torch.randn(4)) - self.assertEqual(len(counters["graph_break"]), 1) - first_graph_break = list(counters["graph_break"].keys())[0] - self.assertExpectedInline( - first_graph_break, - "Graph break for an optree C/C++ function optree._C.PyCapsule.flatten. Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py", - ) - def test_scalar_device_movement(self): if not torch._dynamo.config.assume_static_by_default: self.skipTest("Doesn't work with symints") @@ -324,74 +306,6 @@ class MiscTests(torch._inductor.test_case.TestCase): res_compiled = add_fn(2, 3, torch.tensor(0.0)) self.assertEqual(res, res_compiled) - @scoped_load_inline - @skipIfNNModuleInlined("fails internal CI") - @unittest.skipIf(IS_FBCODE, "inline cpp_extension doesn't work in fbcode") - def test_cpp_extension_recommends_custom_ops(self, load_inline): - cpp_source = """ - #include - at::Tensor foobar(const at::Tensor& x) { - return x.clone(); - } - """ - module = load_inline( - name="mylib", - cpp_sources=cpp_source, - functions="foobar", - verbose=True, - ) - - x = torch.ones(2, 2, requires_grad=True) - counters.clear() - - @torch.compile(backend="eager") - def f(x): - return module.foobar(x) - - with self.assertWarnsOnceRegex( - UserWarning, - ".*https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html.*", - ): - f(x) - self.assertEqual(len(counters["graph_break"]), 1) - first_graph_break = list(counters["graph_break"].keys())[0] - self.assertExpectedInline( - first_graph_break, - """Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""", - ) - - cpp_source = """ - #include - at::Tensor baz(const at::Tensor& x) { - return x.clone(); - } - """ - module2 = load_inline( - name="mylib2", - cpp_sources=cpp_source, - functions="baz", - verbose=True, - ) - - torch._dynamo.reset() - - # Test that each warning only happens once - @torch.compile(backend="eager") - def f(x): - module2.baz(x) - module.foobar(x) - module.foobar(x) - module2.baz(x) - module.foobar(x) - module2.baz(x) - return x.clone() - - with warnings.catch_warnings(record=True) as ws: - warnings.simplefilter("always") - f(x) - f(x) - self.assertEqual(len(ws), 2) - def test_callpacked(self): def call_packed(args): a, b, c = args @@ -8967,7 +8881,7 @@ def ___make_guard_fn(): # and so the guard story for the objects passed into input just isn't there atm. with self.assertRaisesRegex( torch._dynamo.exc.Unsupported, - "^call_method UserDefinedObjectVariable\\(set\\).*", + "Unsupported method call", ): foo(inp) @@ -10643,7 +10557,7 @@ ShapeEnv not equal: field values don't match: # Should only be one restart per event (restart_reason,) = metrics[0].restart_reasons self.assertTrue( - "skip function graph_break" in restart_reason, + "User-inserted graph break" in restart_reason, "Should have logged graph break reason", ) self.assertTrue( @@ -10653,7 +10567,7 @@ ShapeEnv not equal: field values don't match: (restart_reason,) = metrics[1].restart_reasons self.assertTrue( - "skip function graph_break" in restart_reason, + "User-inserted graph break" in restart_reason, "Should have logged graph break reason", ) self.assertTrue( diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index c73f4a850ace..c92edc279f55 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -1772,8 +1772,12 @@ def forward(self, x_1): self.assertExpectedInline( next(iter(counters["graph_break"].keys())).replace(";", "\n"), """\ -dynamic shape operator: _torch_testing.numpy_nonzero.default - to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""", +Dynamic shape operator + Explanation: Operator `_torch_testing.numpy_nonzero.default`'s output shape depends on input Tensor data. + Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True` + + Developer debug context: _torch_testing.numpy_nonzero.default +""", ) # pre-existing problem: torch.compile(dynamic=True) will, by default, diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index ff26a1d8a37b..727b78a61c19 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -206,7 +206,7 @@ def disallow_in_graph(fn): @_disallow_in_graph_helper(throw_if_not_allowed=False) -def graph_break(): +def graph_break(msg=""): """Force a graph break""" diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 459012f2629d..2b17b136f869 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -441,6 +441,69 @@ def unimplemented( raise Unsupported(msg, case_name=case_name) +def unimplemented_v2_with_warning( + e: Exception, + code: types.CodeType, + gb_type: str, + context: str, + explanation: str, + hints: list[str], +) -> NoReturn: + # This function calls unimplemented internally and eventually graph breaks + # or falls to eager. unimplemented itself does not print any user warnings, + # i.e., its very silent. This helper function is intended when an error is + # encountered in the torch.compile stack which is worth showing as warning + # to the user. For example, if AOT Autograd backend fails with a fake tensor + # exception, its ok to fallback to eager but not silently. Here, we can use + # this function to log the message and the stack trace. + graph_break_msg = format_error_msg_verbose(e, code) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_graph_break_reason", + "encoding": "string", + }, + payload_fn=lambda: graph_break_msg, + ) + graph_breaks_log.debug("%s", graph_break_msg) + unimplemented_v2(gb_type, context, explanation, hints, from_exc=e, log_warning=True) + + +# TODO replace old unimplemented later +def unimplemented_v2( + gb_type: str, + context: str, + explanation: str, + hints: list[str], + *, + from_exc: Any = _NOTHING, + log_warning: bool = False, +) -> NoReturn: + """ + Called within dynamo to cause a graph break. + Args: + gb_type: Context-free graph break type. It should be a short string without any + information specific to the tracing context (i.e. no dynamically-generated strings) + context: Developer context for the graph break. It can contain tracing context/dynamic strings. + explanation: User-facing context-dependent explanation for the graph break. Can be dynamic. + hints: List of user-facing hints for the graph break. + """ + hints_str = "\n".join(hints) + hints_str = textwrap.indent(hints_str, " Hint: ") + msg = f"""\ +{gb_type} + Explanation: {explanation} +{hints_str} + + Developer debug context: {context} +""" + if log_warning: + log.warning(msg) + if from_exc is not _NOTHING: + raise Unsupported(msg) from from_exc + raise Unsupported(msg) + + def warning(msg: str) -> None: counters["warnings"][msg] += 1 assert msg != os.environ.get("BREAK", False) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index d90d0cf99291..95d36ababd0d 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -82,7 +82,7 @@ from .exc import ( exceptions_allowed_to_be_fallback, SkipFrame, unimplemented, - unimplemented_with_warning, + unimplemented_v2_with_warning, ) from .graph_deduplication import apply_graph_deduplication from .graph_region_tracker import GraphRegionTracker @@ -652,9 +652,11 @@ class OutputGraph: """ global_state = cast( dict[str, tuple[Callable[..., Any], bool]], - out - if out is not None - else self.tracing_context.global_context.global_state, + ( + out + if out is not None + else self.tracing_context.global_context.global_state + ), ) # TODO - Consider having a torch level API for torch_function_state. As @@ -1475,12 +1477,12 @@ class OutputGraph: gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment] gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment] + name = ( + self.compiler_fn.__name__ + if hasattr(self.compiler_fn, "__name__") + else "" + ) try: - name = ( - self.compiler_fn.__name__ - if hasattr(self.compiler_fn, "__name__") - else "" - ) _step_logger()(logging.INFO, f"calling compiler function {name}") compiler_fn = self.compiler_fn if config.verify_correctness: @@ -1495,12 +1497,16 @@ class OutputGraph: raise BackendCompilerFailed( self.compiler_fn, e, inspect.currentframe() ).with_traceback(e.__traceback__) from None - msg = ( - f"Backend compiler failed with {str(e)} at \n" - f"{self.root_tx.format_frame_summary()}" - "Adding a graph break." + unimplemented_v2_with_warning( + e, + self.root_tx.f_code, + gb_type="Backend compiler exception", + context=f"Backend: {name}\nException:{str(e)}\nTraceback:\n{self.root_tx.format_frame_summary()}", + explanation=f"Backend compiler `{name}` failed with {str(e)}. Adding a graph break.", + hints=[ + "Report an issue to the backend compiler repo.", + ], ) - unimplemented_with_warning(e, self.root_tx.f_code, msg) except SkipFrame as e: # The backend compiler has requested that we skip the frame, instead of # aborting execution. diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 09675fbf961c..fdb1ee04c860 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -74,7 +74,13 @@ from .bytecode_transformation import ( ) from .code_context import code_context from .codegen import PyCodegen -from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported +from .exc import ( + ArgsMismatchError, + BackendCompilerFailed, + unimplemented, + unimplemented_v2, + Unsupported, +) from .funcname_cache import get_funcname from .guards import GuardBuilder, install_guard from .output_graph import GraphCompileReason, OutputGraph @@ -488,7 +494,7 @@ def log_graph_break(code_options, reason="", exc_info=False, user_stack=None): user_stack_formatted = "".join(traceback.format_list(user_stack)) user_stack_trace = ( - "Graph break in user code at %s:%s\nReason: %s\nUser code traceback:\n%s" # noqa: UP031 + "Graph break in user code at %s:%s\nGraph Break Reason: %s\nUser code traceback:\n%s" # noqa: UP031 % ( frame_loc[0], frame_loc[1], @@ -523,7 +529,7 @@ def log_graph_break(code_options, reason="", exc_info=False, user_stack=None): # exercised by # python test/dynamo/test_misc.py -k test_duplicate_graph_break_log graph_break_log.debug( - "Graph break (details suppressed) in user code at %s:%s\nReason: %s", + "Graph break (user stack suppressed due to duplicate graph break) in user code at %s:%s\nGraph Break Reason: %s", frame_loc[0], frame_loc[1], reason, @@ -762,7 +768,7 @@ def break_graph_if_unsupported(*, push): log_graph_break( self.code_options, exc_info=True, - reason=f"Unsupported: {excp}", + reason=str(excp), user_stack=excp.real_stack, ) @@ -2546,7 +2552,16 @@ class InstructionTranslatorBase( if not isinstance( ctx, (ContextWrappingVariable, GenericContextWrappingVariable) ): - unimplemented(f"{inst.opname} {ctx}") + unimplemented_v2( + gb_type="Unsupported context manager", + context=f"Attempted SETUP_WITH/BEFORE_WITH on {ctx}", + explanation=f"Dynamo does not know how to enter a `{ctx.python_type_name()}` context manager.", + hints=[ + "Avoid using the unsupported context manager.", + "File an issue to PyTorch. Simple context managers can potentially be supported, " + "but note that context managers can't be supported in general", + ], + ) if ( isinstance(ctx, GenericContextWrappingVariable) @@ -3289,15 +3304,36 @@ class InliningInstructionTranslator(InstructionTranslatorBase): False, "allowlist in dynamo known function" ) fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else "" - unimplemented( - f"'inline in skipfiles: {fn_qualname} | {func.get_name()} {func.get_filename()}, {result.reason}'" + hints = [ + f"Avoid calling the function `{fn_qualname}`.", + ] + if "_dynamo" not in func.get_filename(): + hints += [ + f"Remove the function `{fn_qualname}` or the file `{func.get_filename()}` " + "from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of " + "attempting to trace into the function.", + "Please file an issue to PyTorch.", + # TODO suggest mark_force_inline when implemented + ] + unimplemented_v2( + gb_type="Attempted to inline function marked as skipped", + context=f"qualname: {fn_qualname}, name: {func.get_name()}, " + f"filename: `{func.get_filename()}`, skip reason: {result.reason}", + explanation=f"Dynamo developers have intentionally marked that the function `{fn_qualname}` " + "should not be traced.", + hints=hints, ) if isinstance(func, UserFunctionVariable) and inspect.getattr_static( func.get_function(), "_torchdynamo_disable", False ): - unimplemented( - f"call torch._dynamo.disable() wrapped function {func.get_function()}" + unimplemented_v2( + gb_type="Skip inlining `torch.compiler.disable()`d function", + context=str(func.get_function()), + explanation=f"Skip inlining function {func.get_function()} since it was wrapped with `torch.compiler.disable`", + hints=[ + "Remove the `torch.compiler.disable` call", + ], ) else: return result diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index a7e5cb035dbc..347b05f5383d 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -3004,6 +3004,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): from .exc import ( TorchRuntimeError, unimplemented, + unimplemented_v2, Unsupported, UserError, UserErrorType, @@ -3063,23 +3064,50 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): if isinstance( cause, torch._subclasses.fake_tensor.DataDependentOutputException ): - unimplemented( - f"data dependent operator: {cause.func}; " - "to enable, set torch._dynamo.config.capture_scalar_outputs = True" + # capture_scalar_outputs only works for these ops right now + # see torch/_subclasses/fake_impls.py + if cause.func in ( + torch.ops.aten.item.default, + torch.ops.aten._local_scalar_dense.default, + ): + # does this actually get triggered? + hints = [ + "Enable tracing of data-dependent output operators with " + "`torch._dynamo.config.capture_scalar_outputs = True`", + ] + else: + hints = [ + "Consider wrapping the operator into a PyTorch-understood custom operator " + "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html)", + ] + unimplemented_v2( + gb_type="Data dependent operator", + context=str(cause.func), + explanation=f"Operator `{cause.func}` has a non-Tensor output " + "whose value is dependent on the data of Tensor inputs.", + hints=hints, ) elif isinstance( cause, torch._subclasses.fake_tensor.DynamicOutputShapeException ): if not torch._dynamo.config.capture_dynamic_output_shape_ops: - unimplemented( - f"dynamic shape operator: {cause.func}; " - "to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True" + unimplemented_v2( + gb_type="Dynamic shape operator", + context=str(cause.func), + explanation=f"Operator `{cause.func}`'s output shape depends on input Tensor data.", + hints=[ + "Enable tracing of dynamic shape operators with " + "`torch._dynamo.config.capture_dynamic_output_shape_ops = True`", + ], ) else: - unimplemented( - f"dynamic shape operator: {cause.func}; " - "Operator does not have a meta kernel that supports dynamic output shapes, " - "please report an issue to PyTorch" + unimplemented_v2( + gb_type="Dynamic shape operator (no meta kernel)", + context=str(cause.func), + explanation=f"Operator `{cause.func}` does not have a meta kernel that supports dynamic output shapes", + hints=[ + "Please report an issue to PyTorch", + ], ) elif isinstance( cause, torch._subclasses.fake_tensor.UnsupportedOperatorException diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 74b09f0656b0..eed78c93593b 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -21,7 +21,7 @@ from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING from .. import variables from ..current_scope_id import current_scope_id -from ..exc import unimplemented +from ..exc import unimplemented, unimplemented_v2 from ..guards import GuardBuilder, install_guard from ..source import AttrSource, Source from ..utils import cmp_name_to_op_mapping, istype @@ -310,6 +310,12 @@ class VariableTracker(metaclass=VariableTrackerMeta): except NotImplementedError: raise NotImplementedError(f"{self} has no type") from None + def python_type_name(self): + try: + return self.python_type().__name__ + except NotImplementedError: + return "" + def as_python_constant(self): """For constants""" raise NotImplementedError(f"{self} is not a constant") @@ -408,7 +414,15 @@ class VariableTracker(metaclass=VariableTrackerMeta): args: Sequence["VariableTracker"], kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": - unimplemented(f"call_function {self} {args} {kwargs}") + unimplemented_v2( + gb_type="Unsupported function call", + context=f"call_function {self} {args} {kwargs}", + explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`", + hints=[ + f"Avoid calling `{self.debug_repr()}` in your code.", + "Please report an issue to PyTorch.", + ], + ) def call_method( self, @@ -450,7 +464,27 @@ class VariableTracker(metaclass=VariableTrackerMeta): self.as_python_constant(), other.as_python_constant() ) ) - unimplemented(f"call_method {self} {name} {args} {kwargs}") + hints = [ + f"Avoid calling `{self.python_type_name()}.{name}` in your code.", + "Please report an issue to PyTorch.", + ] + # additional hint for method calls on improperly constructed iterators + if isinstance(self, variables.UserDefinedObjectVariable) and name in ( + "__iter__", + "__next__", + ): + hints.append( + "Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) " + "passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). " + "This can happen unintentionally if a previous graph break happens with a builtin iterator " + "in the local scope." + ) + unimplemented_v2( + gb_type="Unsupported method call", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=f"Dynamo does not know how to trace method `{name}` of class `{self.python_type_name()}`", + hints=hints, + ) def set_name_hint(self, name): pass diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 076073cb847e..e00bc431c4f1 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -24,6 +24,7 @@ from ..exc import ( ObservedAttributeError, raise_observed_exception, unimplemented, + unimplemented_v2, Unsupported, UserError, UserErrorType, @@ -902,9 +903,24 @@ class BuiltinVariable(VariableTracker): handlers.append(constant_fold_handler) - error_msg = f"builtin: {fn.__name__} {arg_types} {has_kwargs}" + def call_unimplemented_v2(args): + real_arg_types = [arg.python_type_name() for arg in args] + unimplemented_v2( + gb_type="Failed to trace builtin operator", + context=f"builtin {fn.__name__} {arg_types} {has_kwargs}", + explanation=f"Dynamo does not know how to trace builtin operator `{fn.__name__}` " + f"with argument types {real_arg_types} (has_kwargs {has_kwargs})", + hints=[ + f"Avoid calling builtin `{fn.__name__}` with argument types {real_arg_types}. " + f"Consider using an equivalent alternative function/method to `{fn.__name__}`.", + "If you are attempting to call a logging function (e.g. `print`), " + "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.", + "Please report an issue to PyTorch.", + ], + ) + if len(handlers) == 0: - return lambda *args: unimplemented(error_msg) + return lambda tx, args, kwargs: call_unimplemented_v2(args) elif len(handlers) == 1: (handler,) = handlers @@ -912,7 +928,7 @@ class BuiltinVariable(VariableTracker): rv = handler(tx, args, kwargs) if rv: return rv - unimplemented(error_msg) + call_unimplemented_v2(args) else: @@ -921,7 +937,7 @@ class BuiltinVariable(VariableTracker): rv = fn(tx, args, kwargs) if rv: return rv - unimplemented(error_msg) + call_unimplemented_v2(args) return builtin_dispatch diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 5038ffbd1d3f..8a2a31ae4aca 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -49,6 +49,7 @@ from ..exc import ( raise_observed_exception, SkipFrame, unimplemented, + unimplemented_v2, Unsupported, ) from ..guards import GuardBuilder, install_guard @@ -1136,7 +1137,26 @@ class SkipFunctionVariable(VariableTracker): kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if inspect.getattr_static(self.value, "_torchdynamo_disable", False): - unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}") + unimplemented_v2( + gb_type="Skip calling `torch.compiler.disable()`d function", + context=str(self.value), + explanation=f"Skip calling function `{self.value}` since it was wrapped with `torch.compiler.disable`", + hints=[ + "Remove the `torch.compiler.disable` call", + ], + ) + elif self.value is torch._dynamo.graph_break: + graph_break_msg = kwargs.get("msg", None) + if graph_break_msg: + graph_break_msg = graph_break_msg.as_python_constant() + unimplemented_v2( + gb_type="Call to `torch._dynamo.graph_break()`", + context=f"Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`", + explanation=f"User-inserted graph break. Message: {graph_break_msg}", + hints=[ + "Remove the `torch._dynamo.graph_break()` call.", + ], + ) elif isinstance(self.value, types.WrapperDescriptorType): msg = ( f"Graph break due to unsupported wrapper descriptor {self.value}. " @@ -1148,49 +1168,79 @@ class SkipFunctionVariable(VariableTracker): else: try: path = inspect.getfile(self.value) - msg = f"'skip function {self.value.__qualname__} in file {path}'" + explanation = ( + f"Dynamo developers have intentionally marked that the function `{self.value.__qualname__}` " + f"in file `{path}` should not be traced." + ) + hints = [ + f"Avoid calling the function `{self.value.__qualname__}`.", + ] + # TODO improve trace_rules reasoning to provide better hints. + # How do we tell that a function/file should NOT be removed from skip files? + # Do a very basic check for now. + if "_dynamo" not in path: + hints += [ + f"Remove the function `{self.value.__qualname__}` or the file `{path}` " + "from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of " + "attempting to trace into the function.", + "Please file an issue to PyTorch.", + # TODO suggest mark_force_inline when implemented + ] except TypeError: known_python_builtin_modules = {"_abc", "_warnings"} if self.value.__module__ in known_python_builtin_modules: - msg = ( - f"Graph break due to unsupported Python builtin {self.value.__module__}.{self.value.__qualname__}. " - f"Please file an issue on GitHub " - f"so the PyTorch team can add support for it. " + explanation = ( + f"Dynamo does not know how to trace the Python builtin " + f"`{self.value.__module__}.{self.value.__qualname__}`." ) + hints = [ + "If you are attempting to call a logging function (e.g. `_warnings.warn`), " + "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.", + "Please file an issue on GitHub " + "so the PyTorch team can add support for it. ", + ] elif ( self.value.__module__ is not None and self.value.__module__.startswith("optree") ): - msg = ( - f"Graph break for an optree C/C++ function {self.value.__module__}.{self.value.__qualname__}." - f" Consider using torch.utils._pytree - " - f"https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py" - ) + explanation = f"Dynamo cannot trace optree C/C++ function {self.value.__module__}.{self.value.__qualname__}." + hints = [ + " Consider using torch.utils._pytree - " + "https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py" + ] # also warn on it because most users won't see the graph break message - torch._dynamo.utils.warn_once(msg) + torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints)) else: - msg = ( - f"Graph break due to unsupported builtin {self.value.__module__}.{self.value.__qualname__}. " + explanation = ( + f"Dynamo does not know how to trace the builtin `{self.value.__module__}.{self.value.__qualname__}.` " f"This function is either a Python builtin (e.g. _warnings.warn) " - f"or a third-party C/C++ Python extension (perhaps created with pybind). " - f"If it is a Python builtin, please file an issue on GitHub " - f"so the PyTorch team can add support for it and see the next case for a workaround. " - f"If it is a third-party C/C++ Python extension, please " - f"either wrap it into a PyTorch-understood custom operator " - f"(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html " - f"for more details) or, if it is traceable, use " - f"torch.compiler.allow_in_graph." + f"or a third-party C/C++ Python extension (perhaps created with pybind)." ) + hints = [ + "If it is a Python builtin, please file an issue on GitHub " + "so the PyTorch team can add support for it and see the next case for a workaround.", + "If it is a third-party C/C++ Python extension, please " + "either wrap it into a PyTorch-understood custom operator " + "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html " + "for more details) or, if it is traceable, use " + "`torch.compiler.allow_in_graph`.", + ] # also warn on it because most users won't see the graph break message - torch._dynamo.utils.warn_once(msg) + torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints)) if self.value.__qualname__ == "allow_in_graph": - msg = ( + explanation = ( "Found an allow_in_graph decorator to a function which " "is created inside the parent function that is getting " "compiled. This is not supported for now." ) - msg += f"', {self.reason}'" if self.reason else "" - unimplemented(msg) + hints = [] + reason = self.reason if self.reason else "" + unimplemented_v2( + gb_type="Attempted to call function marked as skipped", + context=f"module: {self.value.__module__}, qualname: {self.value.__qualname__}, skip reason: {reason}", + explanation=explanation, + hints=hints, + ) def call_obj_hasattr(self, tx: "InstructionTranslator", name): return variables.ConstantVariable.create(hasattr(self.value, name)) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 2ae71cbe1451..ab4921f194cf 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -3102,13 +3102,16 @@ class TestCase(expecttest.TestCase): # Munges exceptions that internally contain stack traces, using munge_exc def assertExpectedInlineMunged( - self, exc_type, callable, expect, *, suppress_suffix=True + self, exc_type, callable, expect, *, suppress_suffix=True, post_munge=None, ): try: callable() except exc_type as e: + munged = munge_exc(e, suppress_suffix=suppress_suffix, skip=1) + if post_munge: + munged = post_munge(munged) self.assertExpectedInline( - munge_exc(e, suppress_suffix=suppress_suffix, skip=1), expect, skip=1 + munged, expect, skip=1 ) return self.fail(msg="Did not raise when expected to")