mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] change error_on_graph_break/fullgraph semantics (#161747)
This PR implements the semantics change to `torch._dynamo.error_on_graph_break`: - ~`torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks~ - `error_on_graph_break` is a new internal `torch.compile `setting that is lower-priority than `fullgraph`. It allows the user to toggle erroring/continuing on graph breaks. - `error_on_graph_break` does nothing when `fullgraph=True` - `error_on_graph_break` does NOT guarantee a single graph Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation: - `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled - `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time - `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time Pull Request resolved: https://github.com/pytorch/pytorch/pull/161747 Approved by: https://github.com/mlazos ghstack dependencies: #161739
This commit is contained in:
committed by
PyTorch MergeBot
parent
ba7f546ccc
commit
f36f285953
242
docs/source/compile/programming_model.error_on_graph_break.md
Normal file
242
docs/source/compile/programming_model.error_on_graph_break.md
Normal file
@ -0,0 +1,242 @@
|
||||
---
|
||||
file_format: mystnb
|
||||
kernelspec:
|
||||
name: python3
|
||||
mystnb:
|
||||
execution_timeout: 30
|
||||
execution_show_tb: True
|
||||
merge_streams: True
|
||||
---
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import torch
|
||||
|
||||
import header_code
|
||||
torch._logging.set_logs(graph_breaks=True)
|
||||
```
|
||||
|
||||
# Toggling `error_on_graph_break`
|
||||
|
||||
**Summary:**
|
||||
|
||||
- When `fullgraph=False`, we can use `torch._dynamo.error_on_graph_break()` for more flexibility in
|
||||
dealing with graph breaks.
|
||||
|
||||
So far, we have introduced two ways in dealing with graph breaks in `torch.compile`:
|
||||
1. `fullgraph=True` errors on the first graph break and additionally guarantees that only one graph is traced from the code.
|
||||
2. `fullgraph=False` continues tracing even when encountering graph breaks.
|
||||
|
||||
What if we want to disallow graph breaks for most of the code, but there are a few problematic functions where the graph breaks are hard to remove,
|
||||
and we are okay with having those graph breaks? We can use `torch._dynamo.error_on_graph_break()` to achieve this.
|
||||
|
||||
`torch.compile` has an `error_on_graph_break` setting (initially set to `False`).
|
||||
If a graph break or compiler error occurs in code while `error_on_graph_break` is set to `False`, then `torch.compile` will attempt to continue compilation after the graph break/error.
|
||||
If `error_on_graph_break` is set to `True`, then `torch.compile` will abort compilation and propagate the error to user code.
|
||||
|
||||
A significant difference between `error_on_graph_break=True` and `fullgraph=True` is that the former **does not guarantee that a single graph will be captured**.
|
||||
`error_on_graph_break` **can be arbitrarily toggled during compile time** by using the `torch._dynamo.error_on_graph_break()` context manager/decorator.
|
||||
In comparison, once `fullgraph` is set to `True`, it cannot be set back to `False`.
|
||||
Finally, `error_on_graph_break` has lower precedence than `fullgraph` - `error_on_graph_break` only takes effect when `fullgraph=False`.
|
||||
|
||||
|
||||
## `error_on_graph_break(False)` example
|
||||
|
||||
```{code-cell}
|
||||
@torch._dynamo.error_on_graph_break(False)
|
||||
def code_with_a_difficult_graph_break(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
def inner(x):
|
||||
return code_with_a_difficult_graph_break(x)
|
||||
|
||||
# NOTE: fullgraph=False
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return inner(x)
|
||||
|
||||
# No error, but there is a graph break
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
Using `error_on_graph_break(False)` under `error_on_graph_break(True)` is helpful for when we want to minimize graph breaks (i.e. follow the `fullgraph=True` programming model),
|
||||
but there are some sections of code with non-performance-critical graph breaks that are difficult to work around.
|
||||
|
||||
`error_on_graph_break()` can be used as a context manager as well:
|
||||
|
||||
```{code-cell}
|
||||
# NOTE: fullgraph=False
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
torch._dynamo.graph_break() # no error
|
||||
return x + 2
|
||||
|
||||
# No error, but there is a graph break
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
You can use monkey patching to toggle `error_on_graph_break` for code where you cannot edit the source (e.g. framework code):
|
||||
|
||||
```{code-cell}
|
||||
class ThirdPartyModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
tp_mod = ThirdPartyModule()
|
||||
tp_mod.forward = torch._dynamo.error_on_graph_break(False)(tp_mod.forward)
|
||||
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return tp_mod.forward(x)
|
||||
|
||||
# No error, but there is a graph break
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
## `error_on_graph_break(True)` example
|
||||
|
||||
```{code-cell}
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
def inner2(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break() # error
|
||||
return x + 2
|
||||
|
||||
def inner(x):
|
||||
return inner2(x)
|
||||
|
||||
# fullgraph=False, error_on_graph_break=False
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
x = x + 4
|
||||
torch._dynamo.graph_break() # no error
|
||||
return inner(x)
|
||||
|
||||
try:
|
||||
fn(torch.randn(3))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
```
|
||||
|
||||
Using `error_on_graph_break(True)` under `error_on_graph_break(False)` is helpful for when we want to use `torch.compile` flexibly (i.e. follow the `fullgraph=False` programming model),
|
||||
but there are some sections of the code that are performance-critical and we want to ensure that those sections do not contain graph breaks.
|
||||
|
||||
## `error_on_graph_break` nesting behavior
|
||||
|
||||
`torch._dynamo.error_on_graph_break()` affects the `error_on_graph_break` setting of nested calls as well:
|
||||
|
||||
```{code-cell}
|
||||
def inner(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
def inner2(x):
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
return inner(x)
|
||||
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return inner2(x)
|
||||
|
||||
# no error
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
`torch._dynamo.error_on_graph_break()` can be used under another `torch._dynamo.error_on_graph_break()` region:
|
||||
|
||||
```{code-cell}
|
||||
def inner(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
def inner2(x):
|
||||
with torch._dynamo.error_on_graph_break(True):
|
||||
return inner(x)
|
||||
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return inner2(x)
|
||||
|
||||
# no error
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
## Interaction with `fullgraph`
|
||||
|
||||
`fullgraph=True` takes higher precedence than `error_on_graph_break`:
|
||||
|
||||
|
||||
```{code-cell}
|
||||
@torch._dynamo.error_on_graph_break(False)
|
||||
def inner(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def fn(x):
|
||||
return inner(x)
|
||||
|
||||
try:
|
||||
fn(torch.randn(3))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
```
|
||||
|
||||
`fullgraph=True` cannot be toggled back to `fullgraph=False`:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile(fullgraph=False)
|
||||
def inner(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def fn(x):
|
||||
return inner(x)
|
||||
|
||||
try:
|
||||
fn(torch.randn(3))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile(fullgraph=True)
|
||||
def inner(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
@torch.compile(fullgraph=False)
|
||||
def fn(x):
|
||||
return inner(x)
|
||||
|
||||
try:
|
||||
fn(torch.randn(3))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
```
|
||||
|
||||
## Summary of `fullgraph=True/False` vs `error_on_graph_break`
|
||||
|
||||
Here is a table summarizing the differences between `fullgraph=True/False` and `error_on_graph_break`:
|
||||
|
||||
| | `error_on_graph_break=True` | `error_on_graph_break=False` (default) |
|
||||
| --- | --- | --- |
|
||||
| `fullgraph=True` | Graph breaks result in errors. Only the first graph break will be reported. **One graph guarantee.**<br><br>`fullgraph` cannot be toggled to `False`. `error_on_graph_break` has no effect.<br><br>User code must be fully compatible with `torch.compile`. Guarantees no performance hits from graph breaks (because there are no graph breaks).<br><br>Ideal for code sensitive to graph breaks: framework/library code or cases where getting maximum performance is required. Prevents downstream user code from inadvertently allowing graph breaks. | Same as `fullgraph=True` and `error_on_graph_break=True` as `error_on_graph_break` has no effect when `fullgraph=True`. |
|
||||
| `fullgraph=False` (default) | Graph breaks result in errors. Only the first graph break will be reported. **No one graph guarantee.**<br><br>`error_on_graph_break` can be toggled to `False`.<br><br>User code must be fully compatible with `torch.compile`. Guarantees no performance hits from graph breaks (because there are no graph breaks).<br><br>Ideal for user code sensitive to graph breaks. `error_on_graph_break` can be toggled to `False` to deal with sections that have graph breaks that are difficult to work around. | Will continue to compile after encountering graph breaks. All graph breaks will be reported.<br><br>`error_on_graph_break` can be toggled to `True`.<br><br>Doesn’t require many user code changes to work. Performance may be negatively impacted due to graph breaks.<br><br>Ideal for out-of-the-box use cases, on “non-weird” code, or where squeezing maximal performance is not necessary |
|
@ -19,6 +19,7 @@ The strategy for using `torch.compile(fullgraph=False)` is as follows:
|
||||
```{toctree}
|
||||
programming_model.where_to_apply_compile
|
||||
programming_model.compiler_disable
|
||||
programming_model.error_on_graph_break
|
||||
programming_model.nested_graph_breaks
|
||||
programming_model.skipped_functions
|
||||
```
|
||||
|
@ -1067,11 +1067,10 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
|
||||
cnts.clear()
|
||||
torch._dynamo.reset()
|
||||
fn3(torch.randn(4, 5))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
with self.assertRaisesRegex(
|
||||
Unsupported, r"Skip calling `torch.compiler.disable\(\)`d function"
|
||||
):
|
||||
fn3(torch.randn(4, 5))
|
||||
|
||||
def test_disable_optimize(self):
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
@ -1724,7 +1723,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
def test_error_on_graph_break(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile(backend=cnts)
|
||||
def f1(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
@ -1745,7 +1745,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
with self.assertRaises(Unsupported):
|
||||
f2(inp)
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile(backend=cnts)
|
||||
def f3(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
@ -1763,7 +1764,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile(backend=cnts)
|
||||
def f4(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
@ -1784,7 +1786,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile(backend=cnts)
|
||||
def f5(x):
|
||||
x = x + 1
|
||||
return inner_f5(x)
|
||||
@ -1799,7 +1802,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile(backend=cnts)
|
||||
def f6(x):
|
||||
x = x + 1
|
||||
return inner_f6(x)
|
||||
@ -1814,7 +1818,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 4
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=False)
|
||||
@torch._dynamo.error_on_graph_break(False)
|
||||
@torch.compile(backend=cnts)
|
||||
def f7(x):
|
||||
x = x + 1
|
||||
return inner_f7(x)
|
||||
@ -1837,7 +1842,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
torch._dynamo.skip_frame()
|
||||
return inner2_f8(x)
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile(backend=cnts)
|
||||
def f8(x):
|
||||
x = x + 1
|
||||
return inner1_f8(x)
|
||||
@ -1856,7 +1862,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
def inner1_f9(x):
|
||||
return inner2_f9(x)
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=False)
|
||||
@torch._dynamo.error_on_graph_break(False)
|
||||
@torch.compile(backend=cnts)
|
||||
def f9(x):
|
||||
x = x + 1
|
||||
return inner1_f9(x)
|
||||
@ -1898,7 +1905,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
def inner4_f1(x):
|
||||
return inner3_f1(x)
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile(backend=cnts)
|
||||
def f1(x):
|
||||
x = x + 4
|
||||
return inner4_f1(x)
|
||||
@ -1922,7 +1930,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
def inner4_f2(x):
|
||||
return inner3_f2(x)
|
||||
|
||||
@torch.compile(backend=cnts, fullgraph=False)
|
||||
@torch._dynamo.error_on_graph_break(False)
|
||||
@torch.compile(backend=cnts)
|
||||
def f2(x):
|
||||
x = x + 4
|
||||
return inner4_f2(x)
|
||||
@ -1953,34 +1962,88 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
with self.assertRaises(Exception):
|
||||
f3()
|
||||
|
||||
def test_nested_compile_fullgraph(self):
|
||||
def test_nested_compile_error_on_graph_break(self):
|
||||
inp = torch.ones(3)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile(backend="eager")
|
||||
def inner_f1(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=False)
|
||||
@torch._dynamo.error_on_graph_break(False)
|
||||
@torch.compile(backend="eager")
|
||||
def f1(x):
|
||||
return inner_f1(x)
|
||||
|
||||
with self.assertRaises(Unsupported):
|
||||
f1(inp)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=False)
|
||||
@torch._dynamo.error_on_graph_break(False)
|
||||
@torch.compile(backend="eager")
|
||||
def inner_f2(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile(backend="eager")
|
||||
def f2(x):
|
||||
return inner_f2(x)
|
||||
|
||||
self.assertEqual(f2(inp), inp + 3)
|
||||
|
||||
def test_error_on_graph_break_fullgraph(self):
|
||||
# Test that error_on_graph_break=False cannot override fullgraph=True
|
||||
inp = torch.ones(3)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
x = x + 1
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
with self.assertRaises(Unsupported):
|
||||
f(inp)
|
||||
|
||||
def test_error_on_graph_break_empty_graph(self):
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile(backend="eager")
|
||||
def f():
|
||||
return 1
|
||||
|
||||
self.assertEqual(f(), 1)
|
||||
|
||||
def test_nested_compile_fullgraph(self):
|
||||
# Test that fullgraph=True cannot be toggled back by fullgraph=False
|
||||
inp = torch.ones(3)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def inner_f1(x):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=False)
|
||||
def outer_f1(x):
|
||||
return inner_f1(x)
|
||||
|
||||
with self.assertRaises(Unsupported):
|
||||
outer_f1(inp)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=False)
|
||||
def inner_f2(x):
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def outer_f2(x):
|
||||
return inner_f2(x)
|
||||
|
||||
with self.assertRaises(Unsupported):
|
||||
outer_f2(inp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -93,7 +93,8 @@ class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# test e2e, with Inductor, as smoketest.
|
||||
@torch.compile(fullgraph=True, backend="inductor")
|
||||
@torch._dynamo.error_on_graph_break(True)
|
||||
@torch.compile(backend="inductor")
|
||||
def g(x):
|
||||
return 2 * x.sin().cos()
|
||||
|
||||
|
@ -1884,9 +1884,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
)
|
||||
# Ensure no more re-compilation after the second automatic dynamic shape version.
|
||||
if i == 0:
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
||||
else:
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 4)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
|
@ -524,12 +524,6 @@ class ConvertFrameBox:
|
||||
error_on_graph_break: Optional[bool] = None
|
||||
|
||||
|
||||
def _is_error_on_graph_break(tx: Optional[DynamoTracerOutput]) -> bool:
|
||||
if tx is None:
|
||||
return _get_error_on_graph_break()
|
||||
return tx.error_on_graph_break
|
||||
|
||||
|
||||
def get_compile_id(
|
||||
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
|
||||
) -> CompileId:
|
||||
@ -1167,10 +1161,8 @@ def _compile(
|
||||
package=package,
|
||||
)
|
||||
except exc.SkipFrame as e:
|
||||
if one_graph or _is_error_on_graph_break(e._torch_dynamo_tracer_output):
|
||||
log.debug(
|
||||
"No graph captured with one_graph=True or error_on_graph_break=True"
|
||||
)
|
||||
if one_graph:
|
||||
log.debug("No graph captured with export/fullgraph=True")
|
||||
assert e._torch_dynamo_tracer_output is not None
|
||||
return ConvertFrameReturn(), e._torch_dynamo_tracer_output
|
||||
|
||||
@ -1376,10 +1368,9 @@ def _compile(
|
||||
raise FailOnRecompileLimitHit(
|
||||
f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure"
|
||||
)
|
||||
elif one_graph or _get_error_on_graph_break():
|
||||
elif one_graph:
|
||||
raise FailOnRecompileLimitHit(
|
||||
f"{limit_type} reached with one_graph=True or error_on_graph_break=True. "
|
||||
"Excessive recompilations can degrade "
|
||||
f"{limit_type} reached with fullgraph=True. Excessive recompilations can degrade "
|
||||
"performance due to the compilation overhead of each recompilation. To monitor "
|
||||
"recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider "
|
||||
"increasing torch._dynamo.config.cache_size_limit to an appropriate value."
|
||||
|
@ -941,9 +941,17 @@ def error_on_graph_break(
|
||||
error_on_graph_break: bool,
|
||||
) -> ErrorOnGraphBreakDecoratorContextManager:
|
||||
"""
|
||||
Context manager/decorator to toggle error_on_graph_break (i.e. torch.compile's fullgraph) setting.
|
||||
Context manager/decorator to toggle torch.compile's `error_on_graph_break` setting at compile time.
|
||||
|
||||
More precisely, when encountering a graph break, we will decide to resume (fullgraph=False)
|
||||
or error out (fullgraph=True) based on the fullgraph setting at the location of the graph break.
|
||||
If `fullgraph` is set, then `error_on_graph_break` does nothing
|
||||
(i.e. `fullgraph = True` takes higher precedence). If `fullgraph` is False, then
|
||||
`error_on_graph_break` determines whether `torch.compile` throws an error upon
|
||||
encountering a graph break, or attempts to continue tracing.
|
||||
|
||||
`error_on_graph_break` can be toggled during compile time with this decorator to allow graph breaks in some
|
||||
compiled regions but not others. One key difference from `fullgraph` is that `error_on_graph_break = True`
|
||||
does NOT guarantee that a single graph is captured from the compiled function.
|
||||
|
||||
The default value of torch.compile's `error_on_graph_break` setting is False.
|
||||
"""
|
||||
return ErrorOnGraphBreakDecoratorContextManager(error_on_graph_break)
|
||||
|
@ -598,7 +598,8 @@ class _TorchDynamoContext:
|
||||
patch_fn: Callable[[], Any] = nothing,
|
||||
first_ctx: bool = False,
|
||||
*,
|
||||
error_on_graph_break: bool = False,
|
||||
fullgraph: bool = False,
|
||||
error_on_graph_break: Optional[bool] = None,
|
||||
export: bool = False,
|
||||
dynamic: Optional[bool] = None,
|
||||
compiler_config: Optional[Any] = None,
|
||||
@ -611,6 +612,7 @@ class _TorchDynamoContext:
|
||||
self._backend_ctx_ctor = backend_ctx_ctor
|
||||
self.prior: Union[Unset, DynamoCallback] = unset
|
||||
self.first_ctx = first_ctx
|
||||
self.fullgraph = fullgraph
|
||||
self.error_on_graph_break = error_on_graph_break
|
||||
self.export = export
|
||||
self._dynamic = dynamic
|
||||
@ -705,7 +707,7 @@ class _TorchDynamoContext:
|
||||
def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any:
|
||||
from torch._dynamo.aot_compile import aot_compile_fullgraph
|
||||
|
||||
if not self.error_on_graph_break:
|
||||
if not self.fullgraph:
|
||||
raise RuntimeError(
|
||||
"Graph breaks are not supported with aot compile. Please use torch.compile(fullgraph=True)."
|
||||
)
|
||||
@ -810,7 +812,7 @@ class _TorchDynamoContext:
|
||||
_is_skip_guard_eval_unsafe_stance()
|
||||
)
|
||||
prior_error_on_graph_break = None
|
||||
if self.error_on_graph_break is not None:
|
||||
if not self.fullgraph and self.error_on_graph_break is not None:
|
||||
prior_error_on_graph_break = _get_error_on_graph_break()
|
||||
_set_error_on_graph_break(self.error_on_graph_break)
|
||||
|
||||
@ -857,11 +859,14 @@ class _TorchDynamoContext:
|
||||
_maybe_set_eval_frame(prior)
|
||||
|
||||
# hooks to properly handle inlining
|
||||
compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined]
|
||||
external_utils.wrap_inline_with_error_on_graph_break(
|
||||
fn, self.error_on_graph_break
|
||||
if self.error_on_graph_break is not None:
|
||||
compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined]
|
||||
external_utils.wrap_inline_with_error_on_graph_break(
|
||||
fn, self.error_on_graph_break
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
compile_wrapper._torchdynamo_inline = fn # type: ignore[attr-defined]
|
||||
|
||||
# Save the function pointer to find the original callable while nesting
|
||||
# of decorators.
|
||||
@ -923,7 +928,8 @@ class OptimizeContext(_TorchDynamoContext):
|
||||
backend_ctx_ctor: Callable[[], contextlib.AbstractContextManager[Any]],
|
||||
first_ctx: bool = False,
|
||||
*,
|
||||
error_on_graph_break: bool = False,
|
||||
fullgraph: bool = False,
|
||||
error_on_graph_break: Optional[bool] = None,
|
||||
export: bool = False,
|
||||
dynamic: Optional[bool] = None,
|
||||
compiler_config: Optional[Any] = None,
|
||||
@ -942,6 +948,7 @@ class OptimizeContext(_TorchDynamoContext):
|
||||
backend_ctx_ctor=backend_ctx_ctor,
|
||||
patch_fn=TorchPatcher.patch,
|
||||
first_ctx=first_ctx,
|
||||
fullgraph=fullgraph,
|
||||
error_on_graph_break=error_on_graph_break,
|
||||
export=export,
|
||||
dynamic=dynamic,
|
||||
@ -1067,7 +1074,8 @@ def _optimize_catch_errors(
|
||||
backend_ctx_ctor: Callable[
|
||||
[], contextlib.AbstractContextManager[Any]
|
||||
] = null_context,
|
||||
error_on_graph_break: bool = False,
|
||||
fullgraph: bool = False,
|
||||
error_on_graph_break: Optional[bool] = None,
|
||||
export: bool = False,
|
||||
dynamic: Optional[bool] = None,
|
||||
compiler_config: Optional[Any] = None,
|
||||
@ -1078,6 +1086,7 @@ def _optimize_catch_errors(
|
||||
convert_frame.catch_errors_wrapper(compile_fn, hooks),
|
||||
backend_ctx_ctor=backend_ctx_ctor,
|
||||
first_ctx=True,
|
||||
fullgraph=fullgraph,
|
||||
error_on_graph_break=error_on_graph_break,
|
||||
export=export,
|
||||
dynamic=dynamic,
|
||||
@ -1176,6 +1185,7 @@ def _optimize(
|
||||
backend: Union[str, Callable[..., Any]] = "inductor",
|
||||
*,
|
||||
nopython: bool = False,
|
||||
error_on_graph_break: Optional[bool] = None,
|
||||
guard_export_fn: Optional[Callable[[_guards.GuardsSet], None]] = None,
|
||||
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
|
||||
guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None,
|
||||
@ -1198,6 +1208,11 @@ def _optimize(
|
||||
- Or, a string backend name in `torch._dynamo.list_backends()`
|
||||
nopython: If True, graph breaks will be errors and there will
|
||||
be a single whole-program graph.
|
||||
error_on_graph_break: If not None, the current `error_on_graph_break` setting is set to the given value.
|
||||
See `torch._dynamo.error_on_graph_break()` for more details on what `error_on_graph_break` means.
|
||||
|
||||
Unlike `nopython=True` (i.e. `fullgraph=True`), there is no guarantee of a single whole-program graph.
|
||||
If `nopython` is True, `error_on_graph_break` does nothing.
|
||||
disable: If True, turn this decorator into a no-op
|
||||
dynamic: If True, upfront compile as dynamic a kernel as possible. If False,
|
||||
disable all dynamic shapes support (always specialize). If None, automatically
|
||||
@ -1228,6 +1243,15 @@ def _optimize(
|
||||
):
|
||||
return _NullDecorator()
|
||||
|
||||
if nopython and not config.debug_force_graph_break_on_leaf_return:
|
||||
return optimize_assert(
|
||||
backend,
|
||||
dynamic=dynamic,
|
||||
hooks=hooks,
|
||||
rebuild_ctx=rebuild_ctx,
|
||||
package=package,
|
||||
)
|
||||
|
||||
backend = get_compiler_fn(backend)
|
||||
|
||||
# Find if backend has any extra context manager
|
||||
@ -1252,7 +1276,8 @@ def _optimize(
|
||||
),
|
||||
hooks,
|
||||
backend_ctx_ctor,
|
||||
error_on_graph_break=nopython
|
||||
fullgraph=False,
|
||||
error_on_graph_break=error_on_graph_break
|
||||
and not config.debug_force_graph_break_on_leaf_return,
|
||||
dynamic=dynamic,
|
||||
compiler_config=(
|
||||
@ -2174,10 +2199,11 @@ def _optimize_assert(
|
||||
package: Optional[CompilePackage] = None,
|
||||
) -> OptimizeContext:
|
||||
"""
|
||||
The same as `torch._dynamo.optimize(backend, nopython=True)`,
|
||||
but ignores symbolic_convert.error_on_graph_break setting.
|
||||
Guarantees single-graph capture.
|
||||
The same as `torch._dynamo.optimize(backend)` but ignores
|
||||
symbolic_convert.error_on_graph_break setting.
|
||||
|
||||
Used for export, since we must always error on graph breaks and ignore
|
||||
Used for fullgraph=True and export, since we must always error on graph breaks and ignore
|
||||
symbolic_convert.error_on_graph_break. Can also be used for testing.
|
||||
"""
|
||||
backend = get_compiler_fn(backend)
|
||||
@ -2204,6 +2230,7 @@ def _optimize_assert(
|
||||
),
|
||||
hooks,
|
||||
backend_ctx_ctor,
|
||||
fullgraph=True,
|
||||
export=export,
|
||||
dynamic=dynamic,
|
||||
rebuild_ctx=rebuild_ctx,
|
||||
|
@ -203,7 +203,7 @@ def wrap_dunder_call_ctx_manager(self: Any, func: Callable[_P, _R]) -> Callable[
|
||||
Apply self as a ctx manager around a call to func
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
# NOTE: do not functools.wraps(func) because we don't ever want this frame to be skipped!
|
||||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
@ -234,16 +234,15 @@ def wrap_inline_with_error_on_graph_break(
|
||||
) -> Callable[_P, _R]:
|
||||
# NB: need multiple definitions in order to prevent `fullgraph` from
|
||||
# being a freevar of wrapper
|
||||
# NOTE: do not functools.wraps(fn) because we don't ever want these wrappers to be skipped!
|
||||
if error_on_graph_break:
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
with torch._dynamo.error_on_graph_break(True):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
else:
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
return fn(*args, **kwargs)
|
||||
|
@ -3756,8 +3756,8 @@ class InstructionTranslatorBase(
|
||||
self.num_calls: dict[str, int] = {}
|
||||
# Flag to indicate whether tracing is used for export.
|
||||
self.export = export
|
||||
# NOTE: one_graph is used for export/debugging to always force errors on graph breaks.
|
||||
# To toggle fullgraph during normal compile, self.error_on_graph_break
|
||||
# NOTE: one_graph is used for export/fullgraph=True to always force errors on graph breaks.
|
||||
# To toggle erroring/resuming on graph breaks during fullgraph=False compile, self.error_on_graph_break
|
||||
# is used instead. Every step(), its value is updated to the global tls.error_on_graph_break.
|
||||
# We mirror this value since cleanup may (correctly) inadvertently change tls.error_on_graph_break.
|
||||
# This assumes that we cannot both trace a change to tls.error_on_graph_break and graph break on
|
||||
|
@ -170,7 +170,7 @@ class CPythonTestCase(TestCase):
|
||||
# We want to compile only the test function, excluding any setup code
|
||||
# from unittest
|
||||
method = getattr(self, self._testMethodName)
|
||||
method = torch._dynamo.optimize(backend, nopython=nopython)(method)
|
||||
method = torch._dynamo.optimize(backend, error_on_graph_break=nopython)(method)
|
||||
setattr(self, self._testMethodName, method)
|
||||
return fn
|
||||
|
||||
|
Reference in New Issue
Block a user