mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo, docs] add fullgraph=False docs (#159050)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159050 Approved by: https://github.com/svekars, https://github.com/anijain2305 ghstack dependencies: #157985, #158055, #158531
This commit is contained in:
committed by
PyTorch MergeBot
parent
f916f34739
commit
ffccb90ff4
75
docs/source/compile/programming_model.compiler_disable.md
Normal file
75
docs/source/compile/programming_model.compiler_disable.md
Normal file
@ -0,0 +1,75 @@
|
||||
---
|
||||
file_format: mystnb
|
||||
kernelspec:
|
||||
name: python3
|
||||
mystnb:
|
||||
execution_timeout: 30
|
||||
execution_show_tb: True
|
||||
merge_streams: True
|
||||
---
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import torch
|
||||
|
||||
import header_code
|
||||
|
||||
torch._logging.set_logs(graph_breaks=True, graph_code=True)
|
||||
```
|
||||
|
||||
# Disabling and Suppressing Errors
|
||||
For some model architectures, there are portions of the model which are particularly difficult to compile -
|
||||
either there are many graph breaks, or there are crashes.
|
||||
You may want to explicitly disable these portions of the model which are problematic so that you can apply
|
||||
`torch.compile` to the parts that work. You can do this by using the `@torch.compiler.disable` decorator.
|
||||
When `torch.compile` attempts to call a disabled function, it breaks the graph and skips tracing the disabled function,
|
||||
resuming tracing after the call. By default, all recursive calls made from a disabled function are also disabled.
|
||||
Use the `recursive=False` option to allow compilation for recursive calls.
|
||||
|
||||
```{code-cell}
|
||||
def inner1(x):
|
||||
torch._dynamo.graph_break() # not traced
|
||||
return x + 1 # not traced
|
||||
|
||||
@torch.compiler.disable
|
||||
def outer1(x):
|
||||
x = x + 2 # not traced
|
||||
torch._dynamo.graph_break() # not traced
|
||||
return inner1(x)
|
||||
|
||||
@torch.compile
|
||||
def f(x):
|
||||
x = outer1(x)
|
||||
return x + 4 # traced
|
||||
|
||||
print(f(torch.ones(3)))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def inner2(x):
|
||||
torch._dynamo.graph_break() # traced
|
||||
return x + 1 # traced
|
||||
|
||||
@torch.compiler.disable(recursive=False)
|
||||
def outer2(x):
|
||||
x = x + 2 # not traced
|
||||
torch._dynamo.graph_break() # not traced
|
||||
return inner2(x)
|
||||
|
||||
@torch.compile
|
||||
def g(x):
|
||||
x = outer2(x)
|
||||
return x + 4 # traced
|
||||
|
||||
print(g(torch.ones(3)))
|
||||
```
|
||||
|
||||
For example, one can use `torch.compiler.disable` to disable `torch.compile` on sparse architecture in
|
||||
recommendation models, as the sparse arch is difficult to compile.
|
||||
Preprocessing and logging functions are other examples of functions that typically cause
|
||||
a lot of graph breaks and do not get value from being compiled.
|
||||
|
||||
If you are experiencing compiler crashes and you want to continue regardless,
|
||||
you can set `torch._dynamo.config.suppress_errors = True`.
|
||||
When the compiler crashes, we will just skip tracing the function and try again later.
|
||||
**This is not best practice** - it is better to eventually manually add `disable` annotations as necessary.
|
@ -41,6 +41,8 @@ For example, for the function `f` in the above diagram, Dynamo produces:
|
||||
- **guards** that specify the conditions under which the graph and bytecode are valid. Unless otherwise specified,
|
||||
the graph produced by Dynamo specializes on the shapes of input Tensors.
|
||||
|
||||
(programming_model.dynamo_core_concepts.graph_breaks)=
|
||||
|
||||
## Graph Breaks
|
||||
Dynamo traces your code and attempts to capture your PyTorch code into a single computation graph of PyTorch
|
||||
operators (FX graph). However, this is not always possible. When encountering code that can't be traced, a "**graph break**" occurs.
|
||||
|
24
docs/source/compile/programming_model.fullgraph_false.md
Normal file
24
docs/source/compile/programming_model.fullgraph_false.md
Normal file
@ -0,0 +1,24 @@
|
||||
# Working with `fullgraph=False`
|
||||
While `fullgraph=False` is the default `torch.compile` setting, the semantics of resuming compilation upon encountering a graph break are more complicated.
|
||||
You can find details on the `fullgraph=False` semantics in the subsections.
|
||||
|
||||
The strategy for using `torch.compile(fullgraph=False)` is as follows:
|
||||
|
||||
1. [Determine the ideal location to place `torch.compile`](programming_model.where_to_apply_compile). Normally, it is the highest-level function that doesn’t result in excessive graph breaks.
|
||||
Functions that do a lot of preprocessing or I/O operations are examples of functions that result in many graph breaks and do not significantly benefit from `torch.compile`.
|
||||
a. You can isolate issues by first compiling individual functions/modules before compiling entire models.
|
||||
2. [Apply `torch.compiler.disable` to functions in the compiled region that result in a lot of graph breaks
|
||||
and do not benefit from compilation](programming_model.compiler_disable). In this case, one graph break is better than potentially tens or hundreds.
|
||||
3. Use `TORCH_LOGS="graph_breaks"` or tlparse to investigate remaining graph breaks. <!-- TODO link -->
|
||||
Work around these graph breaks using the same approaches as working around graph breaks under
|
||||
the `fullgraph=True` programming model. Not all graph breaks need to be removed - some may
|
||||
impact performance more than others. The general rule is to focus on graph breaks that are happening during model computation.
|
||||
a. We recommend using `torch.compile(backend='eager')` when debugging graph breaks, for faster debugging iteration times
|
||||
|
||||
|
||||
```{toctree}
|
||||
programming_model.where_to_apply_compile
|
||||
programming_model.compiler_disable
|
||||
programming_model.nested_graph_breaks
|
||||
programming_model.skipped_functions
|
||||
```
|
@ -17,4 +17,5 @@ programming_model.fullgraph_true
|
||||
programming_model.common_graph_breaks
|
||||
programming_model.dynamo_nonstrict_trace
|
||||
programming_model.custom_ops
|
||||
programming_model.fullgraph_false
|
||||
```
|
||||
|
191
docs/source/compile/programming_model.nested_graph_breaks.md
Normal file
191
docs/source/compile/programming_model.nested_graph_breaks.md
Normal file
@ -0,0 +1,191 @@
|
||||
# Nested Graph Breaks
|
||||
|
||||
Summary:
|
||||
- Graph breaks in nested functions can result in hard-to-understand compiler behavior, which we document below
|
||||
- A nested graph break results in {math}`\mathcal O(N)` duplicate graph break behavior
|
||||
|
||||
Recall that when `torch.compile` is applied to a function, any nested function calls are also traced.
|
||||
A **nested graph break** refers to any graph break that happens in a nested function call.
|
||||
|
||||
```python
|
||||
def inner(x):
|
||||
...
|
||||
torch._dynamo.graph_break() # nested graph break
|
||||
...
|
||||
|
||||
@torch.compile
|
||||
def outer(x):
|
||||
...
|
||||
y = inner(x)
|
||||
...
|
||||
```
|
||||
|
||||
The resumption semantics around nested graph breaks can be confusing, so we describe the behavior here.
|
||||
|
||||
Recall that in `fullgraph=False`, [graph breaks are handled](programming_model.dynamo_core_concepts.graph_breaks) by compiling the FX graph that has been determined so far,
|
||||
running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph.
|
||||
Resuming a function is actually a fairly complicated technical feat, so resuming tracing is only supported on top-level functions.
|
||||
|
||||
We can therefore resume tracing after a nested graph break with this restriction in the following way:
|
||||
|
||||
First, consider the below example where `torch.compile` traces from `f` and traces all the way until the
|
||||
graph break in `inner1` is encountered.
|
||||
|
||||
```python
|
||||
def inner1(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break() # stop tracing due to graph break
|
||||
return x + 2
|
||||
|
||||
def inner2(x):
|
||||
x = x + 4
|
||||
x = inner1(x)
|
||||
x = x + 8
|
||||
|
||||
@torch.compile
|
||||
def f(x):
|
||||
# start tracing from here
|
||||
x = x + 16
|
||||
x = inner2(x)
|
||||
x = x + 32
|
||||
|
||||
f(torch.randn(3))
|
||||
```
|
||||
|
||||
Since we can only resume from top-level functions, we graph break on the `inner2` call in `f`.
|
||||
```python
|
||||
# The semantics of torch.compile(f)(x) is roughly this:
|
||||
def compiled_f_semantics(x):
|
||||
y = x + 16
|
||||
z = inner2(y)
|
||||
return torch.compile(resume_f_semantics)(z)
|
||||
|
||||
def resume_f_semantics(x):
|
||||
return x + 32
|
||||
|
||||
compiled_f_semantics(torch.randn(3))
|
||||
```
|
||||
|
||||
`inner2` is then automatically compiled as a top-level function.
|
||||
We trace all the way until the graph break in `inner1` is encountered again.
|
||||
|
||||
```python
|
||||
def inner1(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break() # stop tracing due to graph break
|
||||
return x + 2
|
||||
|
||||
# this torch.compile is automatically applied
|
||||
@torch.compile
|
||||
def inner2(x):
|
||||
# start tracing from here
|
||||
x = x + 4
|
||||
x = inner1(x)
|
||||
x = x + 8
|
||||
|
||||
def compiled_f_semantics(x):
|
||||
y = x + 16
|
||||
z = inner2(y)
|
||||
return torch.compile(resume_f_semantics)(z)
|
||||
|
||||
def resume_f_semantics(x):
|
||||
return x + 32
|
||||
|
||||
compiled_f_semantics(torch.randn(3))
|
||||
```
|
||||
|
||||
Then we graph break on the `inner1` call in `inner2`.
|
||||
```python
|
||||
def compiled_inner2_semantics(x):
|
||||
y = x + 4
|
||||
z = inner1(y)
|
||||
return torch.compile(resume_inner2_semantics)(z)
|
||||
|
||||
def resume_inner2_semantics(x):
|
||||
return x + 8
|
||||
```
|
||||
|
||||
`inner1` is then automatically compiled as a top-level function.
|
||||
The graph break is from `inner1`, so we handle the graph break normally.
|
||||
```python
|
||||
# this torch.compile is automatically applied
|
||||
@torch.compile
|
||||
def inner1(x):
|
||||
# start tracing from here
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break() # stop tracing due to graph break
|
||||
return x + 2
|
||||
|
||||
def compiled_f_semantics(x):
|
||||
y = x + 16
|
||||
z = compiled_inner2_semantics(y)
|
||||
return torch.compile(resume_f_semantics)(z)
|
||||
|
||||
def resume_f_semantics(x):
|
||||
return x + 32
|
||||
|
||||
def compiled_inner2_semantics(x):
|
||||
y = x + 4
|
||||
z = inner1(y)
|
||||
return torch.compile(resume_inner2_semantics)(z)
|
||||
|
||||
def resume_inner2_semantics(x):
|
||||
return x + 8
|
||||
|
||||
compiled_f_semantics(torch.randn(3))
|
||||
```
|
||||
|
||||
`inner1` is handled normally:
|
||||
|
||||
```python
|
||||
def compiled_inner1_semantics(x):
|
||||
y = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return torch.compile(resume_inner1_semantics)(y)
|
||||
|
||||
def resume_inner1_semantics(x):
|
||||
return x + 2
|
||||
```
|
||||
|
||||
So the initial code is semantically equivalent to
|
||||
```python
|
||||
def compiled_f_semantics(x):
|
||||
y = x + 16
|
||||
z = compiled_inner2_semantics(y)
|
||||
return torch.compile(resume_f_semantics)(z)
|
||||
|
||||
def resume_f_semantics(x):
|
||||
return x + 32
|
||||
|
||||
def compiled_inner2_semantics(x):
|
||||
y = x + 4
|
||||
z = compiled_inner1_semantics(y)
|
||||
return torch.compile(resume_inner2_semantics)(z)
|
||||
|
||||
def resume_inner2_semantics(x):
|
||||
return x + 8
|
||||
|
||||
def compiled_inner1_semantics(x):
|
||||
y = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return torch.compile(resume_inner1_semantics)(y)
|
||||
|
||||
def resume_inner1_semantics(x):
|
||||
return x + 2
|
||||
|
||||
compiled_f_semantics(torch.randn(3))
|
||||
```
|
||||
|
||||
Note in particular that we traced 3 top-level functions, and that we traced the same graph break 3 times.
|
||||
**This explains why you may encounter duplicate graph breaks when using `torch.compile`.**
|
||||
|
||||
In summary, nested graph breaks are handled by:
|
||||
- Tracing from the top-level function all the way to the nested graph break
|
||||
- Graph breaking on the top-level function at the call to the second-level function
|
||||
- Compiling the PyTorch ops tracked so far and running the compiled graph
|
||||
- Calling the second-level function, which gets automatically compiled as a top-level function
|
||||
- Resuming tracing after the second-level function call
|
||||
|
||||
Note that the runtime of handling this graph break is {math}`\mathcal O(NK)`, where {math}`N` is the nesting depth,
|
||||
and {math}`K` is the number of instructions from the top-level function to the graph break.
|
||||
We end up tracing {math}`\mathcal O(N^2)` frames, and we trace the same graph break {math}`\mathcal O(N)` times.
|
199
docs/source/compile/programming_model.skipped_functions.md
Normal file
199
docs/source/compile/programming_model.skipped_functions.md
Normal file
@ -0,0 +1,199 @@
|
||||
---
|
||||
file_format: mystnb
|
||||
kernelspec:
|
||||
name: python3
|
||||
mystnb:
|
||||
execution_timeout: 30
|
||||
execution_show_tb: True
|
||||
merge_streams: True
|
||||
---
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import torch
|
||||
|
||||
import header_code
|
||||
import logging
|
||||
torch._logging.set_logs(dynamo=logging.DEBUG)
|
||||
```
|
||||
|
||||
# Skipped Functions
|
||||
|
||||
**Summary:**
|
||||
- Sometimes, `torch.compile` completely gives up compiling a function and runs it eagerly instead,
|
||||
resulting in potentially lost optimization opportunities.
|
||||
- There are ways to work around skipped functions in order to re-enable tracing around the problematic code.
|
||||
|
||||
Sometimes, `torch.compile` with `fullgraph=False` is unable to resume tracing when encountering a graph break
|
||||
or other compiler error. In many of these cases, `torch.compile` will skip compiling the function entirely and run it eagerly.
|
||||
|
||||
Note that the skip is only applied to the current function and NOT any nested function calls.
|
||||
`torch.compile` will still attempt to compile nested calls.
|
||||
|
||||
<!-- TODO: fix logging for skipped functions. -->
|
||||
|
||||
```{code-cell}
|
||||
def inner1(x):
|
||||
return x + 1
|
||||
def inner2(x):
|
||||
return x + 2
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
x = inner1(x)
|
||||
torch._dynamo.skip_frame()
|
||||
x = inner2(x)
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
In the above example, `torch.compile` will trace `fn` (including `inner1`) up until the `skip_frame`.
|
||||
Then `fn` is skipped and run eagerly - `inner1` and `inner2` are compiled when they are called.
|
||||
|
||||
Skipping functions may result in lost optimization opportunities,
|
||||
so it is important to check if code you want compiled is being skipped, and if so, to work around the skip.
|
||||
|
||||
## Graph Break in a Loop
|
||||
|
||||
`torch.compile` cannot resume tracing if a graph break occurs in a loop:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
for i in range(5):
|
||||
x = x + 1
|
||||
if i == 3:
|
||||
torch._dynamo.graph_break()
|
||||
return x
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
In this example, we can avoid skipping by unrolling the loop:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
def inner(i):
|
||||
nonlocal x
|
||||
x = x + 1
|
||||
if i == 3:
|
||||
torch._dynamo.graph_break()
|
||||
inner(0)
|
||||
inner(1)
|
||||
inner(2)
|
||||
inner(3)
|
||||
inner(4)
|
||||
return x
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
In general, resolving the graph break causing the skip will also resolve the skip.
|
||||
|
||||
## Graph Break in a Context Manager
|
||||
|
||||
Another common example of an unresumable graph break is a graph break in most context managers:
|
||||
|
||||
```{code-cell}
|
||||
class CustomCtxManager:
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
with CustomCtxManager():
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
We can avoid skipping by moving the graph break outside of the context manager:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
with CustomCtxManager():
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
with CustomCtxManager():
|
||||
return x + 1
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
There are some context managers where Dynamo can resume after a graph break.
|
||||
Some of these can be found in `supported_ctx_manager_classes` in `torch/_dynamo/variables/torch.py`.
|
||||
In general, any context manager represented by a `ContextWrappingVariable` subclass in
|
||||
`torch/_dynamo/variables/ctx_manager.py` support resuming after a graph break. For example:
|
||||
|
||||
```{code-cell}
|
||||
import contextlib
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
with contextlib.nullcontext():
|
||||
with torch.no_grad():
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
## Graph Break in a Try Block
|
||||
|
||||
A graph break in a try block cannot be resumed:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
try:
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
except Exception as e:
|
||||
pass
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
We can avoid skipping by moving the graph break outside of the try block:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
try:
|
||||
x = x + 1
|
||||
except Exception as e:
|
||||
pass
|
||||
torch._dynamo.graph_break()
|
||||
try:
|
||||
return x + 1
|
||||
except Exception as e:
|
||||
pass
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
<!-- ## Hitting a Recompilation Limit
|
||||
See Changing the Cache Size Limit. (TODO: link) -->
|
||||
|
||||
## Compiler Errors
|
||||
Some compiler errors will result in skipped functions.
|
||||
Other compiler errors will result in a hard error rather than a skipped function.
|
||||
|
||||
## Dealing with Skipped Functions
|
||||
In general, you can resolve a skipped function by fixing the underlying graph break or error that
|
||||
is causing the function to be skipped.
|
||||
|
||||
If the graph break/error causing the skipped function is difficult to fix,
|
||||
then consider isolating the graph break/error in its own function so that minimal things are skipped.
|
||||
|
||||
```{code-cell}
|
||||
def inner1(x):
|
||||
return x + 1
|
||||
def inner2(x):
|
||||
return x + 2
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
x = inner1(x)
|
||||
def problematic_code():
|
||||
torch._dynamo.skip_frame()
|
||||
problematic_code()
|
||||
x = inner2(x)
|
||||
fn(torch.randn(3))
|
||||
```
|
@ -0,0 +1,77 @@
|
||||
# Where to apply torch.compile?
|
||||
|
||||
We recommend applying `torch.compile` to the highest-level function that doesn’t cause excessive problems.
|
||||
Typically, it is:
|
||||
- your `train` or `eval` step with the optimizer but without the loop,
|
||||
- your top-level `nn.Module`
|
||||
- or some sub-`nn.Module`s.
|
||||
|
||||
`torch.compile` specifically doesn’t handle distributed wrapper modules like DDP or FSDP very well,
|
||||
so consider applying `torch.compile` to the inner module passed to the wrapper.
|
||||
|
||||
```python
|
||||
# inference
|
||||
model = ...
|
||||
model.compile()
|
||||
|
||||
for _ in range(N_ITERS):
|
||||
inp = ...
|
||||
out = model(inp)
|
||||
```
|
||||
|
||||
```python
|
||||
# training
|
||||
model = ...
|
||||
opt = torch.optim.Adam(model.parameters())
|
||||
|
||||
@torch.compile
|
||||
def train(mod, data):
|
||||
opt.zero_grad(True)
|
||||
pred = mod(data[0])
|
||||
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
for _ in range(N_ITERS):
|
||||
inp = ...
|
||||
train(model, inp)
|
||||
```
|
||||
|
||||
```python
|
||||
# DistributedDataParallel
|
||||
model = ...
|
||||
model.compile()
|
||||
model_ddp = DistributedDataParallel(model, ...)
|
||||
|
||||
for _ in range(N_ITERS):
|
||||
inp = ...
|
||||
out = model_ddp(inp)
|
||||
```
|
||||
|
||||
<!-- TODO add examples for specific model domains, compile(model) vs. model.compile()-->
|
||||
|
||||
## `compile(model)` vs `model.compile()`
|
||||
|
||||
Due to nuances to how `torch.compile` interacts with `nn.Module` instances,
|
||||
we advise using the `.compile()` method of `nn.Module` instances if you wish to compile them as
|
||||
top-level functions. Nested module calls will be traced correctly -
|
||||
there is no need to call `.compile()` in that case.
|
||||
|
||||
```python
|
||||
# DO NOT DO THIS
|
||||
model = MyModel()
|
||||
model = torch.compile(model)
|
||||
model(inp)
|
||||
|
||||
# DO THIS
|
||||
model = MyModel()
|
||||
model.compile()
|
||||
model(inp)
|
||||
|
||||
# this is also acceptable
|
||||
@torch.compile
|
||||
def fn(model, inp):
|
||||
return model(inp)
|
||||
model = MyModel()
|
||||
fn(model, inp)
|
||||
```
|
Reference in New Issue
Block a user