mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo] Graph break on on user-defined class in compiled region (#161670)
Currently, user-defined classes inside of a compiled frame will cause the whole frame to be skipped by dynamo. This change defers the Unsupported exception until the __build_class__ builtin is actually called, which allows a graph break to be inserted. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161670 Approved by: https://github.com/williamwen42, https://github.com/guilhermeleobas
This commit is contained in:
committed by
PyTorch MergeBot
parent
dda071587f
commit
5ac112b569
@ -726,14 +726,14 @@ Call to `torch._dynamo.graph_break()`
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
|
||||
"""\
|
||||
LOAD_BUILD_CLASS bytecode not supported
|
||||
Explanation: Dynamo does not support tracing classes that are defined in the compiled region.
|
||||
Hint: Move the class definition out of the compiled region.
|
||||
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo does not know how to trace the builtin `builtins.__build_class__.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
|
||||
Hint: If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
|
||||
Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
|
||||
|
||||
Developer debug context:
|
||||
Developer debug context: module: builtins, qualname: __build_class__, skip reason: <missing reason>
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0075.html
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html
|
||||
|
||||
from user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
|
@ -12684,6 +12684,22 @@ fn
|
||||
self.assertRaises(Unsupported, f, [])
|
||||
self.assertRaises(Unsupported, f, "1 + j")
|
||||
|
||||
def test_compiled_class_graph_break(self):
|
||||
counter = CompileCounter()
|
||||
|
||||
@torch.compile(backend=counter, fullgraph=False)
|
||||
def f(x):
|
||||
x += 1
|
||||
|
||||
class C:
|
||||
pass
|
||||
|
||||
return x.sin()
|
||||
|
||||
x = torch.randn(3)
|
||||
f(x)
|
||||
self.assertEqual(counter.frame_count, 2)
|
||||
|
||||
|
||||
class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
@parametrize_pytree_module
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user