[dynamo] change error_on_graph_break/fullgraph semantics (#161747)

This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- ~`torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks~
- `error_on_graph_break` is a new internal `torch.compile `setting that is lower-priority than `fullgraph`. It allows the user to toggle erroring/continuing on graph breaks.
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161747
Approved by: https://github.com/mlazos
ghstack dependencies: #161739
This commit is contained in:
William Wen
2025-09-03 16:43:14 -07:00
committed by PyTorch MergeBot
parent ba7f546ccc
commit f36f285953
11 changed files with 390 additions and 58 deletions

View File

@ -0,0 +1,242 @@
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
torch._logging.set_logs(graph_breaks=True)
```
# Toggling `error_on_graph_break`
**Summary:**
- When `fullgraph=False`, we can use `torch._dynamo.error_on_graph_break()` for more flexibility in
dealing with graph breaks.
So far, we have introduced two ways in dealing with graph breaks in `torch.compile`:
1. `fullgraph=True` errors on the first graph break and additionally guarantees that only one graph is traced from the code.
2. `fullgraph=False` continues tracing even when encountering graph breaks.
What if we want to disallow graph breaks for most of the code, but there are a few problematic functions where the graph breaks are hard to remove,
and we are okay with having those graph breaks? We can use `torch._dynamo.error_on_graph_break()` to achieve this.
`torch.compile` has an `error_on_graph_break` setting (initially set to `False`).
If a graph break or compiler error occurs in code while `error_on_graph_break` is set to `False`, then `torch.compile` will attempt to continue compilation after the graph break/error.
If `error_on_graph_break` is set to `True`, then `torch.compile` will abort compilation and propagate the error to user code.
A significant difference between `error_on_graph_break=True` and `fullgraph=True` is that the former **does not guarantee that a single graph will be captured**.
`error_on_graph_break` **can be arbitrarily toggled during compile time** by using the `torch._dynamo.error_on_graph_break()` context manager/decorator.
In comparison, once `fullgraph` is set to `True`, it cannot be set back to `False`.
Finally, `error_on_graph_break` has lower precedence than `fullgraph` - `error_on_graph_break` only takes effect when `fullgraph=False`.
## `error_on_graph_break(False)` example
```{code-cell}
@torch._dynamo.error_on_graph_break(False)
def code_with_a_difficult_graph_break(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def inner(x):
return code_with_a_difficult_graph_break(x)
# NOTE: fullgraph=False
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
return inner(x)
# No error, but there is a graph break
fn(torch.randn(3))
```
Using `error_on_graph_break(False)` under `error_on_graph_break(True)` is helpful for when we want to minimize graph breaks (i.e. follow the `fullgraph=True` programming model),
but there are some sections of code with non-performance-critical graph breaks that are difficult to work around.
`error_on_graph_break()` can be used as a context manager as well:
```{code-cell}
# NOTE: fullgraph=False
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
x = x + 1
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break() # no error
return x + 2
# No error, but there is a graph break
fn(torch.randn(3))
```
You can use monkey patching to toggle `error_on_graph_break` for code where you cannot edit the source (e.g. framework code):
```{code-cell}
class ThirdPartyModule(torch.nn.Module):
def forward(self, x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
tp_mod = ThirdPartyModule()
tp_mod.forward = torch._dynamo.error_on_graph_break(False)(tp_mod.forward)
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
return tp_mod.forward(x)
# No error, but there is a graph break
fn(torch.randn(3))
```
## `error_on_graph_break(True)` example
```{code-cell}
@torch._dynamo.error_on_graph_break(True)
def inner2(x):
x = x + 1
torch._dynamo.graph_break() # error
return x + 2
def inner(x):
return inner2(x)
# fullgraph=False, error_on_graph_break=False
@torch.compile
def fn(x):
x = x + 4
torch._dynamo.graph_break() # no error
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
```
Using `error_on_graph_break(True)` under `error_on_graph_break(False)` is helpful for when we want to use `torch.compile` flexibly (i.e. follow the `fullgraph=False` programming model),
but there are some sections of the code that are performance-critical and we want to ensure that those sections do not contain graph breaks.
## `error_on_graph_break` nesting behavior
`torch._dynamo.error_on_graph_break()` affects the `error_on_graph_break` setting of nested calls as well:
```{code-cell}
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def inner2(x):
with torch._dynamo.error_on_graph_break(False):
return inner(x)
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
return inner2(x)
# no error
fn(torch.randn(3))
```
`torch._dynamo.error_on_graph_break()` can be used under another `torch._dynamo.error_on_graph_break()` region:
```{code-cell}
def inner(x):
x = x + 1
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break()
return x + 2
def inner2(x):
with torch._dynamo.error_on_graph_break(True):
return inner(x)
@torch.compile
def fn(x):
return inner2(x)
# no error
fn(torch.randn(3))
```
## Interaction with `fullgraph`
`fullgraph=True` takes higher precedence than `error_on_graph_break`:
```{code-cell}
@torch._dynamo.error_on_graph_break(False)
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(fullgraph=True)
def fn(x):
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
```
`fullgraph=True` cannot be toggled back to `fullgraph=False`:
```{code-cell}
@torch.compile(fullgraph=False)
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(fullgraph=True)
def fn(x):
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
```
```{code-cell}
@torch.compile(fullgraph=True)
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(fullgraph=False)
def fn(x):
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
```
## Summary of `fullgraph=True/False` vs `error_on_graph_break`
Here is a table summarizing the differences between `fullgraph=True/False` and `error_on_graph_break`:
| | `error_on_graph_break=True` | `error_on_graph_break=False` (default) |
| --- | --- | --- |
| `fullgraph=True` | Graph breaks result in errors. Only the first graph break will be reported. **One graph guarantee.**<br><br>`fullgraph` cannot be toggled to `False`. `error_on_graph_break` has no effect.<br><br>User code must be fully compatible with `torch.compile`. Guarantees no performance hits from graph breaks (because there are no graph breaks).<br><br>Ideal for code sensitive to graph breaks: framework/library code or cases where getting maximum performance is required. Prevents downstream user code from inadvertently allowing graph breaks. | Same as `fullgraph=True` and `error_on_graph_break=True` as `error_on_graph_break` has no effect when `fullgraph=True`. |
| `fullgraph=False` (default) | Graph breaks result in errors. Only the first graph break will be reported. **No one graph guarantee.**<br><br>`error_on_graph_break` can be toggled to `False`.<br><br>User code must be fully compatible with `torch.compile`. Guarantees no performance hits from graph breaks (because there are no graph breaks).<br><br>Ideal for user code sensitive to graph breaks. `error_on_graph_break` can be toggled to `False` to deal with sections that have graph breaks that are difficult to work around. | Will continue to compile after encountering graph breaks. All graph breaks will be reported.<br><br>`error_on_graph_break` can be toggled to `True`.<br><br>Doesnt require many user code changes to work. Performance may be negatively impacted due to graph breaks.<br><br>Ideal for out-of-the-box use cases, on “non-weird” code, or where squeezing maximal performance is not necessary |

View File

@ -19,6 +19,7 @@ The strategy for using `torch.compile(fullgraph=False)` is as follows:
```{toctree}
programming_model.where_to_apply_compile
programming_model.compiler_disable
programming_model.error_on_graph_break
programming_model.nested_graph_breaks
programming_model.skipped_functions
```

View File

@ -1067,11 +1067,10 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
cnts.clear()
torch._dynamo.reset()
with self.assertRaisesRegex(
Unsupported, r"Skip calling `torch.compiler.disable\(\)`d function"
):
fn3(torch.randn(4, 5))
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
def test_disable_optimize(self):
cnt = torch._dynamo.testing.CompileCounter()
@ -1724,7 +1723,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def test_error_on_graph_break(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
@torch._dynamo.error_on_graph_break(True)
@torch.compile(backend=cnts)
def f1(x):
x = x + 1
with torch._dynamo.error_on_graph_break(False):
@ -1745,7 +1745,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
with self.assertRaises(Unsupported):
f2(inp)
@torch.compile(backend=cnts, fullgraph=True)
@torch._dynamo.error_on_graph_break(True)
@torch.compile(backend=cnts)
def f3(x):
x = x + 1
with torch._dynamo.error_on_graph_break(False):
@ -1763,7 +1764,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
torch._dynamo.graph_break()
return x + 4
@torch.compile(backend=cnts, fullgraph=True)
@torch._dynamo.error_on_graph_break(True)
@torch.compile(backend=cnts)
def f4(x):
x = x + 1
with torch._dynamo.error_on_graph_break(False):
@ -1784,7 +1786,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
torch._dynamo.graph_break()
return x + 4
@torch.compile(backend=cnts, fullgraph=True)
@torch._dynamo.error_on_graph_break(True)
@torch.compile(backend=cnts)
def f5(x):
x = x + 1
return inner_f5(x)
@ -1799,7 +1802,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
torch._dynamo.graph_break()
return x + 4
@torch.compile(backend=cnts, fullgraph=True)
@torch._dynamo.error_on_graph_break(True)
@torch.compile(backend=cnts)
def f6(x):
x = x + 1
return inner_f6(x)
@ -1814,7 +1818,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
torch._dynamo.graph_break()
return x + 4
@torch.compile(backend=cnts, fullgraph=False)
@torch._dynamo.error_on_graph_break(False)
@torch.compile(backend=cnts)
def f7(x):
x = x + 1
return inner_f7(x)
@ -1837,7 +1842,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
torch._dynamo.skip_frame()
return inner2_f8(x)
@torch.compile(backend=cnts, fullgraph=True)
@torch._dynamo.error_on_graph_break(True)
@torch.compile(backend=cnts)
def f8(x):
x = x + 1
return inner1_f8(x)
@ -1856,7 +1862,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def inner1_f9(x):
return inner2_f9(x)
@torch.compile(backend=cnts, fullgraph=False)
@torch._dynamo.error_on_graph_break(False)
@torch.compile(backend=cnts)
def f9(x):
x = x + 1
return inner1_f9(x)
@ -1898,7 +1905,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def inner4_f1(x):
return inner3_f1(x)
@torch.compile(backend=cnts, fullgraph=True)
@torch._dynamo.error_on_graph_break(True)
@torch.compile(backend=cnts)
def f1(x):
x = x + 4
return inner4_f1(x)
@ -1922,7 +1930,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def inner4_f2(x):
return inner3_f2(x)
@torch.compile(backend=cnts, fullgraph=False)
@torch._dynamo.error_on_graph_break(False)
@torch.compile(backend=cnts)
def f2(x):
x = x + 4
return inner4_f2(x)
@ -1953,34 +1962,88 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
with self.assertRaises(Exception):
f3()
def test_nested_compile_fullgraph(self):
def test_nested_compile_error_on_graph_break(self):
inp = torch.ones(3)
@torch.compile(backend="eager", fullgraph=True)
@torch._dynamo.error_on_graph_break(True)
@torch.compile(backend="eager")
def inner_f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(backend="eager", fullgraph=False)
@torch._dynamo.error_on_graph_break(False)
@torch.compile(backend="eager")
def f1(x):
return inner_f1(x)
with self.assertRaises(Unsupported):
f1(inp)
@torch.compile(backend="eager", fullgraph=False)
@torch._dynamo.error_on_graph_break(False)
@torch.compile(backend="eager")
def inner_f2(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(backend="eager", fullgraph=True)
@torch._dynamo.error_on_graph_break(True)
@torch.compile(backend="eager")
def f2(x):
return inner_f2(x)
self.assertEqual(f2(inp), inp + 3)
def test_error_on_graph_break_fullgraph(self):
# Test that error_on_graph_break=False cannot override fullgraph=True
inp = torch.ones(3)
@torch.compile(backend="eager", fullgraph=True)
def f(x):
x = x + 1
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break()
return x + 2
with self.assertRaises(Unsupported):
f(inp)
def test_error_on_graph_break_empty_graph(self):
@torch._dynamo.error_on_graph_break(True)
@torch.compile(backend="eager")
def f():
return 1
self.assertEqual(f(), 1)
def test_nested_compile_fullgraph(self):
# Test that fullgraph=True cannot be toggled back by fullgraph=False
inp = torch.ones(3)
@torch.compile(backend="eager", fullgraph=True)
def inner_f1(x):
torch._dynamo.graph_break()
return x + 1
@torch.compile(backend="eager", fullgraph=False)
def outer_f1(x):
return inner_f1(x)
with self.assertRaises(Unsupported):
outer_f1(inp)
@torch.compile(backend="eager", fullgraph=False)
def inner_f2(x):
torch._dynamo.graph_break()
return x + 1
@torch.compile(backend="eager", fullgraph=True)
def outer_f2(x):
return inner_f2(x)
with self.assertRaises(Unsupported):
outer_f2(inp)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -93,7 +93,8 @@ class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
return func(*args, **kwargs)
# test e2e, with Inductor, as smoketest.
@torch.compile(fullgraph=True, backend="inductor")
@torch._dynamo.error_on_graph_break(True)
@torch.compile(backend="inductor")
def g(x):
return 2 * x.sin().cos()

View File

@ -1884,9 +1884,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
)
# Ensure no more re-compilation after the second automatic dynamic shape version.
if i == 0:
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
else:
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 4)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)

View File

@ -524,12 +524,6 @@ class ConvertFrameBox:
error_on_graph_break: Optional[bool] = None
def _is_error_on_graph_break(tx: Optional[DynamoTracerOutput]) -> bool:
if tx is None:
return _get_error_on_graph_break()
return tx.error_on_graph_break
def get_compile_id(
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
) -> CompileId:
@ -1167,10 +1161,8 @@ def _compile(
package=package,
)
except exc.SkipFrame as e:
if one_graph or _is_error_on_graph_break(e._torch_dynamo_tracer_output):
log.debug(
"No graph captured with one_graph=True or error_on_graph_break=True"
)
if one_graph:
log.debug("No graph captured with export/fullgraph=True")
assert e._torch_dynamo_tracer_output is not None
return ConvertFrameReturn(), e._torch_dynamo_tracer_output
@ -1376,10 +1368,9 @@ def _compile(
raise FailOnRecompileLimitHit(
f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure"
)
elif one_graph or _get_error_on_graph_break():
elif one_graph:
raise FailOnRecompileLimitHit(
f"{limit_type} reached with one_graph=True or error_on_graph_break=True. "
"Excessive recompilations can degrade "
f"{limit_type} reached with fullgraph=True. Excessive recompilations can degrade "
"performance due to the compilation overhead of each recompilation. To monitor "
"recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider "
"increasing torch._dynamo.config.cache_size_limit to an appropriate value."

View File

@ -941,9 +941,17 @@ def error_on_graph_break(
error_on_graph_break: bool,
) -> ErrorOnGraphBreakDecoratorContextManager:
"""
Context manager/decorator to toggle error_on_graph_break (i.e. torch.compile's fullgraph) setting.
Context manager/decorator to toggle torch.compile's `error_on_graph_break` setting at compile time.
More precisely, when encountering a graph break, we will decide to resume (fullgraph=False)
or error out (fullgraph=True) based on the fullgraph setting at the location of the graph break.
If `fullgraph` is set, then `error_on_graph_break` does nothing
(i.e. `fullgraph = True` takes higher precedence). If `fullgraph` is False, then
`error_on_graph_break` determines whether `torch.compile` throws an error upon
encountering a graph break, or attempts to continue tracing.
`error_on_graph_break` can be toggled during compile time with this decorator to allow graph breaks in some
compiled regions but not others. One key difference from `fullgraph` is that `error_on_graph_break = True`
does NOT guarantee that a single graph is captured from the compiled function.
The default value of torch.compile's `error_on_graph_break` setting is False.
"""
return ErrorOnGraphBreakDecoratorContextManager(error_on_graph_break)

View File

@ -598,7 +598,8 @@ class _TorchDynamoContext:
patch_fn: Callable[[], Any] = nothing,
first_ctx: bool = False,
*,
error_on_graph_break: bool = False,
fullgraph: bool = False,
error_on_graph_break: Optional[bool] = None,
export: bool = False,
dynamic: Optional[bool] = None,
compiler_config: Optional[Any] = None,
@ -611,6 +612,7 @@ class _TorchDynamoContext:
self._backend_ctx_ctor = backend_ctx_ctor
self.prior: Union[Unset, DynamoCallback] = unset
self.first_ctx = first_ctx
self.fullgraph = fullgraph
self.error_on_graph_break = error_on_graph_break
self.export = export
self._dynamic = dynamic
@ -705,7 +707,7 @@ class _TorchDynamoContext:
def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any:
from torch._dynamo.aot_compile import aot_compile_fullgraph
if not self.error_on_graph_break:
if not self.fullgraph:
raise RuntimeError(
"Graph breaks are not supported with aot compile. Please use torch.compile(fullgraph=True)."
)
@ -810,7 +812,7 @@ class _TorchDynamoContext:
_is_skip_guard_eval_unsafe_stance()
)
prior_error_on_graph_break = None
if self.error_on_graph_break is not None:
if not self.fullgraph and self.error_on_graph_break is not None:
prior_error_on_graph_break = _get_error_on_graph_break()
_set_error_on_graph_break(self.error_on_graph_break)
@ -857,11 +859,14 @@ class _TorchDynamoContext:
_maybe_set_eval_frame(prior)
# hooks to properly handle inlining
if self.error_on_graph_break is not None:
compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined]
external_utils.wrap_inline_with_error_on_graph_break(
fn, self.error_on_graph_break
)
)
else:
compile_wrapper._torchdynamo_inline = fn # type: ignore[attr-defined]
# Save the function pointer to find the original callable while nesting
# of decorators.
@ -923,7 +928,8 @@ class OptimizeContext(_TorchDynamoContext):
backend_ctx_ctor: Callable[[], contextlib.AbstractContextManager[Any]],
first_ctx: bool = False,
*,
error_on_graph_break: bool = False,
fullgraph: bool = False,
error_on_graph_break: Optional[bool] = None,
export: bool = False,
dynamic: Optional[bool] = None,
compiler_config: Optional[Any] = None,
@ -942,6 +948,7 @@ class OptimizeContext(_TorchDynamoContext):
backend_ctx_ctor=backend_ctx_ctor,
patch_fn=TorchPatcher.patch,
first_ctx=first_ctx,
fullgraph=fullgraph,
error_on_graph_break=error_on_graph_break,
export=export,
dynamic=dynamic,
@ -1067,7 +1074,8 @@ def _optimize_catch_errors(
backend_ctx_ctor: Callable[
[], contextlib.AbstractContextManager[Any]
] = null_context,
error_on_graph_break: bool = False,
fullgraph: bool = False,
error_on_graph_break: Optional[bool] = None,
export: bool = False,
dynamic: Optional[bool] = None,
compiler_config: Optional[Any] = None,
@ -1078,6 +1086,7 @@ def _optimize_catch_errors(
convert_frame.catch_errors_wrapper(compile_fn, hooks),
backend_ctx_ctor=backend_ctx_ctor,
first_ctx=True,
fullgraph=fullgraph,
error_on_graph_break=error_on_graph_break,
export=export,
dynamic=dynamic,
@ -1176,6 +1185,7 @@ def _optimize(
backend: Union[str, Callable[..., Any]] = "inductor",
*,
nopython: bool = False,
error_on_graph_break: Optional[bool] = None,
guard_export_fn: Optional[Callable[[_guards.GuardsSet], None]] = None,
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None,
@ -1198,6 +1208,11 @@ def _optimize(
- Or, a string backend name in `torch._dynamo.list_backends()`
nopython: If True, graph breaks will be errors and there will
be a single whole-program graph.
error_on_graph_break: If not None, the current `error_on_graph_break` setting is set to the given value.
See `torch._dynamo.error_on_graph_break()` for more details on what `error_on_graph_break` means.
Unlike `nopython=True` (i.e. `fullgraph=True`), there is no guarantee of a single whole-program graph.
If `nopython` is True, `error_on_graph_break` does nothing.
disable: If True, turn this decorator into a no-op
dynamic: If True, upfront compile as dynamic a kernel as possible. If False,
disable all dynamic shapes support (always specialize). If None, automatically
@ -1228,6 +1243,15 @@ def _optimize(
):
return _NullDecorator()
if nopython and not config.debug_force_graph_break_on_leaf_return:
return optimize_assert(
backend,
dynamic=dynamic,
hooks=hooks,
rebuild_ctx=rebuild_ctx,
package=package,
)
backend = get_compiler_fn(backend)
# Find if backend has any extra context manager
@ -1252,7 +1276,8 @@ def _optimize(
),
hooks,
backend_ctx_ctor,
error_on_graph_break=nopython
fullgraph=False,
error_on_graph_break=error_on_graph_break
and not config.debug_force_graph_break_on_leaf_return,
dynamic=dynamic,
compiler_config=(
@ -2174,10 +2199,11 @@ def _optimize_assert(
package: Optional[CompilePackage] = None,
) -> OptimizeContext:
"""
The same as `torch._dynamo.optimize(backend, nopython=True)`,
but ignores symbolic_convert.error_on_graph_break setting.
Guarantees single-graph capture.
The same as `torch._dynamo.optimize(backend)` but ignores
symbolic_convert.error_on_graph_break setting.
Used for export, since we must always error on graph breaks and ignore
Used for fullgraph=True and export, since we must always error on graph breaks and ignore
symbolic_convert.error_on_graph_break. Can also be used for testing.
"""
backend = get_compiler_fn(backend)
@ -2204,6 +2230,7 @@ def _optimize_assert(
),
hooks,
backend_ctx_ctor,
fullgraph=True,
export=export,
dynamic=dynamic,
rebuild_ctx=rebuild_ctx,

View File

@ -203,7 +203,7 @@ def wrap_dunder_call_ctx_manager(self: Any, func: Callable[_P, _R]) -> Callable[
Apply self as a ctx manager around a call to func
"""
@functools.wraps(func)
# NOTE: do not functools.wraps(func) because we don't ever want this frame to be skipped!
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with self:
return func(*args, **kwargs)
@ -234,16 +234,15 @@ def wrap_inline_with_error_on_graph_break(
) -> Callable[_P, _R]:
# NB: need multiple definitions in order to prevent `fullgraph` from
# being a freevar of wrapper
# NOTE: do not functools.wraps(fn) because we don't ever want these wrappers to be skipped!
if error_on_graph_break:
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with torch._dynamo.error_on_graph_break(True):
return fn(*args, **kwargs)
else:
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with torch._dynamo.error_on_graph_break(False):
return fn(*args, **kwargs)

View File

@ -3756,8 +3756,8 @@ class InstructionTranslatorBase(
self.num_calls: dict[str, int] = {}
# Flag to indicate whether tracing is used for export.
self.export = export
# NOTE: one_graph is used for export/debugging to always force errors on graph breaks.
# To toggle fullgraph during normal compile, self.error_on_graph_break
# NOTE: one_graph is used for export/fullgraph=True to always force errors on graph breaks.
# To toggle erroring/resuming on graph breaks during fullgraph=False compile, self.error_on_graph_break
# is used instead. Every step(), its value is updated to the global tls.error_on_graph_break.
# We mirror this value since cleanup may (correctly) inadvertently change tls.error_on_graph_break.
# This assumes that we cannot both trace a change to tls.error_on_graph_break and graph break on

View File

@ -170,7 +170,7 @@ class CPythonTestCase(TestCase):
# We want to compile only the test function, excluding any setup code
# from unittest
method = getattr(self, self._testMethodName)
method = torch._dynamo.optimize(backend, nopython=nopython)(method)
method = torch._dynamo.optimize(backend, error_on_graph_break=nopython)(method)
setattr(self, self._testMethodName, method)
return fn