mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] improved graph break messages for some common graph break sites [1/N] (#146525)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146525 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
1e94c7aaa4
commit
16e202a38e
1
.flake8
1
.flake8
@ -38,6 +38,7 @@ per-file-ignores =
|
||||
torchgen/api/types/__init__.py: F401,F403
|
||||
torchgen/executorch/api/types/__init__.py: F401,F403
|
||||
test/dynamo/test_higher_order_ops.py: B950
|
||||
test/dynamo/test_graph_break_messages.py: B950
|
||||
torch/testing/_internal/dynamo_test_failures.py: B950
|
||||
# TOR901 is only for test, we want to ignore it for everything else.
|
||||
# It's not easy to configure this without affecting other per-file-ignores,
|
||||
|
@ -1251,7 +1251,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
x = torch.randn(4, 4).to(device)
|
||||
opt_fn = torch.compile(fn, fullgraph=True)
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported, "skip function graph_break in file"
|
||||
torch._dynamo.exc.Unsupported, "User-inserted graph break"
|
||||
):
|
||||
opt_fn(x)
|
||||
|
||||
|
@ -271,7 +271,10 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
model = CustomFuncBwdPrintModule()
|
||||
opt_model = torch.compile(model, backend="eager", fullgraph=True)
|
||||
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: print"):
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported,
|
||||
"Dynamo does not know how to trace builtin operator `print`",
|
||||
):
|
||||
opt_model(x)
|
||||
|
||||
def test_stride_in_bwd(self):
|
||||
|
@ -330,7 +330,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
fn3(torch.randn(4, 5))
|
||||
self.assertFalse(True)
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
self.assertIn("call torch._dynamo.disable() wrapped function", str(e))
|
||||
self.assertIn("Skip calling `torch.compiler.disable()`d function", str(e))
|
||||
|
||||
def test_disable_optimize(self):
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
|
@ -37,7 +37,12 @@ class ExcTests(LoggingTestCase):
|
||||
torch.randn(1)
|
||||
),
|
||||
"""\
|
||||
'skip function graph_break in file _dynamo/decorators.py'
|
||||
Call to `torch._dynamo.graph_break()`
|
||||
Explanation: User-inserted graph break. Message: None
|
||||
Hint: Remove the `torch._dynamo.graph_break()` call.
|
||||
|
||||
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_exc.py", line N, in fn001
|
||||
@ -171,7 +176,12 @@ from user code:
|
||||
munge_exc(record.getMessage()),
|
||||
"""\
|
||||
Graph break in user code at test_exc.py:N
|
||||
Reason: Unsupported: 'skip function graph_break in file _dynamo/decorators.py'
|
||||
Graph Break Reason: Call to `torch._dynamo.graph_break()`
|
||||
Explanation: User-inserted graph break. Message: None
|
||||
Hint: Remove the `torch._dynamo.graph_break()` call.
|
||||
|
||||
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
|
||||
|
||||
User code traceback:
|
||||
File "test_exc.py", line N, in fn001
|
||||
return fn002(x)
|
||||
|
505
test/dynamo/test_graph_break_messages.py
Normal file
505
test/dynamo/test_graph_break_messages.py
Normal file
@ -0,0 +1,505 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import re
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.config
|
||||
import torch._dynamo.test_case
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch._dynamo.exc import Unsupported
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, scoped_load_inline
|
||||
|
||||
|
||||
"""
|
||||
NOTE Adding tests to this file:
|
||||
|
||||
It is good practice to add a minimal repro for each graph break site (i.e. `unimplemented()` call
|
||||
to make sure that there aren't any errors that occur when generating graph break messages.
|
||||
|
||||
If a graph break message test fails because the graph break no longer repros,
|
||||
it is good practice to find a new minimal repro that causes the graph break.
|
||||
If this is too much work, it is likely safe to skip/remove the test, assuming
|
||||
it was previously passing and the graph break message is not changed.
|
||||
However, if you add a new graph break or modify a graph break message, you should
|
||||
make sure that there is a test for it.
|
||||
"""
|
||||
|
||||
|
||||
class GraphBreakMessagesTest(torch._dynamo.test_case.TestCase):
|
||||
maxDiff = None
|
||||
|
||||
def test_dynamic_shape_operator(self):
|
||||
def fn():
|
||||
return torch.nonzero(torch.rand([10, 10]))
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
|
||||
"""\
|
||||
Dynamic shape operator
|
||||
Explanation: Operator `aten.nonzero.default`'s output shape depends on input Tensor data.
|
||||
Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`
|
||||
|
||||
Developer debug context: aten.nonzero.default
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
return torch.nonzero(torch.rand([10, 10]))""",
|
||||
)
|
||||
|
||||
def test_dynamic_shape_operator_no_meta_kernel(self):
|
||||
def fn():
|
||||
return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10))
|
||||
|
||||
with torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True):
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
|
||||
"""\
|
||||
Dynamic shape operator (no meta kernel)
|
||||
Explanation: Operator `aten.linalg_lstsq.default` does not have a meta kernel that supports dynamic output shapes
|
||||
Hint: Please report an issue to PyTorch
|
||||
|
||||
Developer debug context: aten.linalg_lstsq.default
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10))""",
|
||||
)
|
||||
|
||||
def test_data_dependent_operator(self):
|
||||
def fn(x):
|
||||
return x.item()
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
|
||||
torch.Tensor([1])
|
||||
),
|
||||
"""\
|
||||
Tensor.item
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
return x.item()""",
|
||||
)
|
||||
|
||||
def test_data_dependent_operator2(self):
|
||||
def fn(x):
|
||||
return torch.equal(x, x)
|
||||
|
||||
with torch._dynamo.config.patch(capture_scalar_outputs=True):
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
|
||||
torch.ones(3)
|
||||
),
|
||||
"""\
|
||||
Data dependent operator
|
||||
Explanation: Operator `aten.equal.default` has a non-Tensor output whose value is dependent on the data of Tensor inputs.
|
||||
Hint: Consider wrapping the operator into a PyTorch-understood custom operator (see https:/pytorch.org/tutorials/advanced/custom_ops_landing_page.html)
|
||||
|
||||
Developer debug context: aten.equal.default
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
return torch.equal(x, x)""",
|
||||
)
|
||||
|
||||
def test_super_call_method(self):
|
||||
def fn(it):
|
||||
return [x + 1 for x in it]
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
|
||||
zip(range(5), range(10))
|
||||
),
|
||||
"""\
|
||||
Unsupported method call
|
||||
Explanation: Dynamo does not know how to trace method `__iter__` of class `zip`
|
||||
Hint: Avoid calling `zip.__iter__` in your code.
|
||||
Hint: Please report an issue to PyTorch.
|
||||
Hint: Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). This can happen unintentionally if a previous graph break happens with a builtin iterator in the local scope.
|
||||
|
||||
Developer debug context: call_method UserDefinedObjectVariable(zip) __iter__ () {}
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
return [x + 1 for x in it]""",
|
||||
)
|
||||
|
||||
def test_super_call_function(self):
|
||||
def fn(it):
|
||||
return [x + 1 for x in it()]
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
|
||||
zip(range(5), range(10))
|
||||
),
|
||||
"""\
|
||||
Unsupported function call
|
||||
Explanation: Dynamo does not know how to trace the function `UserDefinedObjectVariable(zip)`
|
||||
Hint: Avoid calling `UserDefinedObjectVariable(zip)` in your code.
|
||||
Hint: Please report an issue to PyTorch.
|
||||
|
||||
Developer debug context: call_function UserDefinedObjectVariable(zip) [] {}
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
return [x + 1 for x in it()]""",
|
||||
)
|
||||
|
||||
def test_unsupported_context(self):
|
||||
def fn(obj):
|
||||
with obj:
|
||||
return 1
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(3),
|
||||
"""\
|
||||
Unsupported context manager
|
||||
Explanation: Dynamo does not know how to enter a `int` context manager.
|
||||
Hint: Avoid using the unsupported context manager.
|
||||
Hint: File an issue to PyTorch. Simple context managers can potentially be supported, but note that context managers can't be supported in general
|
||||
|
||||
Developer debug context: Attempted SETUP_WITH/BEFORE_WITH on ConstantVariable(int: 3)
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
with obj:""",
|
||||
)
|
||||
|
||||
def test_backend_fake_tensor_exc(self):
|
||||
def bad_backend(gm, ex):
|
||||
raise torch._subclasses.fake_tensor.UnsupportedFakeTensorException("test")
|
||||
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend=bad_backend, fullgraph=True)(
|
||||
torch.ones(3, 3)
|
||||
),
|
||||
"""\
|
||||
Backend compiler exception
|
||||
Explanation: Backend compiler `bad_backend` failed with test. Adding a graph break.
|
||||
Hint: Report an issue to the backend compiler repo.
|
||||
|
||||
Developer debug context: Backend: bad_backend
|
||||
Exception:test
|
||||
Traceback:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
return x + 1""",
|
||||
)
|
||||
|
||||
def test_unsupported_builtin(self):
|
||||
def fn():
|
||||
print("abc")
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
|
||||
"""\
|
||||
Failed to trace builtin operator
|
||||
Explanation: Dynamo does not know how to trace builtin operator `print` with argument types ['str'] (has_kwargs False)
|
||||
Hint: Avoid calling builtin `print` with argument types ['str']. Consider using an equivalent alternative function/method to `print`.
|
||||
Hint: If you are attempting to call a logging function (e.g. `print`), you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.
|
||||
Hint: Please report an issue to PyTorch.
|
||||
|
||||
Developer debug context: builtin print [<class 'torch._dynamo.variables.constant.ConstantVariable'>] False
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
print("abc")""",
|
||||
)
|
||||
|
||||
def test_skipfile_call(self):
|
||||
def fn():
|
||||
return unittest.skip("test")
|
||||
|
||||
def post_munge(s):
|
||||
return re.sub(r"file `.*case\.py`", "file `case.py`", s)
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo developers have intentionally marked that the function `skip` in file `case.py` should not be traced.
|
||||
Hint: Avoid calling the function `skip`.
|
||||
Hint: Remove the function `skip` or the file `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function.
|
||||
Hint: Please file an issue to PyTorch.
|
||||
|
||||
Developer debug context: module: unittest.case, qualname: skip, skip reason: <missing reason>
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
return unittest.skip("test")""",
|
||||
post_munge=post_munge,
|
||||
)
|
||||
|
||||
def test_skipfile_dynamo_call(self):
|
||||
def fn():
|
||||
torch._dynamo.disable()
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo developers have intentionally marked that the function `disable` in file `_dynamo/decorators.py` should not be traced.
|
||||
Hint: Avoid calling the function `disable`.
|
||||
|
||||
Developer debug context: module: torch._dynamo.decorators, qualname: disable, skip reason: <missing reason>
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
torch._dynamo.disable()""",
|
||||
)
|
||||
|
||||
def test_skipfile_inline(self):
|
||||
class Foo:
|
||||
fn = unittest.skip
|
||||
|
||||
def fn():
|
||||
Foo().fn()
|
||||
|
||||
def post_munge(s):
|
||||
return re.sub(r"`.*case\.py`", "`case.py`", s)
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
|
||||
"""\
|
||||
Attempted to inline function marked as skipped
|
||||
Explanation: Dynamo developers have intentionally marked that the function `skip` should not be traced.
|
||||
Hint: Avoid calling the function `skip`.
|
||||
Hint: Remove the function `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function.
|
||||
Hint: Please file an issue to PyTorch.
|
||||
|
||||
Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup SKIP_DIRS
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
Foo().fn()""",
|
||||
post_munge=post_munge,
|
||||
)
|
||||
|
||||
def test_disable(self):
|
||||
@torch.compiler.disable
|
||||
def inner():
|
||||
return 1
|
||||
|
||||
def fn():
|
||||
return inner()
|
||||
|
||||
def post_munge(s):
|
||||
return re.sub(
|
||||
r"<function GraphBreakMessagesTest\.test_disable\.<locals>\.inner at 0x[0-9A-Fa-f]+>",
|
||||
"<function inner>",
|
||||
s,
|
||||
)
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
|
||||
"""\
|
||||
Skip calling `torch.compiler.disable()`d function
|
||||
Explanation: Skip calling function `<function inner>` since it was wrapped with `torch.compiler.disable`
|
||||
Hint: Remove the `torch.compiler.disable` call
|
||||
|
||||
Developer debug context: <function inner>
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
return inner()""",
|
||||
post_munge=post_munge,
|
||||
)
|
||||
|
||||
def test_dynamo_graph_break_fn(self):
|
||||
def fn():
|
||||
torch._dynamo.graph_break()
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
|
||||
"""\
|
||||
Call to `torch._dynamo.graph_break()`
|
||||
Explanation: User-inserted graph break. Message: None
|
||||
Hint: Remove the `torch._dynamo.graph_break()` call.
|
||||
|
||||
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
torch._dynamo.graph_break()""",
|
||||
)
|
||||
|
||||
def test_dynamo_graph_break_fn_with_msg(self):
|
||||
def fn():
|
||||
torch._dynamo.graph_break(msg="test graph break")
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
|
||||
"""\
|
||||
Call to `torch._dynamo.graph_break()`
|
||||
Explanation: User-inserted graph break. Message: test graph break
|
||||
Hint: Remove the `torch._dynamo.graph_break()` call.
|
||||
|
||||
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{'msg': ConstantVariable(str: 'test graph break')}`
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
torch._dynamo.graph_break(msg="test graph break")""",
|
||||
)
|
||||
|
||||
def test_warnings(self):
|
||||
def fn():
|
||||
warnings.warn("test")
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo does not know how to trace the Python builtin `_warnings.warn`.
|
||||
Hint: If you are attempting to call a logging function (e.g. `_warnings.warn`), you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.
|
||||
Hint: Please file an issue on GitHub so the PyTorch team can add support for it.
|
||||
|
||||
Developer debug context: module: _warnings, qualname: warn, skip reason: <missing reason>
|
||||
|
||||
|
||||
from user code:
|
||||
File "test_graph_break_messages.py", line N, in fn
|
||||
warnings.warn("test")""",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not python_pytree._cxx_pytree_exists, "missing optree package")
|
||||
def test_optree_graph_break_message(self):
|
||||
import optree
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
d = {"a": 1}
|
||||
optree.tree_flatten(d)
|
||||
return torch.sin(x)
|
||||
|
||||
fn(torch.randn(4))
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
first_graph_break = next(iter(counters["graph_break"].keys()))
|
||||
self.assertExpectedInline(
|
||||
first_graph_break,
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten.
|
||||
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
|
||||
|
||||
Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: <missing reason>
|
||||
""",
|
||||
)
|
||||
|
||||
@scoped_load_inline
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
|
||||
@unittest.skipIf(IS_FBCODE, "inline cpp_extension doesn't work in fbcode")
|
||||
def test_cpp_extension_recommends_custom_ops(self, load_inline):
|
||||
cpp_source = """
|
||||
#include <torch/extension.h>
|
||||
at::Tensor foobar(const at::Tensor& x) {
|
||||
return x.clone();
|
||||
}
|
||||
"""
|
||||
module = load_inline(
|
||||
name="mylib",
|
||||
cpp_sources=cpp_source,
|
||||
functions="foobar",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
x = torch.ones(2, 2, requires_grad=True)
|
||||
counters.clear()
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f(x):
|
||||
return module.foobar(x)
|
||||
|
||||
with self.assertWarnsOnceRegex(
|
||||
UserWarning,
|
||||
"(?s).*https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html.*",
|
||||
):
|
||||
f(x)
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
first_graph_break = next(iter(counters["graph_break"].keys()))
|
||||
|
||||
first_graph_break = re.sub(r"mylib(_v\d+)?", "mylib", first_graph_break)
|
||||
|
||||
self.assertExpectedInline(
|
||||
first_graph_break,
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo does not know how to trace the builtin `mylib.PyCapsule.foobar.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
|
||||
Hint: If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
|
||||
Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
|
||||
|
||||
Developer debug context: module: mylib, qualname: PyCapsule.foobar, skip reason: <missing reason>
|
||||
""",
|
||||
)
|
||||
|
||||
cpp_source = """
|
||||
#include <torch/extension.h>
|
||||
at::Tensor baz(const at::Tensor& x) {
|
||||
return x.clone();
|
||||
}
|
||||
"""
|
||||
module2 = load_inline(
|
||||
name="mylib2",
|
||||
cpp_sources=cpp_source,
|
||||
functions="baz",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Test that each warning only happens once
|
||||
@torch.compile(backend="eager")
|
||||
def f(x):
|
||||
module2.baz(x)
|
||||
module.foobar(x)
|
||||
module.foobar(x)
|
||||
module2.baz(x)
|
||||
module.foobar(x)
|
||||
module2.baz(x)
|
||||
return x.clone()
|
||||
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
warnings.simplefilter("always")
|
||||
f(x)
|
||||
f(x)
|
||||
self.assertEqual(len(ws), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -512,7 +512,9 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
x2 = torch.ones(4, requires_grad=True)
|
||||
with compiled_autograd._enable(compiler_fn):
|
||||
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2, obj)
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: str"):
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported, "Failed to trace builtin operator"
|
||||
):
|
||||
dynamo_out[0].backward(torch.ones(4))
|
||||
|
||||
self.assertEqual(obj.count, 2)
|
||||
|
@ -293,24 +293,6 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
fn(torch.randn(16))
|
||||
|
||||
@unittest.skipIf(not python_pytree._cxx_pytree_exists, "missing optree package")
|
||||
def test_optree_graph_break_message(self):
|
||||
import optree
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
d = {"a": 1}
|
||||
optree.tree_flatten(d)
|
||||
return torch.sin(x)
|
||||
|
||||
fn(torch.randn(4))
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
first_graph_break = list(counters["graph_break"].keys())[0]
|
||||
self.assertExpectedInline(
|
||||
first_graph_break,
|
||||
"Graph break for an optree C/C++ function optree._C.PyCapsule.flatten. Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py",
|
||||
)
|
||||
|
||||
def test_scalar_device_movement(self):
|
||||
if not torch._dynamo.config.assume_static_by_default:
|
||||
self.skipTest("Doesn't work with symints")
|
||||
@ -324,74 +306,6 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||
res_compiled = add_fn(2, 3, torch.tensor(0.0))
|
||||
self.assertEqual(res, res_compiled)
|
||||
|
||||
@scoped_load_inline
|
||||
@skipIfNNModuleInlined("fails internal CI")
|
||||
@unittest.skipIf(IS_FBCODE, "inline cpp_extension doesn't work in fbcode")
|
||||
def test_cpp_extension_recommends_custom_ops(self, load_inline):
|
||||
cpp_source = """
|
||||
#include <torch/extension.h>
|
||||
at::Tensor foobar(const at::Tensor& x) {
|
||||
return x.clone();
|
||||
}
|
||||
"""
|
||||
module = load_inline(
|
||||
name="mylib",
|
||||
cpp_sources=cpp_source,
|
||||
functions="foobar",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
x = torch.ones(2, 2, requires_grad=True)
|
||||
counters.clear()
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f(x):
|
||||
return module.foobar(x)
|
||||
|
||||
with self.assertWarnsOnceRegex(
|
||||
UserWarning,
|
||||
".*https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html.*",
|
||||
):
|
||||
f(x)
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
first_graph_break = list(counters["graph_break"].keys())[0]
|
||||
self.assertExpectedInline(
|
||||
first_graph_break,
|
||||
"""Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""",
|
||||
)
|
||||
|
||||
cpp_source = """
|
||||
#include <torch/extension.h>
|
||||
at::Tensor baz(const at::Tensor& x) {
|
||||
return x.clone();
|
||||
}
|
||||
"""
|
||||
module2 = load_inline(
|
||||
name="mylib2",
|
||||
cpp_sources=cpp_source,
|
||||
functions="baz",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Test that each warning only happens once
|
||||
@torch.compile(backend="eager")
|
||||
def f(x):
|
||||
module2.baz(x)
|
||||
module.foobar(x)
|
||||
module.foobar(x)
|
||||
module2.baz(x)
|
||||
module.foobar(x)
|
||||
module2.baz(x)
|
||||
return x.clone()
|
||||
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
warnings.simplefilter("always")
|
||||
f(x)
|
||||
f(x)
|
||||
self.assertEqual(len(ws), 2)
|
||||
|
||||
def test_callpacked(self):
|
||||
def call_packed(args):
|
||||
a, b, c = args
|
||||
@ -8967,7 +8881,7 @@ def ___make_guard_fn():
|
||||
# and so the guard story for the objects passed into input just isn't there atm.
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported,
|
||||
"^call_method UserDefinedObjectVariable\\(set\\).*",
|
||||
"Unsupported method call",
|
||||
):
|
||||
foo(inp)
|
||||
|
||||
@ -10643,7 +10557,7 @@ ShapeEnv not equal: field values don't match:
|
||||
# Should only be one restart per event
|
||||
(restart_reason,) = metrics[0].restart_reasons
|
||||
self.assertTrue(
|
||||
"skip function graph_break" in restart_reason,
|
||||
"User-inserted graph break" in restart_reason,
|
||||
"Should have logged graph break reason",
|
||||
)
|
||||
self.assertTrue(
|
||||
@ -10653,7 +10567,7 @@ ShapeEnv not equal: field values don't match:
|
||||
|
||||
(restart_reason,) = metrics[1].restart_reasons
|
||||
self.assertTrue(
|
||||
"skip function graph_break" in restart_reason,
|
||||
"User-inserted graph break" in restart_reason,
|
||||
"Should have logged graph break reason",
|
||||
)
|
||||
self.assertTrue(
|
||||
|
@ -1772,8 +1772,12 @@ def forward(self, x_1):
|
||||
self.assertExpectedInline(
|
||||
next(iter(counters["graph_break"].keys())).replace(";", "\n"),
|
||||
"""\
|
||||
dynamic shape operator: _torch_testing.numpy_nonzero.default
|
||||
to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""",
|
||||
Dynamic shape operator
|
||||
Explanation: Operator `_torch_testing.numpy_nonzero.default`'s output shape depends on input Tensor data.
|
||||
Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`
|
||||
|
||||
Developer debug context: _torch_testing.numpy_nonzero.default
|
||||
""",
|
||||
)
|
||||
|
||||
# pre-existing problem: torch.compile(dynamic=True) will, by default,
|
||||
|
@ -206,7 +206,7 @@ def disallow_in_graph(fn):
|
||||
|
||||
|
||||
@_disallow_in_graph_helper(throw_if_not_allowed=False)
|
||||
def graph_break():
|
||||
def graph_break(msg=""):
|
||||
"""Force a graph break"""
|
||||
|
||||
|
||||
|
@ -441,6 +441,69 @@ def unimplemented(
|
||||
raise Unsupported(msg, case_name=case_name)
|
||||
|
||||
|
||||
def unimplemented_v2_with_warning(
|
||||
e: Exception,
|
||||
code: types.CodeType,
|
||||
gb_type: str,
|
||||
context: str,
|
||||
explanation: str,
|
||||
hints: list[str],
|
||||
) -> NoReturn:
|
||||
# This function calls unimplemented internally and eventually graph breaks
|
||||
# or falls to eager. unimplemented itself does not print any user warnings,
|
||||
# i.e., its very silent. This helper function is intended when an error is
|
||||
# encountered in the torch.compile stack which is worth showing as warning
|
||||
# to the user. For example, if AOT Autograd backend fails with a fake tensor
|
||||
# exception, its ok to fallback to eager but not silently. Here, we can use
|
||||
# this function to log the message and the stack trace.
|
||||
graph_break_msg = format_error_msg_verbose(e, code)
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "dynamo_graph_break_reason",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: graph_break_msg,
|
||||
)
|
||||
graph_breaks_log.debug("%s", graph_break_msg)
|
||||
unimplemented_v2(gb_type, context, explanation, hints, from_exc=e, log_warning=True)
|
||||
|
||||
|
||||
# TODO replace old unimplemented later
|
||||
def unimplemented_v2(
|
||||
gb_type: str,
|
||||
context: str,
|
||||
explanation: str,
|
||||
hints: list[str],
|
||||
*,
|
||||
from_exc: Any = _NOTHING,
|
||||
log_warning: bool = False,
|
||||
) -> NoReturn:
|
||||
"""
|
||||
Called within dynamo to cause a graph break.
|
||||
Args:
|
||||
gb_type: Context-free graph break type. It should be a short string without any
|
||||
information specific to the tracing context (i.e. no dynamically-generated strings)
|
||||
context: Developer context for the graph break. It can contain tracing context/dynamic strings.
|
||||
explanation: User-facing context-dependent explanation for the graph break. Can be dynamic.
|
||||
hints: List of user-facing hints for the graph break.
|
||||
"""
|
||||
hints_str = "\n".join(hints)
|
||||
hints_str = textwrap.indent(hints_str, " Hint: ")
|
||||
msg = f"""\
|
||||
{gb_type}
|
||||
Explanation: {explanation}
|
||||
{hints_str}
|
||||
|
||||
Developer debug context: {context}
|
||||
"""
|
||||
if log_warning:
|
||||
log.warning(msg)
|
||||
if from_exc is not _NOTHING:
|
||||
raise Unsupported(msg) from from_exc
|
||||
raise Unsupported(msg)
|
||||
|
||||
|
||||
def warning(msg: str) -> None:
|
||||
counters["warnings"][msg] += 1
|
||||
assert msg != os.environ.get("BREAK", False)
|
||||
|
@ -82,7 +82,7 @@ from .exc import (
|
||||
exceptions_allowed_to_be_fallback,
|
||||
SkipFrame,
|
||||
unimplemented,
|
||||
unimplemented_with_warning,
|
||||
unimplemented_v2_with_warning,
|
||||
)
|
||||
from .graph_deduplication import apply_graph_deduplication
|
||||
from .graph_region_tracker import GraphRegionTracker
|
||||
@ -652,9 +652,11 @@ class OutputGraph:
|
||||
"""
|
||||
global_state = cast(
|
||||
dict[str, tuple[Callable[..., Any], bool]],
|
||||
(
|
||||
out
|
||||
if out is not None
|
||||
else self.tracing_context.global_context.global_state,
|
||||
else self.tracing_context.global_context.global_state
|
||||
),
|
||||
)
|
||||
|
||||
# TODO - Consider having a torch level API for torch_function_state. As
|
||||
@ -1475,12 +1477,12 @@ class OutputGraph:
|
||||
gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
|
||||
gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
name = (
|
||||
self.compiler_fn.__name__
|
||||
if hasattr(self.compiler_fn, "__name__")
|
||||
else ""
|
||||
else "<unknown compiler_fn>"
|
||||
)
|
||||
try:
|
||||
_step_logger()(logging.INFO, f"calling compiler function {name}")
|
||||
compiler_fn = self.compiler_fn
|
||||
if config.verify_correctness:
|
||||
@ -1495,12 +1497,16 @@ class OutputGraph:
|
||||
raise BackendCompilerFailed(
|
||||
self.compiler_fn, e, inspect.currentframe()
|
||||
).with_traceback(e.__traceback__) from None
|
||||
msg = (
|
||||
f"Backend compiler failed with {str(e)} at \n"
|
||||
f"{self.root_tx.format_frame_summary()}"
|
||||
"Adding a graph break."
|
||||
unimplemented_v2_with_warning(
|
||||
e,
|
||||
self.root_tx.f_code,
|
||||
gb_type="Backend compiler exception",
|
||||
context=f"Backend: {name}\nException:{str(e)}\nTraceback:\n{self.root_tx.format_frame_summary()}",
|
||||
explanation=f"Backend compiler `{name}` failed with {str(e)}. Adding a graph break.",
|
||||
hints=[
|
||||
"Report an issue to the backend compiler repo.",
|
||||
],
|
||||
)
|
||||
unimplemented_with_warning(e, self.root_tx.f_code, msg)
|
||||
except SkipFrame as e:
|
||||
# The backend compiler has requested that we skip the frame, instead of
|
||||
# aborting execution.
|
||||
|
@ -74,7 +74,13 @@ from .bytecode_transformation import (
|
||||
)
|
||||
from .code_context import code_context
|
||||
from .codegen import PyCodegen
|
||||
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported
|
||||
from .exc import (
|
||||
ArgsMismatchError,
|
||||
BackendCompilerFailed,
|
||||
unimplemented,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
)
|
||||
from .funcname_cache import get_funcname
|
||||
from .guards import GuardBuilder, install_guard
|
||||
from .output_graph import GraphCompileReason, OutputGraph
|
||||
@ -488,7 +494,7 @@ def log_graph_break(code_options, reason="", exc_info=False, user_stack=None):
|
||||
|
||||
user_stack_formatted = "".join(traceback.format_list(user_stack))
|
||||
user_stack_trace = (
|
||||
"Graph break in user code at %s:%s\nReason: %s\nUser code traceback:\n%s" # noqa: UP031
|
||||
"Graph break in user code at %s:%s\nGraph Break Reason: %s\nUser code traceback:\n%s" # noqa: UP031
|
||||
% (
|
||||
frame_loc[0],
|
||||
frame_loc[1],
|
||||
@ -523,7 +529,7 @@ def log_graph_break(code_options, reason="", exc_info=False, user_stack=None):
|
||||
# exercised by
|
||||
# python test/dynamo/test_misc.py -k test_duplicate_graph_break_log
|
||||
graph_break_log.debug(
|
||||
"Graph break (details suppressed) in user code at %s:%s\nReason: %s",
|
||||
"Graph break (user stack suppressed due to duplicate graph break) in user code at %s:%s\nGraph Break Reason: %s",
|
||||
frame_loc[0],
|
||||
frame_loc[1],
|
||||
reason,
|
||||
@ -762,7 +768,7 @@ def break_graph_if_unsupported(*, push):
|
||||
log_graph_break(
|
||||
self.code_options,
|
||||
exc_info=True,
|
||||
reason=f"Unsupported: {excp}",
|
||||
reason=str(excp),
|
||||
user_stack=excp.real_stack,
|
||||
)
|
||||
|
||||
@ -2546,7 +2552,16 @@ class InstructionTranslatorBase(
|
||||
if not isinstance(
|
||||
ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
|
||||
):
|
||||
unimplemented(f"{inst.opname} {ctx}")
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported context manager",
|
||||
context=f"Attempted SETUP_WITH/BEFORE_WITH on {ctx}",
|
||||
explanation=f"Dynamo does not know how to enter a `{ctx.python_type_name()}` context manager.",
|
||||
hints=[
|
||||
"Avoid using the unsupported context manager.",
|
||||
"File an issue to PyTorch. Simple context managers can potentially be supported, "
|
||||
"but note that context managers can't be supported in general",
|
||||
],
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(ctx, GenericContextWrappingVariable)
|
||||
@ -3289,15 +3304,36 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
False, "allowlist in dynamo known function"
|
||||
)
|
||||
fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else ""
|
||||
unimplemented(
|
||||
f"'inline in skipfiles: {fn_qualname} | {func.get_name()} {func.get_filename()}, {result.reason}'"
|
||||
hints = [
|
||||
f"Avoid calling the function `{fn_qualname}`.",
|
||||
]
|
||||
if "_dynamo" not in func.get_filename():
|
||||
hints += [
|
||||
f"Remove the function `{fn_qualname}` or the file `{func.get_filename()}` "
|
||||
"from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of "
|
||||
"attempting to trace into the function.",
|
||||
"Please file an issue to PyTorch.",
|
||||
# TODO suggest mark_force_inline when implemented
|
||||
]
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to inline function marked as skipped",
|
||||
context=f"qualname: {fn_qualname}, name: {func.get_name()}, "
|
||||
f"filename: `{func.get_filename()}`, skip reason: {result.reason}",
|
||||
explanation=f"Dynamo developers have intentionally marked that the function `{fn_qualname}` "
|
||||
"should not be traced.",
|
||||
hints=hints,
|
||||
)
|
||||
|
||||
if isinstance(func, UserFunctionVariable) and inspect.getattr_static(
|
||||
func.get_function(), "_torchdynamo_disable", False
|
||||
):
|
||||
unimplemented(
|
||||
f"call torch._dynamo.disable() wrapped function {func.get_function()}"
|
||||
unimplemented_v2(
|
||||
gb_type="Skip inlining `torch.compiler.disable()`d function",
|
||||
context=str(func.get_function()),
|
||||
explanation=f"Skip inlining function {func.get_function()} since it was wrapped with `torch.compiler.disable`",
|
||||
hints=[
|
||||
"Remove the `torch.compiler.disable` call",
|
||||
],
|
||||
)
|
||||
else:
|
||||
return result
|
||||
|
@ -3004,6 +3004,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
|
||||
from .exc import (
|
||||
TorchRuntimeError,
|
||||
unimplemented,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
UserError,
|
||||
UserErrorType,
|
||||
@ -3063,23 +3064,50 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
|
||||
if isinstance(
|
||||
cause, torch._subclasses.fake_tensor.DataDependentOutputException
|
||||
):
|
||||
unimplemented(
|
||||
f"data dependent operator: {cause.func}; "
|
||||
"to enable, set torch._dynamo.config.capture_scalar_outputs = True"
|
||||
# capture_scalar_outputs only works for these ops right now
|
||||
# see torch/_subclasses/fake_impls.py
|
||||
if cause.func in (
|
||||
torch.ops.aten.item.default,
|
||||
torch.ops.aten._local_scalar_dense.default,
|
||||
):
|
||||
# does this actually get triggered?
|
||||
hints = [
|
||||
"Enable tracing of data-dependent output operators with "
|
||||
"`torch._dynamo.config.capture_scalar_outputs = True`",
|
||||
]
|
||||
else:
|
||||
hints = [
|
||||
"Consider wrapping the operator into a PyTorch-understood custom operator "
|
||||
"(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html)",
|
||||
]
|
||||
unimplemented_v2(
|
||||
gb_type="Data dependent operator",
|
||||
context=str(cause.func),
|
||||
explanation=f"Operator `{cause.func}` has a non-Tensor output "
|
||||
"whose value is dependent on the data of Tensor inputs.",
|
||||
hints=hints,
|
||||
)
|
||||
elif isinstance(
|
||||
cause, torch._subclasses.fake_tensor.DynamicOutputShapeException
|
||||
):
|
||||
if not torch._dynamo.config.capture_dynamic_output_shape_ops:
|
||||
unimplemented(
|
||||
f"dynamic shape operator: {cause.func}; "
|
||||
"to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True"
|
||||
unimplemented_v2(
|
||||
gb_type="Dynamic shape operator",
|
||||
context=str(cause.func),
|
||||
explanation=f"Operator `{cause.func}`'s output shape depends on input Tensor data.",
|
||||
hints=[
|
||||
"Enable tracing of dynamic shape operators with "
|
||||
"`torch._dynamo.config.capture_dynamic_output_shape_ops = True`",
|
||||
],
|
||||
)
|
||||
else:
|
||||
unimplemented(
|
||||
f"dynamic shape operator: {cause.func}; "
|
||||
"Operator does not have a meta kernel that supports dynamic output shapes, "
|
||||
"please report an issue to PyTorch"
|
||||
unimplemented_v2(
|
||||
gb_type="Dynamic shape operator (no meta kernel)",
|
||||
context=str(cause.func),
|
||||
explanation=f"Operator `{cause.func}` does not have a meta kernel that supports dynamic output shapes",
|
||||
hints=[
|
||||
"Please report an issue to PyTorch",
|
||||
],
|
||||
)
|
||||
elif isinstance(
|
||||
cause, torch._subclasses.fake_tensor.UnsupportedOperatorException
|
||||
|
@ -21,7 +21,7 @@ from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
|
||||
|
||||
from .. import variables
|
||||
from ..current_scope_id import current_scope_id
|
||||
from ..exc import unimplemented
|
||||
from ..exc import unimplemented, unimplemented_v2
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, Source
|
||||
from ..utils import cmp_name_to_op_mapping, istype
|
||||
@ -310,6 +310,12 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
except NotImplementedError:
|
||||
raise NotImplementedError(f"{self} has no type") from None
|
||||
|
||||
def python_type_name(self):
|
||||
try:
|
||||
return self.python_type().__name__
|
||||
except NotImplementedError:
|
||||
return "<unknown type>"
|
||||
|
||||
def as_python_constant(self):
|
||||
"""For constants"""
|
||||
raise NotImplementedError(f"{self} is not a constant")
|
||||
@ -408,7 +414,15 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
args: Sequence["VariableTracker"],
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
unimplemented(f"call_function {self} {args} {kwargs}")
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported function call",
|
||||
context=f"call_function {self} {args} {kwargs}",
|
||||
explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`",
|
||||
hints=[
|
||||
f"Avoid calling `{self.debug_repr()}` in your code.",
|
||||
"Please report an issue to PyTorch.",
|
||||
],
|
||||
)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
@ -450,7 +464,27 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
self.as_python_constant(), other.as_python_constant()
|
||||
)
|
||||
)
|
||||
unimplemented(f"call_method {self} {name} {args} {kwargs}")
|
||||
hints = [
|
||||
f"Avoid calling `{self.python_type_name()}.{name}` in your code.",
|
||||
"Please report an issue to PyTorch.",
|
||||
]
|
||||
# additional hint for method calls on improperly constructed iterators
|
||||
if isinstance(self, variables.UserDefinedObjectVariable) and name in (
|
||||
"__iter__",
|
||||
"__next__",
|
||||
):
|
||||
hints.append(
|
||||
"Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) "
|
||||
"passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). "
|
||||
"This can happen unintentionally if a previous graph break happens with a builtin iterator "
|
||||
"in the local scope."
|
||||
)
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported method call",
|
||||
context=f"call_method {self} {name} {args} {kwargs}",
|
||||
explanation=f"Dynamo does not know how to trace method `{name}` of class `{self.python_type_name()}`",
|
||||
hints=hints,
|
||||
)
|
||||
|
||||
def set_name_hint(self, name):
|
||||
pass
|
||||
|
@ -24,6 +24,7 @@ from ..exc import (
|
||||
ObservedAttributeError,
|
||||
raise_observed_exception,
|
||||
unimplemented,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
UserError,
|
||||
UserErrorType,
|
||||
@ -902,9 +903,24 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
handlers.append(constant_fold_handler)
|
||||
|
||||
error_msg = f"builtin: {fn.__name__} {arg_types} {has_kwargs}"
|
||||
def call_unimplemented_v2(args):
|
||||
real_arg_types = [arg.python_type_name() for arg in args]
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to trace builtin operator",
|
||||
context=f"builtin {fn.__name__} {arg_types} {has_kwargs}",
|
||||
explanation=f"Dynamo does not know how to trace builtin operator `{fn.__name__}` "
|
||||
f"with argument types {real_arg_types} (has_kwargs {has_kwargs})",
|
||||
hints=[
|
||||
f"Avoid calling builtin `{fn.__name__}` with argument types {real_arg_types}. "
|
||||
f"Consider using an equivalent alternative function/method to `{fn.__name__}`.",
|
||||
"If you are attempting to call a logging function (e.g. `print`), "
|
||||
"you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.",
|
||||
"Please report an issue to PyTorch.",
|
||||
],
|
||||
)
|
||||
|
||||
if len(handlers) == 0:
|
||||
return lambda *args: unimplemented(error_msg)
|
||||
return lambda tx, args, kwargs: call_unimplemented_v2(args)
|
||||
elif len(handlers) == 1:
|
||||
(handler,) = handlers
|
||||
|
||||
@ -912,7 +928,7 @@ class BuiltinVariable(VariableTracker):
|
||||
rv = handler(tx, args, kwargs)
|
||||
if rv:
|
||||
return rv
|
||||
unimplemented(error_msg)
|
||||
call_unimplemented_v2(args)
|
||||
|
||||
else:
|
||||
|
||||
@ -921,7 +937,7 @@ class BuiltinVariable(VariableTracker):
|
||||
rv = fn(tx, args, kwargs)
|
||||
if rv:
|
||||
return rv
|
||||
unimplemented(error_msg)
|
||||
call_unimplemented_v2(args)
|
||||
|
||||
return builtin_dispatch
|
||||
|
||||
|
@ -49,6 +49,7 @@ from ..exc import (
|
||||
raise_observed_exception,
|
||||
SkipFrame,
|
||||
unimplemented,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
)
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
@ -1136,7 +1137,26 @@ class SkipFunctionVariable(VariableTracker):
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
|
||||
unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
|
||||
unimplemented_v2(
|
||||
gb_type="Skip calling `torch.compiler.disable()`d function",
|
||||
context=str(self.value),
|
||||
explanation=f"Skip calling function `{self.value}` since it was wrapped with `torch.compiler.disable`",
|
||||
hints=[
|
||||
"Remove the `torch.compiler.disable` call",
|
||||
],
|
||||
)
|
||||
elif self.value is torch._dynamo.graph_break:
|
||||
graph_break_msg = kwargs.get("msg", None)
|
||||
if graph_break_msg:
|
||||
graph_break_msg = graph_break_msg.as_python_constant()
|
||||
unimplemented_v2(
|
||||
gb_type="Call to `torch._dynamo.graph_break()`",
|
||||
context=f"Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`",
|
||||
explanation=f"User-inserted graph break. Message: {graph_break_msg}",
|
||||
hints=[
|
||||
"Remove the `torch._dynamo.graph_break()` call.",
|
||||
],
|
||||
)
|
||||
elif isinstance(self.value, types.WrapperDescriptorType):
|
||||
msg = (
|
||||
f"Graph break due to unsupported wrapper descriptor {self.value}. "
|
||||
@ -1148,49 +1168,79 @@ class SkipFunctionVariable(VariableTracker):
|
||||
else:
|
||||
try:
|
||||
path = inspect.getfile(self.value)
|
||||
msg = f"'skip function {self.value.__qualname__} in file {path}'"
|
||||
explanation = (
|
||||
f"Dynamo developers have intentionally marked that the function `{self.value.__qualname__}` "
|
||||
f"in file `{path}` should not be traced."
|
||||
)
|
||||
hints = [
|
||||
f"Avoid calling the function `{self.value.__qualname__}`.",
|
||||
]
|
||||
# TODO improve trace_rules reasoning to provide better hints.
|
||||
# How do we tell that a function/file should NOT be removed from skip files?
|
||||
# Do a very basic check for now.
|
||||
if "_dynamo" not in path:
|
||||
hints += [
|
||||
f"Remove the function `{self.value.__qualname__}` or the file `{path}` "
|
||||
"from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of "
|
||||
"attempting to trace into the function.",
|
||||
"Please file an issue to PyTorch.",
|
||||
# TODO suggest mark_force_inline when implemented
|
||||
]
|
||||
except TypeError:
|
||||
known_python_builtin_modules = {"_abc", "_warnings"}
|
||||
if self.value.__module__ in known_python_builtin_modules:
|
||||
msg = (
|
||||
f"Graph break due to unsupported Python builtin {self.value.__module__}.{self.value.__qualname__}. "
|
||||
f"Please file an issue on GitHub "
|
||||
f"so the PyTorch team can add support for it. "
|
||||
explanation = (
|
||||
f"Dynamo does not know how to trace the Python builtin "
|
||||
f"`{self.value.__module__}.{self.value.__qualname__}`."
|
||||
)
|
||||
hints = [
|
||||
"If you are attempting to call a logging function (e.g. `_warnings.warn`), "
|
||||
"you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.",
|
||||
"Please file an issue on GitHub "
|
||||
"so the PyTorch team can add support for it. ",
|
||||
]
|
||||
elif (
|
||||
self.value.__module__ is not None
|
||||
and self.value.__module__.startswith("optree")
|
||||
):
|
||||
msg = (
|
||||
f"Graph break for an optree C/C++ function {self.value.__module__}.{self.value.__qualname__}."
|
||||
f" Consider using torch.utils._pytree - "
|
||||
f"https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py"
|
||||
)
|
||||
explanation = f"Dynamo cannot trace optree C/C++ function {self.value.__module__}.{self.value.__qualname__}."
|
||||
hints = [
|
||||
" Consider using torch.utils._pytree - "
|
||||
"https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py"
|
||||
]
|
||||
# also warn on it because most users won't see the graph break message
|
||||
torch._dynamo.utils.warn_once(msg)
|
||||
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
|
||||
else:
|
||||
msg = (
|
||||
f"Graph break due to unsupported builtin {self.value.__module__}.{self.value.__qualname__}. "
|
||||
explanation = (
|
||||
f"Dynamo does not know how to trace the builtin `{self.value.__module__}.{self.value.__qualname__}.` "
|
||||
f"This function is either a Python builtin (e.g. _warnings.warn) "
|
||||
f"or a third-party C/C++ Python extension (perhaps created with pybind). "
|
||||
f"If it is a Python builtin, please file an issue on GitHub "
|
||||
f"so the PyTorch team can add support for it and see the next case for a workaround. "
|
||||
f"If it is a third-party C/C++ Python extension, please "
|
||||
f"either wrap it into a PyTorch-understood custom operator "
|
||||
f"(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html "
|
||||
f"for more details) or, if it is traceable, use "
|
||||
f"torch.compiler.allow_in_graph."
|
||||
f"or a third-party C/C++ Python extension (perhaps created with pybind)."
|
||||
)
|
||||
hints = [
|
||||
"If it is a Python builtin, please file an issue on GitHub "
|
||||
"so the PyTorch team can add support for it and see the next case for a workaround.",
|
||||
"If it is a third-party C/C++ Python extension, please "
|
||||
"either wrap it into a PyTorch-understood custom operator "
|
||||
"(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html "
|
||||
"for more details) or, if it is traceable, use "
|
||||
"`torch.compiler.allow_in_graph`.",
|
||||
]
|
||||
# also warn on it because most users won't see the graph break message
|
||||
torch._dynamo.utils.warn_once(msg)
|
||||
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
|
||||
if self.value.__qualname__ == "allow_in_graph":
|
||||
msg = (
|
||||
explanation = (
|
||||
"Found an allow_in_graph decorator to a function which "
|
||||
"is created inside the parent function that is getting "
|
||||
"compiled. This is not supported for now."
|
||||
)
|
||||
msg += f"', {self.reason}'" if self.reason else ""
|
||||
unimplemented(msg)
|
||||
hints = []
|
||||
reason = self.reason if self.reason else "<missing reason>"
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to call function marked as skipped",
|
||||
context=f"module: {self.value.__module__}, qualname: {self.value.__qualname__}, skip reason: {reason}",
|
||||
explanation=explanation,
|
||||
hints=hints,
|
||||
)
|
||||
|
||||
def call_obj_hasattr(self, tx: "InstructionTranslator", name):
|
||||
return variables.ConstantVariable.create(hasattr(self.value, name))
|
||||
|
@ -3102,13 +3102,16 @@ class TestCase(expecttest.TestCase):
|
||||
|
||||
# Munges exceptions that internally contain stack traces, using munge_exc
|
||||
def assertExpectedInlineMunged(
|
||||
self, exc_type, callable, expect, *, suppress_suffix=True
|
||||
self, exc_type, callable, expect, *, suppress_suffix=True, post_munge=None,
|
||||
):
|
||||
try:
|
||||
callable()
|
||||
except exc_type as e:
|
||||
munged = munge_exc(e, suppress_suffix=suppress_suffix, skip=1)
|
||||
if post_munge:
|
||||
munged = post_munge(munged)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(e, suppress_suffix=suppress_suffix, skip=1), expect, skip=1
|
||||
munged, expect, skip=1
|
||||
)
|
||||
return
|
||||
self.fail(msg="Did not raise when expected to")
|
||||
|
Reference in New Issue
Block a user