[dynamo, docs] add fullgraph=False docs (#159050)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159050
Approved by: https://github.com/svekars, https://github.com/anijain2305
ghstack dependencies: #157985, #158055, #158531
This commit is contained in:
William Wen
2025-07-28 11:58:07 -07:00
committed by PyTorch MergeBot
parent f916f34739
commit ffccb90ff4
7 changed files with 569 additions and 0 deletions

View File

@ -0,0 +1,75 @@
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
torch._logging.set_logs(graph_breaks=True, graph_code=True)
```
# Disabling and Suppressing Errors
For some model architectures, there are portions of the model which are particularly difficult to compile -
either there are many graph breaks, or there are crashes.
You may want to explicitly disable these portions of the model which are problematic so that you can apply
`torch.compile` to the parts that work. You can do this by using the `@torch.compiler.disable` decorator.
When `torch.compile` attempts to call a disabled function, it breaks the graph and skips tracing the disabled function,
resuming tracing after the call. By default, all recursive calls made from a disabled function are also disabled.
Use the `recursive=False` option to allow compilation for recursive calls.
```{code-cell}
def inner1(x):
torch._dynamo.graph_break() # not traced
return x + 1 # not traced
@torch.compiler.disable
def outer1(x):
x = x + 2 # not traced
torch._dynamo.graph_break() # not traced
return inner1(x)
@torch.compile
def f(x):
x = outer1(x)
return x + 4 # traced
print(f(torch.ones(3)))
```
```{code-cell}
def inner2(x):
torch._dynamo.graph_break() # traced
return x + 1 # traced
@torch.compiler.disable(recursive=False)
def outer2(x):
x = x + 2 # not traced
torch._dynamo.graph_break() # not traced
return inner2(x)
@torch.compile
def g(x):
x = outer2(x)
return x + 4 # traced
print(g(torch.ones(3)))
```
For example, one can use `torch.compiler.disable` to disable `torch.compile` on sparse architecture in
recommendation models, as the sparse arch is difficult to compile.
Preprocessing and logging functions are other examples of functions that typically cause
a lot of graph breaks and do not get value from being compiled.
If you are experiencing compiler crashes and you want to continue regardless,
you can set `torch._dynamo.config.suppress_errors = True`.
When the compiler crashes, we will just skip tracing the function and try again later.
**This is not best practice** - it is better to eventually manually add `disable` annotations as necessary.

View File

@ -41,6 +41,8 @@ For example, for the function `f` in the above diagram, Dynamo produces:
- **guards** that specify the conditions under which the graph and bytecode are valid. Unless otherwise specified,
the graph produced by Dynamo specializes on the shapes of input Tensors.
(programming_model.dynamo_core_concepts.graph_breaks)=
## Graph Breaks
Dynamo traces your code and attempts to capture your PyTorch code into a single computation graph of PyTorch
operators (FX graph). However, this is not always possible. When encountering code that can't be traced, a "**graph break**" occurs.

View File

@ -0,0 +1,24 @@
# Working with `fullgraph=False`
While `fullgraph=False` is the default `torch.compile` setting, the semantics of resuming compilation upon encountering a graph break are more complicated.
You can find details on the `fullgraph=False` semantics in the subsections.
The strategy for using `torch.compile(fullgraph=False)` is as follows:
1. [Determine the ideal location to place `torch.compile`](programming_model.where_to_apply_compile). Normally, it is the highest-level function that doesnt result in excessive graph breaks.
Functions that do a lot of preprocessing or I/O operations are examples of functions that result in many graph breaks and do not significantly benefit from `torch.compile`.
a. You can isolate issues by first compiling individual functions/modules before compiling entire models.
2. [Apply `torch.compiler.disable` to functions in the compiled region that result in a lot of graph breaks
and do not benefit from compilation](programming_model.compiler_disable). In this case, one graph break is better than potentially tens or hundreds.
3. Use `TORCH_LOGS="graph_breaks"` or tlparse to investigate remaining graph breaks. <!-- TODO link -->
Work around these graph breaks using the same approaches as working around graph breaks under
the `fullgraph=True` programming model. Not all graph breaks need to be removed - some may
impact performance more than others. The general rule is to focus on graph breaks that are happening during model computation.
a. We recommend using `torch.compile(backend='eager')` when debugging graph breaks, for faster debugging iteration times
```{toctree}
programming_model.where_to_apply_compile
programming_model.compiler_disable
programming_model.nested_graph_breaks
programming_model.skipped_functions
```

View File

@ -17,4 +17,5 @@ programming_model.fullgraph_true
programming_model.common_graph_breaks
programming_model.dynamo_nonstrict_trace
programming_model.custom_ops
programming_model.fullgraph_false
```

View File

@ -0,0 +1,191 @@
# Nested Graph Breaks
Summary:
- Graph breaks in nested functions can result in hard-to-understand compiler behavior, which we document below
- A nested graph break results in {math}`\mathcal O(N)` duplicate graph break behavior
Recall that when `torch.compile` is applied to a function, any nested function calls are also traced.
A **nested graph break** refers to any graph break that happens in a nested function call.
```python
def inner(x):
...
torch._dynamo.graph_break() # nested graph break
...
@torch.compile
def outer(x):
...
y = inner(x)
...
```
The resumption semantics around nested graph breaks can be confusing, so we describe the behavior here.
Recall that in `fullgraph=False`, [graph breaks are handled](programming_model.dynamo_core_concepts.graph_breaks) by compiling the FX graph that has been determined so far,
running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph.
Resuming a function is actually a fairly complicated technical feat, so resuming tracing is only supported on top-level functions.
We can therefore resume tracing after a nested graph break with this restriction in the following way:
First, consider the below example where `torch.compile` traces from `f` and traces all the way until the
graph break in `inner1` is encountered.
```python
def inner1(x):
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
def inner2(x):
x = x + 4
x = inner1(x)
x = x + 8
@torch.compile
def f(x):
# start tracing from here
x = x + 16
x = inner2(x)
x = x + 32
f(torch.randn(3))
```
Since we can only resume from top-level functions, we graph break on the `inner2` call in `f`.
```python
# The semantics of torch.compile(f)(x) is roughly this:
def compiled_f_semantics(x):
y = x + 16
z = inner2(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
compiled_f_semantics(torch.randn(3))
```
`inner2` is then automatically compiled as a top-level function.
We trace all the way until the graph break in `inner1` is encountered again.
```python
def inner1(x):
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
# this torch.compile is automatically applied
@torch.compile
def inner2(x):
# start tracing from here
x = x + 4
x = inner1(x)
x = x + 8
def compiled_f_semantics(x):
y = x + 16
z = inner2(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
compiled_f_semantics(torch.randn(3))
```
Then we graph break on the `inner1` call in `inner2`.
```python
def compiled_inner2_semantics(x):
y = x + 4
z = inner1(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
```
`inner1` is then automatically compiled as a top-level function.
The graph break is from `inner1`, so we handle the graph break normally.
```python
# this torch.compile is automatically applied
@torch.compile
def inner1(x):
# start tracing from here
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
def compiled_f_semantics(x):
y = x + 16
z = compiled_inner2_semantics(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
def compiled_inner2_semantics(x):
y = x + 4
z = inner1(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
compiled_f_semantics(torch.randn(3))
```
`inner1` is handled normally:
```python
def compiled_inner1_semantics(x):
y = x + 1
torch._dynamo.graph_break()
return torch.compile(resume_inner1_semantics)(y)
def resume_inner1_semantics(x):
return x + 2
```
So the initial code is semantically equivalent to
```python
def compiled_f_semantics(x):
y = x + 16
z = compiled_inner2_semantics(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
def compiled_inner2_semantics(x):
y = x + 4
z = compiled_inner1_semantics(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
def compiled_inner1_semantics(x):
y = x + 1
torch._dynamo.graph_break()
return torch.compile(resume_inner1_semantics)(y)
def resume_inner1_semantics(x):
return x + 2
compiled_f_semantics(torch.randn(3))
```
Note in particular that we traced 3 top-level functions, and that we traced the same graph break 3 times.
**This explains why you may encounter duplicate graph breaks when using `torch.compile`.**
In summary, nested graph breaks are handled by:
- Tracing from the top-level function all the way to the nested graph break
- Graph breaking on the top-level function at the call to the second-level function
- Compiling the PyTorch ops tracked so far and running the compiled graph
- Calling the second-level function, which gets automatically compiled as a top-level function
- Resuming tracing after the second-level function call
Note that the runtime of handling this graph break is {math}`\mathcal O(NK)`, where {math}`N` is the nesting depth,
and {math}`K` is the number of instructions from the top-level function to the graph break.
We end up tracing {math}`\mathcal O(N^2)` frames, and we trace the same graph break {math}`\mathcal O(N)` times.

View File

@ -0,0 +1,199 @@
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
import logging
torch._logging.set_logs(dynamo=logging.DEBUG)
```
# Skipped Functions
**Summary:**
- Sometimes, `torch.compile` completely gives up compiling a function and runs it eagerly instead,
resulting in potentially lost optimization opportunities.
- There are ways to work around skipped functions in order to re-enable tracing around the problematic code.
Sometimes, `torch.compile` with `fullgraph=False` is unable to resume tracing when encountering a graph break
or other compiler error. In many of these cases, `torch.compile` will skip compiling the function entirely and run it eagerly.
Note that the skip is only applied to the current function and NOT any nested function calls.
`torch.compile` will still attempt to compile nested calls.
<!-- TODO: fix logging for skipped functions. -->
```{code-cell}
def inner1(x):
return x + 1
def inner2(x):
return x + 2
@torch.compile
def fn(x):
x = inner1(x)
torch._dynamo.skip_frame()
x = inner2(x)
fn(torch.randn(3))
```
In the above example, `torch.compile` will trace `fn` (including `inner1`) up until the `skip_frame`.
Then `fn` is skipped and run eagerly - `inner1` and `inner2` are compiled when they are called.
Skipping functions may result in lost optimization opportunities,
so it is important to check if code you want compiled is being skipped, and if so, to work around the skip.
## Graph Break in a Loop
`torch.compile` cannot resume tracing if a graph break occurs in a loop:
```{code-cell}
@torch.compile
def fn(x):
for i in range(5):
x = x + 1
if i == 3:
torch._dynamo.graph_break()
return x
fn(torch.randn(3))
```
In this example, we can avoid skipping by unrolling the loop:
```{code-cell}
@torch.compile
def fn(x):
def inner(i):
nonlocal x
x = x + 1
if i == 3:
torch._dynamo.graph_break()
inner(0)
inner(1)
inner(2)
inner(3)
inner(4)
return x
fn(torch.randn(3))
```
In general, resolving the graph break causing the skip will also resolve the skip.
## Graph Break in a Context Manager
Another common example of an unresumable graph break is a graph break in most context managers:
```{code-cell}
class CustomCtxManager:
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
@torch.compile
def fn(x):
with CustomCtxManager():
x = x + 1
torch._dynamo.graph_break()
return x + 1
fn(torch.randn(3))
```
We can avoid skipping by moving the graph break outside of the context manager:
```{code-cell}
@torch.compile
def fn(x):
with CustomCtxManager():
x = x + 1
torch._dynamo.graph_break()
with CustomCtxManager():
return x + 1
fn(torch.randn(3))
```
There are some context managers where Dynamo can resume after a graph break.
Some of these can be found in `supported_ctx_manager_classes` in `torch/_dynamo/variables/torch.py`.
In general, any context manager represented by a `ContextWrappingVariable` subclass in
`torch/_dynamo/variables/ctx_manager.py` support resuming after a graph break. For example:
```{code-cell}
import contextlib
@torch.compile
def fn(x):
with contextlib.nullcontext():
with torch.no_grad():
x = x + 1
torch._dynamo.graph_break()
return x + 1
fn(torch.randn(3))
```
## Graph Break in a Try Block
A graph break in a try block cannot be resumed:
```{code-cell}
@torch.compile
def fn(x):
try:
x = x + 1
torch._dynamo.graph_break()
return x + 1
except Exception as e:
pass
fn(torch.randn(3))
```
We can avoid skipping by moving the graph break outside of the try block:
```{code-cell}
@torch.compile
def fn(x):
try:
x = x + 1
except Exception as e:
pass
torch._dynamo.graph_break()
try:
return x + 1
except Exception as e:
pass
fn(torch.randn(3))
```
<!-- ## Hitting a Recompilation Limit
See Changing the Cache Size Limit. (TODO: link) -->
## Compiler Errors
Some compiler errors will result in skipped functions.
Other compiler errors will result in a hard error rather than a skipped function.
## Dealing with Skipped Functions
In general, you can resolve a skipped function by fixing the underlying graph break or error that
is causing the function to be skipped.
If the graph break/error causing the skipped function is difficult to fix,
then consider isolating the graph break/error in its own function so that minimal things are skipped.
```{code-cell}
def inner1(x):
return x + 1
def inner2(x):
return x + 2
@torch.compile
def fn(x):
x = inner1(x)
def problematic_code():
torch._dynamo.skip_frame()
problematic_code()
x = inner2(x)
fn(torch.randn(3))
```

View File

@ -0,0 +1,77 @@
# Where to apply torch.compile?
We recommend applying `torch.compile` to the highest-level function that doesnt cause excessive problems.
Typically, it is:
- your `train` or `eval` step with the optimizer but without the loop,
- your top-level `nn.Module`
- or some sub-`nn.Module`s.
`torch.compile` specifically doesnt handle distributed wrapper modules like DDP or FSDP very well,
so consider applying `torch.compile` to the inner module passed to the wrapper.
```python
# inference
model = ...
model.compile()
for _ in range(N_ITERS):
inp = ...
out = model(inp)
```
```python
# training
model = ...
opt = torch.optim.Adam(model.parameters())
@torch.compile
def train(mod, data):
opt.zero_grad(True)
pred = mod(data[0])
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
loss.backward()
opt.step()
for _ in range(N_ITERS):
inp = ...
train(model, inp)
```
```python
# DistributedDataParallel
model = ...
model.compile()
model_ddp = DistributedDataParallel(model, ...)
for _ in range(N_ITERS):
inp = ...
out = model_ddp(inp)
```
<!-- TODO add examples for specific model domains, compile(model) vs. model.compile()-->
## `compile(model)` vs `model.compile()`
Due to nuances to how `torch.compile` interacts with `nn.Module` instances,
we advise using the `.compile()` method of `nn.Module` instances if you wish to compile them as
top-level functions. Nested module calls will be traced correctly -
there is no need to call `.compile()` in that case.
```python
# DO NOT DO THIS
model = MyModel()
model = torch.compile(model)
model(inp)
# DO THIS
model = MyModel()
model.compile()
model(inp)
# this is also acceptable
@torch.compile
def fn(model, inp):
return model(inp)
model = MyModel()
fn(model, inp)
```