DOC: Convert to markdown: torch.compiler_best_practices_for_backends.rst, torch.compiler_cudagraph_trees.rst, torch.compiler_custom_backends.rst, torch.compiler_dynamic_shapes.rst, torch.compiler_dynamo_deepdive.rst (#155137)

Fixes #155037

[torch.compiler_best_practices_for_backends.rst](https://github.com/pytorch/pytorch/tree/main/docs/source/torch.compiler_best_practices_for_backends.rst) shows error 404

cc  @svekars @sekyondaMeta @AlannaBurke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155137
Approved by: https://github.com/svekars

Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
This commit is contained in:
Alberto A. Gallegos
2025-06-10 20:51:05 +00:00
committed by PyTorch MergeBot
parent 01b8f5e685
commit 8a396c5635
5 changed files with 867 additions and 934 deletions

View File

@ -1,15 +1,12 @@
CUDAGraph Trees
================
# CUDAGraph Trees
**Background**
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
## **Background**
CUDAGraph
--------------------
### CUDAGraph
For a longer background on CUDAGraphs, read `accelerating pytorch with CUDAGraphs <https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/>`_.
For a longer background on CUDAGraphs, read [accelerating pytorch with CUDAGraphs](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).
`CUDA Graphs <https://developer.nvidia.com/blog/cuda-10-features-revealed/>`_, which made its debut in CUDA 10, let a series of CUDA kernels to be defined and encapsulated as a single unit, i.e., a graph of operations, rather than a sequence of individually-launched operations. It provides a mechanism to launch multiple GPU operations through a single CPU operation, and hence reduces the launching overheads.
[CUDA Graphs](https://developer.nvidia.com/blog/cuda-10-features-revealed/), which made its debut in CUDA 10, let a series of CUDA kernels to be defined and encapsulated as a single unit, i.e., a graph of operations, rather than a sequence of individually-launched operations. It provides a mechanism to launch multiple GPU operations through a single CPU operation, and hence reduces the launching overheads.
CUDA Graphs can give large speedups, especially for models with high CPU overhead or small compute. There are a number of limitations from requiring the same kernels to be run with the same arguments and dependencies, and memory addresses.
@ -19,54 +16,49 @@ CUDA Graphs can give large speedups, especially for models with high CPU overhea
- CUDA Memory addresses are fixed, however the values of the memory at those addresses can change
- No Essential CPU ops or CPU side effects
PyTorch CUDAGraph Integration
-----------------------------
### PyTorch CUDAGraph Integration
PyTorch provides a `convenience wrapper <https://pytorch.org/docs/stable/generated/torch.cuda.CUDAGraph.html>`_ around CUDAGraphs that handles a couple of tricky interactions with PyTorchs caching allocator.
PyTorch provides a [convenience wrapper](https://pytorch.org/docs/stable/generated/torch.cuda.CUDAGraph.html) around CUDAGraphs that handles a couple of tricky interactions with PyTorchs caching allocator.
The CachingAllocator uses a separate memory pool for all the new allocations. During CUDAGraph recording, memory is accounted for, allocated, and freed exactly as during eager run. On replay, just the kernels are invoked, and there are no changes to the allocator. Subsequent to initial recording, the allocator does not know which memory is actively being used in user programs.
Using a separate memory pool between eager allocations and cudagraph allocations may increase the memory of your program if there is substantial memory allocated to both.
Make Graphed Callables
----------------------
### Make Graphed Callables
`Make Graphed Callables <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_ is a PyTorch Abstraction to share a single memory pool over a series of callables. Graphed Callables takes advantage of the fact that on CUDA Graph recording, memory is exactly accounted for by the caching allocator to safely share memory between separate CUDA Graph recordings. In each invocation, outputs are preserved as live memory, preventing one callable from overwriting the live memory of another. Graphed Callables can only be invoked in a single order; memory addresses from the first run are burned into the second, and so forth.
[Make Graphed Callables](https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html) is a PyTorch Abstraction to share a single memory pool over a series of callables. Graphed Callables takes advantage of the fact that on CUDA Graph recording, memory is exactly accounted for by the caching allocator to safely share memory between separate CUDA Graph recordings. In each invocation, outputs are preserved as live memory, preventing one callable from overwriting the live memory of another. Graphed Callables can only be invoked in a single order; memory addresses from the first run are burned into the second, and so forth.
TorchDynamo Previous CUDA Graphs Integration
--------------------------------------------
### TorchDynamo Previous CUDA Graphs Integration
Running with ``cudagraph_trees=False`` does not reuse memory across separate graph captures, which can lead to large memory regressions. Even for a model that has no graph breaks, this has issues. The forward and backward are separate graph captures, so the memory pools for forward and backward are not shared. In particular, memory for activations that are saved in the forward cannot be reclaimed in the backward.
Running with `cudagraph_trees=False` does not reuse memory across separate graph captures, which can lead to large memory regressions. Even for a model that has no graph breaks, this has issues. The forward and backward are separate graph captures, so the memory pools for forward and backward are not shared. In particular, memory for activations that are saved in the forward cannot be reclaimed in the backward.
**CUDAGraph Trees Integration**
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
## **CUDAGraph Trees Integration**
Like Graph Callables, CUDA Graph Trees use a single memory pool across all graph captures. However, instead of requiring a single sequence of invocations, CUDA Graph Trees create separate trees of CUDA Graph captures. Lets take a look at an illustrative example:
.. code-block:: python
@torch.compile(mode="reduce-overhead")
def foo(x):
# GRAPH 1
y = x * x * x
# graph break triggered here
if y.sum() > 0:
# GRAPH 2
z = y ** y
else:
# GRAPH 3
z = (y.abs() ** y.abs())
torch._dynamo.graph_break()
# GRAPH 4
return z * torch.rand_like(z)
# the first run warms up each graph, which does things like CuBlas or Triton benchmarking
foo(torch.arange(0, 10, device="cuda"))
# The second run does a CUDA Graph recording, and replays it
foo(torch.arange(0, 10, device="cuda"))
# Finally we hit the optimized, CUDA Graph replay path
foo(torch.arange(0, 10, device="cuda"))
```python
@torch.compile(mode="reduce-overhead")
def foo(x):
# GRAPH 1
y = x * x * x
# graph break triggered here
if y.sum() > 0:
# GRAPH 2
z = y ** y
else:
# GRAPH 3
z = (y.abs() ** y.abs())
torch._dynamo.graph_break()
# GRAPH 4
return z * torch.rand_like(z)
# the first run warms up each graph, which does things like CuBlas or Triton benchmarking
foo(torch.arange(0, 10, device="cuda"))
# The second run does a CUDA Graph recording, and replays it
foo(torch.arange(0, 10, device="cuda"))
# Finally we hit the optimized, CUDA Graph replay path
foo(torch.arange(0, 10, device="cuda"))
```
In this example, there are two separate paths that we make through the function: 1 -> 2 -> 4, or 1 -> 3 -> 4.
@ -85,27 +77,25 @@ First, we would hit the optimized, CUDAGraph.replay() path that we have already
The second time we hit graph 3 we are warmed up and ready to record. We record graph 3 and then record graph 4 again since the input memory addresses have changed. This creates a tree of CUDA Graph recordings. A CUDA Graph Tree!
::
```
1
/ \\
2 3
\\ \\
4 4
```
1
/ \\
2 3
\\ \\
4 4
Input Mutation Support
----------------------
### Input Mutation Support
Input mutation function refers to a function conducting in-place writes to an input tensor,
as illustrated below:
.. code-block:: python
def foo(x, y):
# mutates input x
x.add_(1)
return x + y
```python
def foo(x, y):
# mutates input x
x.add_(1)
return x + y
```
Input mutation functions generally lead to challenges for CUDAGraph Trees. Due to the static
CUDA memory address requirement from CUDAGraph, for each input tensor x, CUDAGraph Trees may
@ -116,13 +106,13 @@ x and x' reside on different CUDA memory addresses.
A closer look at input mutation functions reveals that there are three types of inputs:
* **inputs from eager**: These tensors we assume will vary input tensor addresses from
execution to execution. Because cudagraphs freeze memory addresses, we need to copy these
- **inputs from eager**: These tensors we assume will vary input tensor addresses from
execution to execution. Because cudagraphs freeze memory addresses, we need to copy these
inputs to a static address tensor prior to graph recording and execution.
* **Parameters and buffers**: These tensors we assume (and runtime-check) have the same tensor
- **Parameters and buffers**: These tensors we assume (and runtime-check) have the same tensor
addresses on every execution. We do not need to copy over their contents because the recorded
memory address will be the same as the executed memory address.
* **Tensors which are prior outputs from CUDAGraph Trees**: Because the output tensor addresses
- **Tensors which are prior outputs from CUDAGraph Trees**: Because the output tensor addresses
of a cudagraph are fixed, if we run CUDAGraph1, then run CUDAGraph2, the inputs which came from
CUDAGraph1 into CUDAGraph2 will have a fixed memory address. These inputs, like parameters and
buffers, do not require copying over to a static address tensor. We check to make sure that
@ -133,55 +123,54 @@ outputs from CUDAGraph Trees. For mutation on inputs from eager, CUDAGraph Trees
function without CUDAGraph and emit *skipping due to mutated inputs* log. The following example
shows CUDAGraph Trees' support for tensors which are prior outputs from CUDAGraph Trees.
```python
import torch
.. code-block:: python
@torch.compile(mode="reduce-overhead")
def foo(x):
return x + 1
import torch
@torch.compile(mode="reduce-overhead")
def foo(x):
return x + 1
@torch.compile(mode="reduce-overhead")
def mut(x):
return x.add_(2)
# Enable input mutation support
torch._inductor.config.triton.cudagraph_support_input_mutation = True
for i in range(3):
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
# CUDAGraph is applied since `foo` does not mutate `inp`
tmp = foo(inp)
# Although `mut` mutates `tmp`, which is an output of a CUDAGraph
# managed function. So CUDAGraph is still applied.
mut(tmp)
@torch.compile(mode="reduce-overhead")
def mut(x):
return x.add_(2)
# Enable input mutation support
torch._inductor.config.triton.cudagraph_support_input_mutation = True
for i in range(3):
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
# CUDAGraph is applied since `foo` does not mutate `inp`
tmp = foo(inp)
# While `tmp` is a CUDAGraph Tree managed function's output, `tmp.clone()`
# is not. So CUDAGraph is not applied to `mut` and there is a log
# `skipping cudagraphs due to mutated inputs`
mut(tmp.clone())
# Although `mut` mutates `tmp`, which is an output of a CUDAGraph
# managed function. So CUDAGraph is still applied.
mut(tmp)
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
tmp = foo(inp)
# While `tmp` is a CUDAGraph Tree managed function's output, `tmp.clone()`
# is not. So CUDAGraph is not applied to `mut` and there is a log
# `skipping cudagraphs due to mutated inputs`
mut(tmp.clone())
```
To enable CUDAGraph Trees for a function mutating inputs from eager, please re-write
the function to avoid input mutation.
.. note:: Enable input mutation support by setting
`torch._inductor.config.cudagraph_support_input_mutation = True <https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py#L662>`_
for "reduce-overhead" mode.
Dynamic Shape Support
---------------------
> **Note**\
> Enable input mutation support by setting
[torch.\_inductor.config.cudagraph_support_input_mutation = True](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py#L662) for "reduce-overhead" mode.
`Dynamic shape <https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html>`_
### Dynamic Shape Support
[Dynamic shape](https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html)
means that an input tensor has different shapes across function calls. Since CUDAGraph
requires fixed tensor addresses, CUDAGraph Trees re-record CUDAGraph for every unique
shape of an input tensor. This leads to multiple CUDAGraphs for a single inductor graph.
@ -193,128 +182,106 @@ This memory cost can be significant with many CUDAGraph re-recordings.
For functions with frequently changing input tensor shapes, we suggest padding input
tensors to a few fixed tensor shapes to still enjoy benefits from CUDAGraph. In addition,
setting `torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True <https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py#L653>`_
setting [torch.\_inductor.config.triton.cudagraph_skip_dynamic_graphs=True](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py#L653)
allows to skip cudagraphing functions with dynamic shape inputs and only cudagraphing
functions with static input tensor shapes.
NCCL Support
------------
### NCCL Support
CUDAGraph Trees support functions with nccl operators. While CUDAGraph Trees perform per-device
record for CUDAGraph, NCCL support allows cross-device communication.
.. code-block:: python
```python
@torch.compile(mode="reduce-overhead")
def func(x):
y = x * x
y = torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM)
x = torch.nn.functional.silu(x)
return x * y
```
@torch.compile(mode="reduce-overhead")
def func(x):
y = x * x
y = torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM)
x = torch.nn.functional.silu(x)
return x * y
Reasons for Skipping CUDAGraph
------------------------------
### Reasons for Skipping CUDAGraph
Since CUDAGraph has requirements such as static input tensor addresses and not supporting
CPU operators, CUDAGraph Trees check whether a function satisfies these requirements and
may skip CUDAGraph when necessary. Here, we list common reasons for skipping CUDAGraph.
* **Input mutation**: CUDAGraph Trees skip functions that in-place mutates eager input.
- **Input mutation**: CUDAGraph Trees skip functions that in-place mutates eager input.
In-place mutating parameters and buffers, or output tensors from CUDAGraph Tree managed
functions are still supported. Please see *Input Mutation Support* section for more details.
* **CPU operators**: Functions containing CPU operator are skipped. Please split the
- **CPU operators**: Functions containing CPU operator are skipped. Please split the
function into multiple functions and apply CUDAGraph Trees on functions with only GPU operators.
* **Multi-device operators**: A function is skipped if it contains operators on multiple
- **Multi-device operators**: A function is skipped if it contains operators on multiple
devices. Currently, CUDAGraph is applied on a per-device basis. Please use supported
libraries such as NCCL for cross-device communication. Please see *NCCL Support*
section for more details.
* **Free unbacked symbols**: Free unbacked symbols usually happen during
`dynamic shapes <https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html>`_.
- **Free unbacked symbols**: Free unbacked symbols usually happen during
[dynamic shapes](https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html).
CUDAGraph Trees currently record a CUDAGraph for every unique input tensor shapes.
Please see *Dynamic Shape Support* for more details.
* **Incompatible operators**: CUDAGraph Trees skip a function if it contain incompatible
- **Incompatible operators**: CUDAGraph Trees skip a function if it contain incompatible
operators. Please replace these operators in a function with supported operators. We
show an exhaustive list of incompatible operators:
```python
aten._fused_moving_avg_obs_fq_helper.default
aten._fused_moving_avg_obs_fq_helper_functional.default
aten.multinomial.default
fbgemm.dense_to_jagged.default
fbgemm.jagged_to_padded_dense.default
run_and_save_rng_state
run_with_rng_state
aten._local_scalar_dense
aten._assert_scalar
```
.. code-block:: python
The following operators are incompatible when [torch.are_deterministic_algorithms_enabled()](https://pytorch.org/docs/stable/generated/torch.are_deterministic_algorithms_enabled.html).
aten._fused_moving_avg_obs_fq_helper.default
aten._fused_moving_avg_obs_fq_helper_functional.default
aten.multinomial.default
fbgemm.dense_to_jagged.default
fbgemm.jagged_to_padded_dense.default
run_and_save_rng_state
run_with_rng_state
aten._local_scalar_dense
aten._assert_scalar
```python
aten._fused_moving_avg_obs_fq_helper.default
aten._fused_moving_avg_obs_fq_helper_functional.default
aten.multinomial.default
fbgemm.dense_to_jagged.default
fbgemm.jagged_to_padded_dense.default
run_and_save_rng_state
run_with_rng_state
aten._local_scalar_dense
aten._assert_scalar
```
The following operators are incompatible when `torch.are_deterministic_algorithms_enabled() <https://pytorch.org/docs/stable/generated/torch.are_deterministic_algorithms_enabled.html>`_.
.. code-block:: python
aten._fused_moving_avg_obs_fq_helper.default
aten._fused_moving_avg_obs_fq_helper_functional.default
aten.multinomial.default
fbgemm.dense_to_jagged.default
fbgemm.jagged_to_padded_dense.default
run_and_save_rng_state
run_with_rng_state
aten._local_scalar_dense
aten._assert_scalar
Limitations
-----------
### Limitations
Because CUDA Graph fixes memory addresses, CUDA Graphs do not have a great way of handling live tensors from a previous invocation.
Lets say we are benchmarking running inference with the following code:
.. code-block:: python
```python
import torch
import torch
@torch.compile(mode="reduce-overhead")
def my_model(x):
y = torch.matmul(x, x)
return y
@torch.compile(mode="reduce-overhead")
def my_model(x):
y = torch.matmul(x, x)
return y
x = torch.randn(10, 10, device="cuda")
y1 = my_model(x)
y2 = my_model(x)
print(y1)
# RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.
x = torch.randn(10, 10, device="cuda")
y1 = my_model(x)
y2 = my_model(x)
print(y1)
# RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.
```
In the Separate CUDA Graph implementation, the output from the first invocation will be overwritten by the second invocation. In CUDAGraph
Trees, we dont want to add unintended dependencies between iterations that would cause us to not hit the hot path, nor do we want we want
to prematurely free memory from a prior invocation. Our heuristics are in inference we start a new iteration on each invocation for
torch.compile, and in training we do the same so long as there is not a pending backward that has not been invoked. If those heuristics
are wrong, you can mark the start of a new iteration with
`torch.compiler.mark_step_begin() <https://pytorch.org/docs/stable/generated/torch.compiler.cudagraph_mark_step_begin.html>`_, or clone
[torch.compiler.mark_step_begin()](https://pytorch.org/docs/stable/generated/torch.compiler.cudagraph_mark_step_begin.html), or clone
tensors of a prior iteration (outside of torch.compile) before you begin the next run.
### Comparisons
Comparisons
-----------
.. list-table::
:widths: 20 40 40
:header-rows: 1
* - Footguns
- Separate CudaGraph
- CUDAGraph Trees
* - Memory Can Increase
- On each graph compilation (new sizes, etc.)
- If you are also running non-cudagraph memory
* - Recordings
- On any new invocation of a graph
- Will re-record on any new, unique path you take through your program
* - Footguns
- Invocation of one graph will overwrite prior invocation
- Cannot persist memory between separate runs through your model - one training loop training, or one run of inference
| Footguns | Separate CudaGraph | CUDAGraph Trees |
|---------------|------------------------------------------------------------|------------------------------------------------------------------------|
| Memory Can Increase | On each graph compilation (new sizes, etc.) | If you are also running non-cudagraph memory |
| Recordings | On any new invocation of a graph | Will re-record on any new, unique path you take through your program |
| Footguns | Invocation of one graph will overwrite prior invocation | Cannot persist memory between separate runs through your model - one training loop training, or one run of inference |

View File

@ -0,0 +1,280 @@
# Custom Backends
## Overview
`torch.compile` provides a straightforward method to enable users
to define custom backends.
A backend function has the contract
`(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable`.
Backend functions can be called by TorchDynamo, the graph tracing component of `torch.compile`,
after tracing an FX graph and are
expected to return a compiled function that is equivalent to the traced FX graph.
The returned callable should have the same contract as the `forward` function of the original `torch.fx.GraphModule`
passed into the backend:
`(*args: torch.Tensor) -> List[torch.Tensor]`.
In order for TorchDynamo to call your backend, pass your backend function as the `backend` kwarg in
`torch.compile`. For example,
```python
import torch
def my_custom_backend(gm, example_inputs):
return gm.forward
def f(...):
...
f_opt = torch.compile(f, backend=my_custom_backend)
@torch.compile(backend=my_custom_backend)
def g(...):
...
```
See below for more examples.
## Registering Custom Backends
You can register your backend using the `register_backend` decorator, for example,
```python
from torch._dynamo import register_backend
@register_backend
def my_compiler(gm, example_inputs):
...
```
Besides the `register_backend` decorator, if your backend is in another python package, you could also register your
backend through entry points of python package, which provides a way for a package to register a plugin for another one.
:::{hint}
You can learn more about `entry_points` in the
[python packaging documentation](https://setuptools.pypa.io/en/latest/userguide/entry_point.html).
:::
To register your backend through `entry_points`, you could add your backend function to the `torch_dynamo_backends` entry point group in the
`setup.py` file of your package like:
```python
...
setup(
...
'torch_dynamo_backends': [
'my_compiler = your_module.submodule:my_compiler',
]
...
)
```
Please replace the `my_compiler` before `=` to the name of your backend's name and replace the part after `=` to
the module and function name of your backend function.
The entry point will be added to your python environment after the installation of the package.
When you call `torch.compile(model, backend="my_compiler")`, PyTorch would first search the backend named `my_compiler`
that has been registered with `register_backend`. If not found, it will continue to search in all backends registered
via `entry_points`.
Registration serves two purposes:
- You can pass a string containing your backend function's name to `torch.compile` instead of the function itself,
for example, `torch.compile(model, backend="my_compiler")`.
- It is required for use with the [minifier](https://pytorch.org/docs/main/torch.compiler_troubleshooting_old.html#minifier). Any generated
code from the minifier must call your code that registers your backend function, typically through an `import` statement.
## Custom Backends after AOTAutograd
It is possible to define custom backends that are called by AOTAutograd rather than TorchDynamo.
This is useful for 2 main reasons:
- Users can define backends that support model training, as AOTAutograd can generate the backward graph for compilation.
- AOTAutograd produces FX graphs consisting of [core Aten ops](https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir). As a result,
custom backends only need to support the core Aten opset, which is a significantly smaller opset than the entire torch/Aten opset.
Wrap your backend with
`torch._dynamo.backends.common.aot_autograd` and use `torch.compile` with the `backend` kwarg as before.
Backend functions wrapped by `aot_autograd` should have the same contract as before.
Backend functions are passed to `aot_autograd` through the `fw_compiler` (forward compiler)
or `bw_compiler` (backward compiler) kwargs. If `bw_compiler` is not specified, the backward compile function
defaults to the forward compile function.
One caveat is that AOTAutograd requires compiled functions returned by backends to be "boxed". This can be done by wrapping
the compiled function with `functorch.compile.make_boxed_func`.
For example,
```python
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func
def my_compiler(gm, example_inputs):
return make_boxed_func(gm.forward)
my_backend = aot_autograd(fw_compiler=my_compiler) # bw_compiler=my_compiler
model_opt = torch.compile(model, backend=my_backend)
```
## Examples
### Debugging Backend
If you want to better understand what is going on during a
compilation, you can create a custom compiler, which is referred to as
backend in this section, that will print pretty print the fx
`GraphModule` extracted from Dynamos bytecode analysis
and return a `forward()` callable.
For example:
```python
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b
fn(torch.randn(10), torch.randn(10))
```
Running the above example produces the following output:
```
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ---------- --------
placeholder x x () {}
placeholder y y () {}
call_function cos <built-in method cos of type object at 0x7f1a894649a8> (x,) {}
call_function sin <built-in method sin of type object at 0x7f1a894649a8> (y,) {}
call_function add <built-in function add> (cos, sin) {}
output output output ((add,),) {}
```
This works for `torch.nn.Module` as well as shown below:
```python
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.cos(x))
mod = MockModule()
optimized_mod = torch.compile(mod, backend=my_compiler)
optimized_mod(torch.randn(10))
```
Lets take a look at one more example with control flow:
```python
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
```
Running this example produces the following output:
```
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f8d259298a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (b, -1) {}
call_function mul_1 <built-in function mul> (x, mul) {}
output output output ((mul_1,),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- --------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (x, b) {}
output output output ((mul,),) {}
The order of the last two graphs is nondeterministic depending
on which one is encountered first by the just-in-time compiler.
```
### Speedy Backend
Integrating a custom backend that offers superior performance is also
easy and well integrate a real one
with [optimize_for_inference](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html):
```python
def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
scripted = torch.jit.script(gm)
return torch.jit.optimize_for_inference(scripted)
```
And then you should be able to optimize any existing code with:
```python
@torch.compile(backend=optimize_for_inference_compiler)
def code_to_accelerate():
...
```
### Composable Backends
TorchDynamo includes many backends, which can be listed with
`torch._dynamo.list_backends()`. You can combine these backends
together with the following code:
```python
from torch._dynamo import lookup_backend
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
try:
trt_compiled = lookup_backend("tensorrt")(gm, example_inputs)
if trt_compiled is not None:
return trt_compiled
except Exception:
pass
# first backend failed, try something else...
try:
inductor_compiled = lookup_backend("inductor")(gm, example_inputs)
if inductor_compiled is not None:
return inductor_compiled
except Exception:
pass
return gm.forward
```

View File

@ -1,288 +0,0 @@
Custom Backends
===============
Overview
--------
``torch.compile`` provides a straightforward method to enable users
to define custom backends.
A backend function has the contract
``(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable``.
Backend functions can be called by TorchDynamo, the graph tracing component of ``torch.compile``,
after tracing an FX graph and are
expected to return a compiled function that is equivalent to the traced FX graph.
The returned callable should have the same contract as the ``forward`` function of the original ``torch.fx.GraphModule``
passed into the backend:
``(*args: torch.Tensor) -> List[torch.Tensor]``.
In order for TorchDynamo to call your backend, pass your backend function as the ``backend`` kwarg in
``torch.compile``. For example,
.. code-block:: python
import torch
def my_custom_backend(gm, example_inputs):
return gm.forward
def f(...):
...
f_opt = torch.compile(f, backend=my_custom_backend)
@torch.compile(backend=my_custom_backend)
def g(...):
...
See below for more examples.
Registering Custom Backends
---------------------------
You can register your backend using the ``register_backend`` decorator, for example,
.. code-block:: python
from torch._dynamo import register_backend
@register_backend
def my_compiler(gm, example_inputs):
...
Besides the ``register_backend`` decorator, if your backend is in another python package, you could also register your
backend through entry points of python package, which provides a way for a package to register a plugin for another one.
.. hint::
You can learn more about ``entry_points`` in the
`python packaging documentation <https://setuptools.pypa.io/en/latest/userguide/entry_point.html>`__.
To register your backend through ``entry_points``, you could add your backend function to the ``torch_dynamo_backends`` entry point group in the
``setup.py`` file of your package like:
.. code-block:: python
...
setup(
...
'torch_dynamo_backends': [
'my_compiler = your_module.submodule:my_compiler',
]
...
)
Please replace the ``my_compiler`` before ``=`` to the name of your backend's name and replace the part after ``=`` to
the module and function name of your backend function.
The entry point will be added to your python environment after the installation of the package.
When you call ``torch.compile(model, backend="my_compiler")``, PyTorch would first search the backend named ``my_compiler``
that has been registered with ``register_backend``. If not found, it will continue to search in all backends registered
via ``entry_points``.
Registration serves two purposes:
* You can pass a string containing your backend function's name to ``torch.compile`` instead of the function itself,
for example, ``torch.compile(model, backend="my_compiler")``.
* It is required for use with the :ref:`minifier <torch.compiler_troubleshooting_old>`. Any generated
code from the minifier must call your code that registers your backend function, typically through an ``import`` statement.
Custom Backends after AOTAutograd
---------------------------------
It is possible to define custom backends that are called by AOTAutograd rather than TorchDynamo.
This is useful for 2 main reasons:
* Users can define backends that support model training, as AOTAutograd can generate the backward graph for compilation.
* AOTAutograd produces FX graphs consisting of `core Aten ops <https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir>`__. As a result,
custom backends only need to support the core Aten opset, which is a significantly smaller opset than the entire torch/Aten opset.
Wrap your backend with
``torch._dynamo.backends.common.aot_autograd`` and use ``torch.compile`` with the ``backend`` kwarg as before.
Backend functions wrapped by ``aot_autograd`` should have the same contract as before.
Backend functions are passed to ``aot_autograd`` through the ``fw_compiler`` (forward compiler)
or ``bw_compiler`` (backward compiler) kwargs. If ``bw_compiler`` is not specified, the backward compile function
defaults to the forward compile function.
One caveat is that AOTAutograd requires compiled functions returned by backends to be "boxed". This can be done by wrapping
the compiled function with ``functorch.compile.make_boxed_func``.
For example,
.. code-block:: python
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func
def my_compiler(gm, example_inputs):
return make_boxed_func(gm.forward)
my_backend = aot_autograd(fw_compiler=my_compiler) # bw_compiler=my_compiler
model_opt = torch.compile(model, backend=my_backend)
Examples
--------
Debugging Backend
^^^^^^^^^^^^^^^^^
If you want to better understand what is going on during a
compilation, you can create a custom compiler, which is referred to as
backend in this section, that will print pretty print the fx
``GraphModule`` extracted from Dynamos bytecode analysis
and return a ``forward()`` callable.
For example:
.. code-block:: python
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b
fn(torch.randn(10), torch.randn(10))
Running the above example produces the following output:
::
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ---------- --------
placeholder x x () {}
placeholder y y () {}
call_function cos <built-in method cos of type object at 0x7f1a894649a8> (x,) {}
call_function sin <built-in method sin of type object at 0x7f1a894649a8> (y,) {}
call_function add <built-in function add> (cos, sin) {}
output output output ((add,),) {}
This works for ``torch.nn.Module`` as well as shown below:
.. code-block:: python
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.cos(x))
mod = MockModule()
optimized_mod = torch.compile(mod, backend=my_compiler)
optimized_mod(torch.randn(10))
Lets take a look at one more example with control flow:
.. code-block:: python
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
Running this example produces the following output:
::
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f8d259298a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (b, -1) {}
call_function mul_1 <built-in function mul> (x, mul) {}
output output output ((mul_1,),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- --------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (x, b) {}
output output output ((mul,),) {}
The order of the last two graphs is nondeterministic depending
on which one is encountered first by the just-in-time compiler.
Speedy Backend
^^^^^^^^^^^^^^
Integrating a custom backend that offers superior performance is also
easy and well integrate a real one
with `optimize_for_inference <https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html>`__:
.. code-block:: python
def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
scripted = torch.jit.script(gm)
return torch.jit.optimize_for_inference(scripted)
And then you should be able to optimize any existing code with:
.. code-block:: python
@torch.compile(backend=optimize_for_inference_compiler)
def code_to_accelerate():
...
Composable Backends
^^^^^^^^^^^^^^^^^^^
TorchDynamo includes many backends, which can be listed with
``torch._dynamo.list_backends()``. You can combine these backends
together with the following code:
.. code-block:: python
from torch._dynamo import lookup_backend
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
try:
trt_compiled = lookup_backend("tensorrt")(gm, example_inputs)
if trt_compiled is not None:
return trt_compiled
except Exception:
pass
# first backend failed, try something else...
try:
inductor_compiled = lookup_backend("inductor")(gm, example_inputs)
if inductor_compiled is not None:
return inductor_compiled
except Exception:
pass
return gm.forward

View File

@ -1,12 +1,10 @@
Dynamic shapes
==============
# Dynamic Shapes
Code: `symbolic_shapes.py <https://github.com/pytorch/pytorch/blob/db4572dbf18f1cf50cf662547e272d3117063747/torch/fx/experimental/symbolic_shapes.py>`_
Code: [symbolic_shapes.py](https://github.com/pytorch/pytorch/blob/db4572dbf18f1cf50cf662547e272d3117063747/torch/fx/experimental/symbolic_shapes.py)
See also: `The dynamic shapes manual <https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.fh8zzonyw8ng>`_
See also: [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.fh8zzonyw8ng)
Motivation
----------
## Motivation
Deep learning compilers commonly only work for static shapes, that is to say, they produced compiled programs which only work for a single specific configuration of input shapes, and must recompile if any input shape changes. This assumption works great for the majority of commonly run deep learning models today, but there are a few situations where it is insufficient:
@ -16,67 +14,58 @@ Deep learning compilers commonly only work for static shapes, that is to say, th
In supporting dynamic shapes, we chose not to support dynamic rank programs, e.g., programs whose inputs tensors change in dimensionality, as this pattern rarely occurs in real-world deep learning programs, and it avoids the need to reason inductively over symbolic lists of shapes.
Abridged public API
-------------------
## Abridged public API
The default dynamic behavior in PyTorch 2.1 is:
- PT2 assumes everything is static by default
- If we recompile because a size changed, we will instead attempt to recompile
that size as being dynamic (sizes that have changed are likely to change in
the future). This generalization may fail (e.g., because user code does a
the future). This generalization may fail (e.g., because user code does a
conditional branch on the size in question or missing dynamic shapes support
in PT2). If you are trying to understand why PT2 has overspecialized some
code, run with ``TORCH_LOGS=dynamic`` and look for "eval" entries that say
in PT2). If you are trying to understand why PT2 has overspecialized some
code, run with `TORCH_LOGS=dynamic` and look for "eval" entries that say
when guards are added and why.
- If you know ahead of time something will be dynamic, you can skip the first
recompile with ``torch._dynamo.mark_dynamic(tensor, dim)``. If you know ahead of time
the ``min`` and ``max`` value this dimension can take, you can specify ``torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)``
- If you say ``torch.compile(dynamic=False)``, we will turn off automatic
recompile with `torch._dynamo.mark_dynamic(tensor, dim)`. If you know ahead of time
the `min` and `max` value this dimension can take, you can specify `torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)`
- If you say `torch.compile(dynamic=False)`, we will turn off automatic
dynamic shapes on recompiles and always recompile for each distinct size.
Conversely, if you say ``torch.compile(dynamic=True)``, we will try to make
everything as dynamic as possible. This is mostly useful for small
operators; if you try it on a big model it will (1) probably crash PT2 and
(2) run slow for no good reason.
Conversely, if you say `torch.compile(dynamic=True)`, we will try to make
everything as dynamic as possible. This is mostly useful for small
operators; if you try it on a big model it will (1) probably crash PT2 and (2) run slow for no good reason.
- You can whitelist specific sources to be marked as dynamic using the
``TORCH_COMPILE_DYNAMIC_SOURCES`` environment variable or by setting
``torch.compiler.config.dynamic_sources``. This is particularly useful for large
`TORCH_COMPILE_DYNAMIC_SOURCES` environment variable or by setting
`torch.compiler.config.dynamic_sources`. This is particularly useful for large
models with graph breaks, as you can maintain dynamism across graph breaks since
source names stay consistent. You can also use this to mark integers as dynamic.
The format is a comma-delimited list of source names, e.g., ``"L['x'], L['y']"``.
You can also use regexes, e.g., `"L\\['x.*'\\], L\\['y.*'\\]")`.
This whitelist takes precedence over other flags like ``dynamic=False``,
``force_nn_module_property_static_shapes``, and ``force_parameter_static_shapes``.
The format is a comma-delimited list of source names, e.g., `"L['x'], L['y']"`.
You can also use regexes, e.g., `"L\['x.*'\], L\['y.*'\]")`.
This whitelist takes precedence over other flags like `dynamic=False`,
`force_nn_module_property_static_shapes`, and `force_parameter_static_shapes`.
- Sometimes it can be cumbersome to find the right inputs to mark as dynamic. If
you're willing to take a performance hit for the first batch, one other affordable
option we have are the eager_then_compile stances which derive dynamism for you.
See `torch.compiler.set_stance <https://docs.pytorch.org/docs/stable/generated/torch.compiler.set_stance.html>`_ for more details.
See [torch.compiler.set_stance](https://docs.pytorch.org/docs/stable/generated/torch.compiler.set_stance.html) for more details.
The Guard Model
---------------
## The Guard Model
When considering how to add support for dynamic shapes to TorchDynamo and TorchInductor, we made a major design decision: in order to reuse decompositions and other preexisting code written in Python/C++ targeting the PyTorch API, we must be able to trace through dynamic shapes. Unlike a fully symbolic system which might capture both branches of a conditional, we always pick one branch and specialize our trace under the assumption that we only use this trace when we would have made the same choice for that branch in the future. To do this, we maintain a "hint" for every symbolic size saying what its concrete value is at compile time (as TorchDynamo is a just-in-time compiler, it always knows what the actual input sizes are.) When we perform a condition on a tensor, we simply consult the hint to find out which branch to take.
This greatly simplifies the symbolic shape formulas we produce, but means we have a much more involved system for managing guards. Consider, for example, the following program:
.. code-block:: python
```python
def f(x, y):
z = torch.cat([x, y])
if z.size(0) > 2:
return z.mul(2)
else:
return z.add(2)
```
def f(x, y):
z = torch.cat([x, y])
if z.size(0) > 2:
return z.mul(2)
else:
return z.add(2)
The final IR we will compile with TorchInductor will either be `torch.cat([x, y]).add(2)` or `torch.cat([x, y]).mul(2)` (with the condition flattened away), but to determine which branch we are in, we would need to know the size of `z`, an intermediate. Because TorchDynamo must know upfront if a compiled trace is valid (we do not support bailouts, like some JIT compilers), we must be able to reduce `z.size(0)` as an expression in terms of the inputs, `x.size(0) + y.size(0)`. This is done by writing meta functions for all operators in PyTorch which can propagate size information to the output of a tensor without actually performing computation on the node.
The final IR we will compile with TorchInductor will either be ``torch.cat([x, y]).add(2)`` or ``torch.cat([x, y]).mul(2)`` (with the condition flattened away), but to determine which branch we are in, we would need to know the size of ``z``, an intermediate. Because TorchDynamo must know upfront if a compiled trace is valid (we do not support bailouts, like some JIT compilers), we must be able to reduce ``z.size(0)`` as an expression in terms of the inputs, ``x.size(0) + y.size(0)``. This is done by writing meta functions for all operators in PyTorch which can propagate size information to the output of a tensor without actually performing computation on the node.
Overall architecture
--------------------
## Overall architecture
Symbolic shapes workflow:
@ -84,24 +73,23 @@ Symbolic shapes workflow:
2. We allocate symbolic sizes for tensors on entry (what is static or dynamic is a policy decision, with some knobs).
3. We propagate the symbolic sizes through operators, maintaining both (1) FX IR so that we can faithfully export symbolic compute, and (2) Sympy expressions representing the size vars, so we can reason about them.
4. When we condition on symbolic sizes, either in Dynamo tracing or in Inductor optimization, we add guards based on the conditional. These can be induced from both Python and C++.
5. These guards can induce further simplifications on symbolic variables. For example, if you assert ``s0 == 4``, we can now replace all occurrences of ``s0`` with ``4``.
5. These guards can induce further simplifications on symbolic variables. For example, if you assert `s0 == 4`, we can now replace all occurrences of `s0` with `4`.
6. When we're done tracing and optimizing, we install all of these guards with the compiled code; the compiled code is only reusable if all the guards evaluate true.
Important files:
- C++ SymInt API: ``c10/core/SymInt.h``, ``SymFloat.h``, ``SymBool.h``
- Python SymInt API: ``torch/__init__.py`` (look for ``SymInt/SymFloat/SymBool``)
- C++ plumbing: ``c10/core/SymNodeImpl.h``, ``torch/csrc/utils/python_symnode.h``, ``torch/csrc/jit/python/init.cpp``
- Python infrastructure: ``torch/fx/experimental/symbolic_shapes.py``
- Other important files: ``torch/_subclasses/fake_tensor.py``, ``torch/_meta_registrations.py``, decomps, PrimTorch refs
- C++ SymInt API: `c10/core/SymInt.h`, `SymFloat.h`, `SymBool.h`
- Python SymInt API: `torch/__init__.py` (look for `SymInt/SymFloat/SymBool`)
- C++ plumbing: `c10/core/SymNodeImpl.h`, `torch/csrc/utils/python_symnode.h`, `torch/csrc/jit/python/init.cpp`
- Python infrastructure: `torch/fx/experimental/symbolic_shapes.py`
- Other important files: `torch/_subclasses/fake_tensor.py`, `torch/_meta_registrations.py`, decomps, PrimTorch refs
Abridged internal API
---------------------
## Abridged internal API
Understanding the Python class hierarchy:
- SymInt/SymFloat/SymBool: these are user-visible classes that simulate their int/float/bool counterparts. If you add two SymInts, we give you a new SymInt that symbolically tracks that the integer addition had occurred.
- SymNode: this is the internal structure (accessible via e.g., ``symint.node``) which holds the actual symbolic tracking info. SymNode is type erased; this makes it more convenient to represent mixed-type operations. Note that technically you don't have to call into Python SymNode from SymInt; for example, XLA's C++ ``SymNodeImpl`` would take the place of SymNode.
- SymNode: this is the internal structure (accessible via e.g., `symint.node`) which holds the actual symbolic tracking info. SymNode is type erased; this makes it more convenient to represent mixed-type operations. Note that technically you don't have to call into Python SymNode from SymInt; for example, XLA's C++ `SymNodeImpl` would take the place of SymNode.
- ShapeEnv: per-compile context state which keeps track of all the free symbols and guards we have accumulated so far. Every SymNode records its ShapeEnv (but not vice versa; SymNodes only get used if they participate in a guard).
C++ is fairly similar:
@ -110,10 +98,9 @@ C++ is fairly similar:
- c10::SymNode/SymNodeImpl: analogous to SymNode
- There is no ShapeEnv in C++; for ease of debugging, the entire symbolic reasoning apparatus is in Python.
When you write code that is traceable with ``make_fx``, it must be able to deal with SymInt/SymFloat/SymBool flowing through it. `The dynamic shapes manual <https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.fh8zzonyw8ng>`_ gives some guidance for how to do this.
When you write code that is traceable with `make_fx`, it must be able to deal with SymInt/SymFloat/SymBool flowing through it. [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.fh8zzonyw8ng) gives some guidance for how to do this.
DimDynamic policy
-----------------
## DimDynamic policy
Symbolic reasoning:
@ -122,22 +109,21 @@ Symbolic reasoning:
- Constraints
- DimDynamic/Constraint
Unbacked SymInts
----------------
## Unbacked SymInts
To resolve control flow, we check the hint, aka actual value, of a symbolic integer to determine which branch to go. However, in some cases, we may not have a hint: so-called unbacked symbolic integers arise when a size variable emerges from a data-dependent operation like ``.nonzero()`` or ``.item()``. It is illegal to perform control flow on these symbolic integers, so we must graph break on these operations.
To resolve control flow, we check the hint, aka actual value, of a symbolic integer to determine which branch to go. However, in some cases, we may not have a hint: so-called unbacked symbolic integers arise when a size variable emerges from a data-dependent operation like `.nonzero()` or `.item()`. It is illegal to perform control flow on these symbolic integers, so we must graph break on these operations.
Naively implemented, this is too restrictive: most PyTorch programs will immediately fail if you try to do anything with unbacked symbolic integers. Here are the most important enhancements to make this actually work:
- On tensor creation, PyTorch precomputes a lot of data about a tensor; for example, if you use ``empty_strided`` to create a tensor, we will eagerly sort the strides and determine if the tensor is non-overlapping and dense. Sorts produce a lot of guards. However, it is more common to produce a tensor directly with a higher-level API like ``empty``, which is guaranteed to produce a non-overlapping and dense tensor. We modified PyTorch to avoid needlessly recomputing these properties.
- On tensor creation, PyTorch precomputes a lot of data about a tensor; for example, if you use `empty_strided` to create a tensor, we will eagerly sort the strides and determine if the tensor is non-overlapping and dense. Sorts produce a lot of guards. However, it is more common to produce a tensor directly with a higher-level API like `empty`, which is guaranteed to produce a non-overlapping and dense tensor. We modified PyTorch to avoid needlessly recomputing these properties.
- Even if nontrivial compute is needed, sometimes a property is never actually queried at all. Making these precomputed properties lazy allows us to avoid guarding on an unbacked symbolic integer unless it is actually needed.
- The data in an integer tensor is generally not known to be non-negative. However, we provide an API ``constrain_range`` whereby a user can specify that a size is bounded above and below by known limits.
- The data in an integer tensor is generally not known to be non-negative. However, we provide an API `constrain_range` whereby a user can specify that a size is bounded above and below by known limits.
Similar to the dynamic APIs, there are corresponding unbacked APIs: namely you can use mark_unbacked instead of ``mark_dynamic`` and ``TORCH_COMPILE_UNBACKED_SOURCES`` instead of ``TORCH_COMPILE_DYNAMIC_SOURCES`` to tell the compiler to mark an input as unbacked.
Similar to the dynamic APIs, there are corresponding unbacked APIs: namely you can use mark_unbacked instead of `mark_dynamic` and `TORCH_COMPILE_UNBACKED_SOURCES` instead of `TORCH_COMPILE_DYNAMIC_SOURCES` to tell the compiler to mark an input as unbacked.
In future versions of PT2 (beyond PT2.1), we will extend our reasoning system
to infer that an unbacked symbolic integer is size-like based on usage. For
example, if you pass the result of an ``.item()`` call to a factory function
like ``torch.empty``, we will automatically infer that the result is a size
(because if it was not, it would fail.) This assumption would get validated
to infer that an unbacked symbolic integer is size-like based on usage. For
example, if you pass the result of an `.item()` call to a factory function
like `torch.empty`, we will automatically infer that the result is a size
(because if it was not, it would fail.) This assumption would get validated
at runtime, raising an error if it was not fulfilled.