convert: rst to myst pr 1/2 (#155840)

Fixes #155038
parent [PR](https://github.com/pytorch/pytorch/pull/155375) (made two PRs to pass sanity check)
this PR converts the following two .rst files
- [torch.compiler_dynamo_overview](https://github.com/pytorch/pytorch/blob/main/docs/source/torch.compiler_dynamo_overview.rst)
- [torch.compiler_fake_tensor](https://github.com/pytorch/pytorch/blob/main/docs/source/torch.compiler_fake_tensor.rst)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155840
Approved by: https://github.com/sekyondaMeta
This commit is contained in:
Dhia-naouali
2025-06-13 18:02:28 +00:00
committed by PyTorch MergeBot
parent 36bf81e363
commit c5d00e150a
3 changed files with 403 additions and 439 deletions

View File

@ -0,0 +1,333 @@
# Dynamo Overview
Before you read this section, read {ref}`torch.compiler_overview`.
TorchDynamo (or simply Dynamo) is a Python-level Just-In-Time (JIT) compiler designed to make
unmodified PyTorch programs faster. Dynamo hooks into the frame evaluation
API in CPython ([PEP 523](https://peps.python.org/pep-0523/)) to
dynamically modify Python bytecode right before it is executed. It
rewrites Python bytecode to extract sequences of PyTorch
operations into an [FX Graph](https://pytorch.org/docs/stable/fx.html)
which is then compiled with a customizable backend.
It creates this FX Graph through bytecode analysis and is designed to
mix Python execution with compiled backends to get the best of both
worlds — usability and performance.
Dynamo makes it easy to experiment with different compiler
backends to make PyTorch code faster with a single line decorator
`torch._dynamo.optimize()` which is wrapped for convenience by `torch.compile()`
The following diagram demonstrates how PyTorch works with `torch.compile`
and without it:
```{image} _static/img/dynamo/TorchDynamo.png
```
`TorchInductor` is one of the backends
supported by [Dynamo Graph](https://pytorch.org/docs/stable/fx.html)
into [Triton](https://github.com/openai/triton) for GPUs or
[C++/OpenMP](https://www.openmp.org/) for CPUs. We have a
[training performance dashboard](https://github.com/pytorch/torchdynamo/issues/681#issuecomment-1233828468)
that provides performance comparison for different training backends. You can read
more in the [TorchInductor post on PyTorch
dev-discuss](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747).
For an in-depth overview, read the sections below, watch the deep-dive video,
and check out the dev-discuss topics.
- [Dynamo deep-dive video](https://www.youtube.com/watch?v=egZB5Uxki0I)
- [dev-discuss topics](https://dev-discuss.pytorch.org/search?q=TorchDynamo%20order%3Alatest)
## Dynamo Internals
**Author**: [Jason Ansel](https://github.com/jansel) and [Kaichao You](https://github.com/youkaichao)
This section will go over some of the Dynamo internals and will
demonstrate how Dynamo works under the hood.
### What is a guard?
Dynamo operates just-in-time and specializes graphs based on
dynamic properties. Below is a basic example of how to use Dynamo.
One can decorate a function or a method using `torchdynamo.optimize` to enable
Dynamo optimization:
```python
from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
```
For example, the first graph above has the following
guards:
```
GUARDS:
hasattr(L['a'], '_dynamo_dynamic_indices') == False
hasattr(L['b'], '_dynamo_dynamic_indices') == False
utils_device.CURRENT_DEVICE == None
___skip_backend_check() or ___current_backend() == ___lookup_backend(140355900538256)
check_tensor(L['a'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1])
check_tensor(L['b'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1])
```
If any of those guards fail, the graph will be recaptured and
recompiled. The interesting guard there is `check_tensor`, which
checks the following `torch.Tensor` properties:
- Python class of the tensor (tensor subclassing, etc)
- dtype
- device
- requires_grad
- dispatch_key (with thread-local includes/excludes applied)
- ndim
- sizes\*
- strides\*
The full specialization mode allows the backend compiler to assume an
entirely static graph. Unfortunately, most backends require this.
Operators which return dynamic shapes will trigger a graph break when
not in dynamic shape mode.
### What is Dynamo doing?
If you want to understand better what Dynamo is doing, you can run your code with:
```
TORCH_LOGS="+dynamo,guards,bytecode"
```
If you are not familiar with Python bytecode, you can add a decompiler hook
to decompile the bytecode into human-readable source code. One available
tool is [depyf](https://github.com/youkaichao/depyf). If you don't have
`depyf` already installed, run `pip install depyf`. Then, add the
following code to install decompilation hooks before you run any code.
```python
import depyf
depyf.install()
```
This code triggers useful (but spammy) printouts.
For example, the printouts for the first graph in the `toy_example`
are:
```
__compiled_fn_0 <eval_with_key>.1
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f9ca082f8a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
ORIGINAL BYTECODE toy_example example.py line 12
14 0 LOAD_FAST 0 (a)
2 LOAD_GLOBAL 0 (torch)
4 LOAD_METHOD 1 (abs)
6 LOAD_FAST 0 (a)
8 CALL_METHOD 1
10 LOAD_CONST 1 (1)
12 BINARY_ADD
14 BINARY_TRUE_DIVIDE
16 STORE_FAST 2 (x)
15 18 LOAD_FAST 1 (b)
20 LOAD_METHOD 2 (sum)
22 CALL_METHOD 0
24 LOAD_CONST 2 (0)
26 COMPARE_OP 0 (<)
28 POP_JUMP_IF_FALSE 19 (to 38)
16 30 LOAD_FAST 1 (b)
32 LOAD_CONST 3 (-1)
34 BINARY_MULTIPLY
36 STORE_FAST 1 (b)
17 >> 38 LOAD_FAST 2 (x)
40 LOAD_FAST 1 (b)
42 BINARY_MULTIPLY
44 RETURN_VALUE
MODIFIED BYTECODE toy_example example.py line 12
12 0 LOAD_GLOBAL 3 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 LOAD_FAST 1 (b)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 2
10 STORE_FAST 2 (x)
12 POP_JUMP_IF_FALSE 12 (to 24)
14 LOAD_GLOBAL 4 (__resume_at_30_1)
16 LOAD_FAST 1 (b)
18 LOAD_FAST 2 (x)
20 CALL_FUNCTION 2
22 RETURN_VALUE
>> 24 LOAD_GLOBAL 5 (__resume_at_38_2)
26 LOAD_FAST 1 (b)
28 LOAD_FAST 2 (x)
30 CALL_FUNCTION 2
32 RETURN_VALUE
possible source code:
def toy_example(a, b):
__temp_1 = __compiled_fn_0(a, b)
x = __temp_1[0]
if __temp_1[1]:
return __resume_at_30_1(b, x)
return __resume_at_38_2(b, x)
If you find the decompiled code is wrong,please submit an issue at https://github.com/youkaichao/depyf/issues.
```
At the top you can see the FX graph.
Next, you see the original bytecode of the function, followed by the
modified bytecode generated by Dynamo, and the decompiled source
code for reference. Finally, you see the guards which we covered above.
In the modified bytecode, `__compiled_fn_0` is the return value of
`my_compiler()` (the compiled graph). `__resume_at_30_1` and
`__resume_at_38_2` are both generated continuation functions that pick
up execution after a graph break (at bytecode offsets 30 and 38). Each
of these functions take the form:
```
__resume_at_<offset>:
... restore stack state if needed ...
JUMP_ABSOLUTE <offset> into toy_example
... original bytecode of toy_example ...
```
By generating this `resume_at` function, we force the remainder of the
function to be executed in a new Python frame which recursively
triggers Dynamo to restart its capture once execution reaches that
point for the first time.
### How to inspect artifacts generated by Dynamo?
To inspect the artifacts generated by Dynamo, there is an API `torch._dynamo.eval_frame._debug_get_cache_entry_list` that retrieves compiled code and guards out of a function's `__code__` object. A compiled function can have several cache entries, and each cache entry consists a generated function to check guards, and a `types.CodeType` object to keep the code to be executed if the guarding conditions are satisfied.
```python
from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn
cache_entries = _debug_get_cache_entry_list(innermost_fn(toy_example))
cache_entry = cache_entries[0]
guard, code = cache_entry.check_fn, cache_entry.code
# the guard takes the local variables of an input frame, and tells whether a re-compilation should be triggered.
import dis
dis.dis(guard)
dis.dis(code)
```
If you know Python bytecode, you can understand the above output.
For the guard function, there is no need to inspect the bytecode. We can directly access its guarding conditions:
```python
for code_part in guard.code_parts:
print(code_part)
```
The output is:
```
___guarded_code.valid
___check_global_state()
hasattr(L['a'], '_dynamo_dynamic_indices') == False
hasattr(L['b'], '_dynamo_dynamic_indices') == False
utils_device.CURRENT_DEVICE == None
___skip_backend_check() or ___current_backend() == ___lookup_backend(140215810860528)
___check_tensors(L['a'], L['b'], tensor_check_names=tensor_check_names)
```
Only when all the conditions are satisfied, the guard function returns true, and the compiled code is executed.
For the compiled code, we cannot directly access its source but have to decompile it.
```python
from depyf import decompile
print(decompile(code))
```
The output is:
```
def toy_example(a, b):
__temp_1 = __compiled_fn_0(a, b)
x = __temp_1[0]
if __temp_1[1]:
return __resume_at_30_1(b, x)
return __resume_at_38_2(b, x)
```
Some names referenced in the code are:
- Compiled functions, stored in the global namespace of the module containing the original function `toy_example`. These include names like `__compiled_fn_0` / `__resume_at_30_1` / `__resume_at_38_2`.
- Closure variables used for checking guards. The names can be accessed from `guard.__code__.co_freevars`, and the values are stored in `guard.__closure__`. These include names like `___guarded_code` / `___is_grad_enabled` / `___are_deterministic_algorithms_enabled` / `___is_torch_function_enabled` / `utils_device` / `___check_tensors` / `tensor_check_names`.
- Argument `L` of the `guard` function. This is a dict mapping the name of arguments of `toy_example` to its values. This is only available when the function is called, where the frame evaluation API comes into play. In short, `L` is a `dict` with structure of `{'a': value_a, 'b': value_b}`. Therefore, you can see the code uses `L['a']` to refer to the input variable `a`.
The graph break is shown in the code of compiled `toy_example`, where we have to use Python interpreter to select the following graph to execute.
Note that we pass a simple `my_compiler` function as the backend compiler, therefore the subgraph code `__resume_at_38_2`, `__resume_at_30_1`, and `__compiled_fn_0` remain Python code. This can also be inspected (please ignore the function name, and only use the function signature and function body code):
```python
print("source code of __compiled_fn_0:")
print(innermost_fn(__compiled_fn_0).__self__.code)
print("=" * 60)
print("source code of __resume_at_30_1:")
print(decompile(__resume_at_30_1))
print("=" * 60)
print("source code of __resume_at_38_2:")
print(decompile(__resume_at_38_2))
```
```
source code of __compiled_fn_0:
def forward(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
l_a_ = L_a_
l_b_ = L_b_
abs_1 = torch.abs(l_a_)
add = abs_1 + 1; abs_1 = None
truediv = l_a_ / add; l_a_ = add = None
sum_1 = l_b_.sum(); l_b_ = None
lt = sum_1 < 0; sum_1 = None
return (truediv, lt)
# To see more debug info, please use ``graph_module.print_readable()``
============================================================
source code of __resume_at_30_1:
def <resume in toy_example>(b, x):
b = b * -1
return x * b
============================================================
source code of __resume_at_38_2:
def <resume in toy_example>(b, x):
return x * b
```
However, if we use other backends like the built-in `inductor`, the subgraph code will be compiled CUDA kernels for GPU or C++ code for CPU.
To summarize, the compiled code is conceptually equivalent to the code below:
```python
def compiled_example(a, b):
L = {'a': a, 'b': b}
for guard, code in get_cache_entries():
if guard(L):
return code(a, b)
recompile_and_add_another_cache_entry()
```
The following diagram demonstrates how `torch.compile` transforms and optimizes user-written code: it first extracts computation graphs from the user-written function, and compiles these graphs into optimized functions, then assembles them into a new function, which is functionally equivalent to the user-written code but optimized to have a good computation speed.
```{image} _static/img/dynamo/flowchart.jpg
```
To learn more about how all this is implemented internally, see {ref}`torch.compiler_dynamo_deepdive`.

View File

@ -1,350 +0,0 @@
Dynamo Overview
===============
Before you read this section, read :ref:`torch.compiler_overview`.
TorchDynamo (or simply Dynamo) is a Python-level Just-In-Time (JIT) compiler designed to make
unmodified PyTorch programs faster. Dynamo hooks into the frame evaluation
API in CPython (`PEP 523 <https://peps.python.org/pep-0523/>`__) to
dynamically modify Python bytecode right before it is executed. It
rewrites Python bytecode to extract sequences of PyTorch
operations into an `FX Graph <https://pytorch.org/docs/stable/fx.html>`__
which is then compiled with a customizable backend.
It creates this FX Graph through bytecode analysis and is designed to
mix Python execution with compiled backends to get the best of both
worlds — usability and performance.
Dynamo makes it easy to experiment with different compiler
backends to make PyTorch code faster with a single line decorator
``torch._dynamo.optimize()`` which is wrapped for convenience by ``torch.compile()``
The following diagram demonstrates how PyTorch works with ``torch.compile``
and without it:
.. image:: _static/img/dynamo/TorchDynamo.png
`TorchInductor` is one of the backends
supported by `Dynamo Graph <https://pytorch.org/docs/stable/fx.html>`__
into `Triton <https://github.com/openai/triton>`__ for GPUs or
`C++/OpenMP <https://www.openmp.org/>`__ for CPUs. We have a
`training performance dashboard <https://github.com/pytorch/torchdynamo/issues/681#issuecomment-1233828468>`__
that provides performance comparison for different training backends. You can read
more in the `TorchInductor post on PyTorch
dev-discuss <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__.
For an in-depth overview, read the sections below, watch the deep-dive video,
and check out the dev-discuss topics.
* `Dynamo deep-dive video <https://www.youtube.com/watch?v=egZB5Uxki0I>`__
* `dev-discuss topics <https://dev-discuss.pytorch.org/search?q=TorchDynamo%20order%3Alatest>`__
Dynamo Internals
~~~~~~~~~~~~~~~~
**Author**: `Jason Ansel <https://github.com/jansel>`_ and `Kaichao You <https://github.com/youkaichao>`_
This section will go over some of the Dynamo internals and will
demonstrate how Dynamo works under the hood.
What is a guard?
----------------
Dynamo operates just-in-time and specializes graphs based on
dynamic properties. Below is a basic example of how to use Dynamo.
One can decorate a function or a method using ``torchdynamo.optimize`` to enable
Dynamo optimization:
.. code-block:: python
from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
For example, the first graph above has the following
guards:
::
GUARDS:
hasattr(L['a'], '_dynamo_dynamic_indices') == False
hasattr(L['b'], '_dynamo_dynamic_indices') == False
utils_device.CURRENT_DEVICE == None
___skip_backend_check() or ___current_backend() == ___lookup_backend(140355900538256)
check_tensor(L['a'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1])
check_tensor(L['b'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1])
If any of those guards fail, the graph will be recaptured and
recompiled. The interesting guard there is ``check_tensor``, which
checks the following ``torch.Tensor`` properties:
- Python class of the tensor (tensor subclassing, etc)
- dtype
- device
- requires_grad
- dispatch_key (with thread-local includes/excludes applied)
- ndim
- sizes\*
- strides\*
The full specialization mode allows the backend compiler to assume an
entirely static graph. Unfortunately, most backends require this.
Operators which return dynamic shapes will trigger a graph break when
not in dynamic shape mode.
What is Dynamo doing?
---------------------
If you want to understand better what Dynamo is doing, you can run your code with:
::
TORCH_LOGS="+dynamo,guards,bytecode"
If you are not familiar with Python bytecode, you can add a decompiler hook
to decompile the bytecode into human-readable source code. One available
tool is `depyf <https://github.com/youkaichao/depyf>`__. If you don't have
``depyf`` already installed, run ``pip install depyf``. Then, add the
following code to install decompilation hooks before you run any code.
.. code-block:: python
import depyf
depyf.install()
This code triggers useful (but spammy) printouts.
For example, the printouts for the first graph in the ``toy_example``
are:
::
__compiled_fn_0 <eval_with_key>.1
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f9ca082f8a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
ORIGINAL BYTECODE toy_example example.py line 12
14 0 LOAD_FAST 0 (a)
2 LOAD_GLOBAL 0 (torch)
4 LOAD_METHOD 1 (abs)
6 LOAD_FAST 0 (a)
8 CALL_METHOD 1
10 LOAD_CONST 1 (1)
12 BINARY_ADD
14 BINARY_TRUE_DIVIDE
16 STORE_FAST 2 (x)
15 18 LOAD_FAST 1 (b)
20 LOAD_METHOD 2 (sum)
22 CALL_METHOD 0
24 LOAD_CONST 2 (0)
26 COMPARE_OP 0 (<)
28 POP_JUMP_IF_FALSE 19 (to 38)
16 30 LOAD_FAST 1 (b)
32 LOAD_CONST 3 (-1)
34 BINARY_MULTIPLY
36 STORE_FAST 1 (b)
17 >> 38 LOAD_FAST 2 (x)
40 LOAD_FAST 1 (b)
42 BINARY_MULTIPLY
44 RETURN_VALUE
MODIFIED BYTECODE toy_example example.py line 12
12 0 LOAD_GLOBAL 3 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 LOAD_FAST 1 (b)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 2
10 STORE_FAST 2 (x)
12 POP_JUMP_IF_FALSE 12 (to 24)
14 LOAD_GLOBAL 4 (__resume_at_30_1)
16 LOAD_FAST 1 (b)
18 LOAD_FAST 2 (x)
20 CALL_FUNCTION 2
22 RETURN_VALUE
>> 24 LOAD_GLOBAL 5 (__resume_at_38_2)
26 LOAD_FAST 1 (b)
28 LOAD_FAST 2 (x)
30 CALL_FUNCTION 2
32 RETURN_VALUE
possible source code:
def toy_example(a, b):
__temp_1 = __compiled_fn_0(a, b)
x = __temp_1[0]
if __temp_1[1]:
return __resume_at_30_1(b, x)
return __resume_at_38_2(b, x)
If you find the decompiled code is wrong,please submit an issue at https://github.com/youkaichao/depyf/issues.
At the top you can see the FX graph.
Next, you see the original bytecode of the function, followed by the
modified bytecode generated by Dynamo, and the decompiled source
code for reference. Finally, you see the guards which we covered above.
In the modified bytecode, ``__compiled_fn_0`` is the return value of
``my_compiler()`` (the compiled graph). ``__resume_at_30_1`` and
``__resume_at_38_2`` are both generated continuation functions that pick
up execution after a graph break (at bytecode offsets 30 and 38). Each
of these functions take the form:
::
__resume_at_<offset>:
... restore stack state if needed ...
JUMP_ABSOLUTE <offset> into toy_example
... original bytecode of toy_example ...
By generating this ``resume_at`` function, we force the remainder of the
function to be executed in a new Python frame which recursively
triggers Dynamo to restart its capture once execution reaches that
point for the first time.
How to inspect artifacts generated by Dynamo?
---------------------------------------------
To inspect the artifacts generated by Dynamo, there is an API ``torch._dynamo.eval_frame._debug_get_cache_entry_list`` that retrieves compiled code and guards out of a function's ``__code__`` object. A compiled function can have several cache entries, and each cache entry consists a generated function to check guards, and a ``types.CodeType`` object to keep the code to be executed if the guarding conditions are satisfied.
.. code-block:: python
from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn
cache_entries = _debug_get_cache_entry_list(innermost_fn(toy_example))
cache_entry = cache_entries[0]
guard, code = cache_entry.check_fn, cache_entry.code
# the guard takes the local variables of an input frame, and tells whether a re-compilation should be triggered.
import dis
dis.dis(guard)
dis.dis(code)
If you know Python bytecode, you can understand the above output.
For the guard function, there is no need to inspect the bytecode. We can directly access its guarding conditions:
.. code-block:: python
for code_part in guard.code_parts:
print(code_part)
The output is:
::
___guarded_code.valid
___check_global_state()
hasattr(L['a'], '_dynamo_dynamic_indices') == False
hasattr(L['b'], '_dynamo_dynamic_indices') == False
utils_device.CURRENT_DEVICE == None
___skip_backend_check() or ___current_backend() == ___lookup_backend(140215810860528)
___check_tensors(L['a'], L['b'], tensor_check_names=tensor_check_names)
Only when all the conditions are satisfied, the guard function returns true, and the compiled code is executed.
For the compiled code, we cannot directly access its source but have to decompile it.
.. code-block:: python
from depyf import decompile
print(decompile(code))
The output is:
::
def toy_example(a, b):
__temp_1 = __compiled_fn_0(a, b)
x = __temp_1[0]
if __temp_1[1]:
return __resume_at_30_1(b, x)
return __resume_at_38_2(b, x)
Some names referenced in the code are:
- Compiled functions, stored in the global namespace of the module containing the original function ``toy_example``. These include names like ``__compiled_fn_0`` / ``__resume_at_30_1`` / ``__resume_at_38_2``.
- Closure variables used for checking guards. The names can be accessed from ``guard.__code__.co_freevars``, and the values are stored in ``guard.__closure__``. These include names like ``___guarded_code`` / ``___is_grad_enabled`` / ``___are_deterministic_algorithms_enabled`` / ``___is_torch_function_enabled`` / ``utils_device`` / ``___check_tensors`` / ``tensor_check_names``.
- Argument ``L`` of the ``guard`` function. This is a dict mapping the name of arguments of ``toy_example`` to its values. This is only available when the function is called, where the frame evaluation API comes into play. In short, ``L`` is a ``dict`` with structure of ``{'a': value_a, 'b': value_b}``. Therefore, you can see the code uses ``L['a']`` to refer to the input variable ``a``.
The graph break is shown in the code of compiled ``toy_example``, where we have to use Python interpreter to select the following graph to execute.
Note that we pass a simple ``my_compiler`` function as the backend compiler, therefore the subgraph code ``__resume_at_38_2``, ``__resume_at_30_1``, and ``__compiled_fn_0`` remain Python code. This can also be inspected (please ignore the function name, and only use the function signature and function body code):
.. code-block:: python
print("source code of __compiled_fn_0:")
print(innermost_fn(__compiled_fn_0).__self__.code)
print("=" * 60)
print("source code of __resume_at_30_1:")
print(decompile(__resume_at_30_1))
print("=" * 60)
print("source code of __resume_at_38_2:")
print(decompile(__resume_at_38_2))
::
source code of __compiled_fn_0:
def forward(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
l_a_ = L_a_
l_b_ = L_b_
abs_1 = torch.abs(l_a_)
add = abs_1 + 1; abs_1 = None
truediv = l_a_ / add; l_a_ = add = None
sum_1 = l_b_.sum(); l_b_ = None
lt = sum_1 < 0; sum_1 = None
return (truediv, lt)
# To see more debug info, please use ``graph_module.print_readable()``
============================================================
source code of __resume_at_30_1:
def <resume in toy_example>(b, x):
b = b * -1
return x * b
============================================================
source code of __resume_at_38_2:
def <resume in toy_example>(b, x):
return x * b
However, if we use other backends like the built-in ``inductor``, the subgraph code will be compiled CUDA kernels for GPU or C++ code for CPU.
To summarize, the compiled code is conceptually equivalent to the code below:
.. code-block:: python
def compiled_example(a, b):
L = {'a': a, 'b': b}
for guard, code in get_cache_entries():
if guard(L):
return code(a, b)
recompile_and_add_another_cache_entry()
The following diagram demonstrates how ``torch.compile`` transforms and optimizes user-written code: it first extracts computation graphs from the user-written function, and compiles these graphs into optimized functions, then assembles them into a new function, which is functionally equivalent to the user-written code but optimized to have a good computation speed.
.. image:: _static/img/dynamo/flowchart.jpg
To learn more about how all this is implemented internally, see :ref:`torch.compiler_dynamo_deepdive`.

View File

@ -1,57 +1,48 @@
Fake tensor
===========
# Fake tensor
Code: `fake_tensor.py <https://github.com/pytorch/pytorch/blob/db4572dbf18f1cf50cf662547e272d3117063747/torch/_subclasses/fake_tensor.py>`_
Code: [fake_tensor.py](https://github.com/pytorch/pytorch/blob/db4572dbf18f1cf50cf662547e272d3117063747/torch/_subclasses/fake_tensor.py)
Motivation
----------
## Motivation
When doing Dynamo symbolic evaluation and compiler passes, we often want to be able to run tensor operations to understand what output sizes/dtypes/devices are, without actually running those operations (or trashing preexisting tensors), which would be slower (if you're doing a lot of compute) and take a lot of memory (it's bad if your compiler needs to use GPU memory while you are compiling the program). A fake tensor is like a real tensor in all respects, except that it doesn't actually have any data. For example, when we do Dynamo tracing, we need to trace through user Tensor code and answer questions about intermediates (e.g., if a user does a conditional on an intermediate tensor). Without fake tensor, we would not have accurate information for these queries.
Similarly, suppose you want to store metadata for a tensor, e.g., on an FX IR node (meta['val']). You can instead store a fake tensor directly on the node, which will give you all the metadata you need for the tensor, including subtle stuff that you probably wouldn't have handled (e.g., aliasing relationships).
Related work
------------
## Related work
- A meta tensor is a tensor with device='meta'. This is actually a lot of what you want for fake tensor, but meta tensors don't model devices, and sometimes stride behavior varies depending on your device, so fake tensors really can get a lot more accurate info this way. Also, meta tensors are "global" (they exist on their own, similar to how a CPU/CUDA tensor exist on their own), whereas fake tensors are scoped to a FakeTensorMode.
- A tensor subclass lets you subclass torch.Tensor and customize their behavior. Fake tensors are implemented as a tensor subclass; that means almost all of its implementation lives in Python! For more simple examples of tensor subclasses check out `subclass_zoo <https://github.com/albanD/subclass_zoo/>`_.
- A tensor subclass lets you subclass torch.Tensor and customize their behavior. Fake tensors are implemented as a tensor subclass; that means almost all of its implementation lives in Python! For more simple examples of tensor subclasses check out [subclass_zoo](https://github.com/albanD/subclass_zoo/).
- Dynamic shapes allow you to create tensors with symbolic sizes rather than only concrete sizes, and propagate these sizes symbolically through operations. Dynamic shapes maintain state in a ShapeEnv, which is always associated with a FakeTensorMode (so fake tensors also are responsible for managing symbolic sizes.) In general, whenever we compile a subgraph with PT2, there is a tracing context associated with this compilation, which contains, among other things, a FakeTensorMode and (possibly) a ShapeEnv.
Overall architecture
--------------------
## Overall architecture
All fake tensors are associated with a FakeTensorMode. Because fake tensor's primary use case is to do analysis on real tensors, the general workflow is you have a bunch of real tensors, you allocate a FakeTensorMode, and then you use from_real_tensor to convert all those real tensors into fake tensors, and then you do things to the fake tensors. In particular, the FakeTensorMode maintains a memo table persistently mapping tensors (and storages) to the same storages. If you fakeify the same tensor multiple times, you will get the same fake tensor; if you fakeify two tensors which alias each other, you will get two fake tensors which alias the same fake storage. FakeTensors are tensor subclasses, so if you do operations on them, you'll automatically get a fake tensor, but in general you will want to do operations on fake tensors (e.g., if you're running an FX pass) with the FakeTensorMode active; what a tensor operation will do is automatically turn on the fake tensor mode and try again.
A fake tensor is represented as a __torch_dispatch__ tensor subclass of a meta tensor. This means under the hood, fake tensors are meta device tensors; they then use extra extensibility hooks, specifically dispatch_device, to lie about what the actual device of the tensor is. This was one of the more error-prone parts of fake tensors in the early days: sometimes, fake tensors were too good at lying about being CPU/CUDA whatever, and you'd end up with a CPU kernel getting called with a fake tensor trying to dereference the data pointer, which obviously won't work. If you are segfaulting in fake tensor code, this is the first thing you should check: is the C++ backtrace in a CPU kernel (unexpected!) or a meta kernel (expected!) A meta kernel is like a real kernel, but all it does is allocate the outputs, it doesn't do any data compute.
A fake tensor is represented as a \_\_torch_dispatch\_\_ tensor subclass of a meta tensor. This means under the hood, fake tensors are meta device tensors; they then use extra extensibility hooks, specifically dispatch_device, to lie about what the actual device of the tensor is. This was one of the more error-prone parts of fake tensors in the early days: sometimes, fake tensors were too good at lying about being CPU/CUDA whatever, and you'd end up with a CPU kernel getting called with a fake tensor trying to dereference the data pointer, which obviously won't work. If you are segfaulting in fake tensor code, this is the first thing you should check: is the C++ backtrace in a CPU kernel (unexpected!) or a meta kernel (expected!) A meta kernel is like a real kernel, but all it does is allocate the outputs, it doesn't do any data compute.
A tensor subclass has to define how to implement various operations. Here is the general fake tensor recipe:
- Run the meta kernel on the input fake tensors, reinterpreting them as meta tensors. This is done via a magic context manager in_kernel_invocation_manager which instructs all of PyTorch to view fake tensors as their underlying meta tensors, rather than "unwrapping" fake tensors into meta tensors (a fake tensor is a meta tensor). Fake tensors are represented this way to avoid having to keep two sets of metadata in sync (the meta tensor's metadata, and the fake tensor's metadata); the "is a" relationship ensures there is only one canonical copy of metadata.
- If you're a factory function, you'll instead call the underlying factory function with device='meta'.
- Convert the resulting meta tensor into a fake tensor, computing what the output device of the tensor should be (this is usually trivial, but sometimes it is not, e.g., cpu scalar promotion, or device-converting operations.)
API: the important bits
-----------------------
## API: the important bits
Non-PT2 usage (check out test/test_fake_tensor.py for more examples):
.. code:: python
# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
# Fakeify some real tensors
fake_x = converter.from_real_tensor(fake_mode, x)
with fake_mode:
# Do some operations on the fake tensors
fake_y = fake_x * 2
# Factory operations automatically get fakeified in the context manager
fake_z = torch.empty(20)
```python
# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
# Fakeify some real tensors
fake_x = converter.from_real_tensor(fake_mode, x)
with fake_mode:
# Do some operations on the fake tensors
fake_y = fake_x * 2
# Factory operations automatically get fakeified in the context manager
fake_z = torch.empty(20)
```
Q: Why do you have real tensors as inputs?
@ -59,64 +50,61 @@ A: In a PT2 context, this is because you typically are compiling just-in-time, s
PT2 pre-AOTAutograd usage (this is unusual, you probably don't want to do this):
.. code:: python
# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
# if fake_mode isn't None
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
... # do stuff with the fake args, if needed ...
```python
# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
# if fake_mode isn't None
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
... # do stuff with the fake args, if needed ...
```
detect_fake_mode will search a number of locations to try to find "the" fake tensor mode associated with the lifecycle. Typically it will be pulled off of the tracing context.
PT2 post-AOTAutograd usage:
.. code:: python
# Fake mode is enabled! example_inputs is typically fake already
# TODO: we probably want to change this
# Still do this to access fake mode
fake_mode = detect_fake_mode(example_inputs)
# But in general you don't have to turn it on
```python
# Fake mode is enabled! example_inputs is typically fake already
# TODO: we probably want to change this
# Still do this to access fake mode
fake_mode = detect_fake_mode(example_inputs)
# But in general you don't have to turn it on
```
Other useful stuff:
.. code:: python
from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
... # fake mode is disabled here, you can do real tensor compute
```python
from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
... # fake mode is disabled here, you can do real tensor compute
```
When might you want to disable fake tensor mode? Usually you don't want to do this. One niche case where we've found it useful is to implement constant propagation on fake tensors: in this case, we need to do some actual tensor computation even though we're in a fake tensor mode.
.. code:: python
```python
import FakeTensorProp from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
```
import FakeTensorProp from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
Details
-------
## Details
Auto-convert or not?
Originally, FakeTensorMode would not automatically fakeify real tensors if you tried to do compute on them inside a FakeTensorMode region. The motivation behind this was to prevent the following footgun:
.. code:: python
with FakeTensorMode():
real_tensor.t_()
```python
with FakeTensorMode():
real_tensor.t_()
```
What should this code do? It would be surprising if we actually modified the metadata on the real tensor. But at the same time, there isn't any obvious opportunity to create a FakeTensor. So we conservatively decided to make this raise an error: "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. Please convert all Tensors to FakeTensors first."
@ -125,17 +113,15 @@ This error is pretty annoying in practice. For example, suppose you have a real
Eventually, we gave up and added automatic fakeification. However, this is still not yet enabled by default in many uses of FakeTensorMode.
Metadata mutation on fake tensor
If you have a fake tensor, and you t_() it, the metadata on the fake tensor changes. This is reasonable on its face, but sometimes you want to also store fake tensors as metadata on FX nodes; mutating a fake tensor is bad because this will invalidate old metadata!
If you have a fake tensor, and you t\_() it, the metadata on the fake tensor changes. This is reasonable on its face, but sometimes you want to also store fake tensors as metadata on FX nodes; mutating a fake tensor is bad because this will invalidate old metadata!
In fact, there is a fundamental tension here, which is that fake tensors maintain extremely accurate metadata about tensors, up to and including object identity. If object metadata changes over time in an FX graph, there is not actually any way to represent this change over time. Most of the time, our serious FX analyses are done on functionalized graphs, which don't have this, but occasionally you need to do an analysis on a non-functionalized graph. Maybe it was a mistake to put fake tensor in meta['val']
About the tensor subclass
-------------------------
## About the tensor subclass
Fake tensor uses both a subclass and a mode tensor subclass pattern, where FakeTensor.__torch_dispatch__ enables the FakeTensorMode associated with the fake tensor, and then redispatches (relying on FakeTensorMode to do the heavy lifting). If fake tensor operations get a subclass argument it doesn't recognize, it will return NotImplemented, giving the other subclass a chance to run first (hopefully desugaring into plain tensor operations), before it tries again. This can cause infinite loops.
Fake tensor uses both a subclass and a mode tensor subclass pattern, where FakeTensor.\_\_torch_dispatch\_\_ enables the FakeTensorMode associated with the fake tensor, and then redispatches (relying on FakeTensorMode to do the heavy lifting). If fake tensor operations get a subclass argument it doesn't recognize, it will return NotImplemented, giving the other subclass a chance to run first (hopefully desugaring into plain tensor operations), before it tries again. This can cause infinite loops.
How is each individual operator implemented?
--------------------------------------------
## How is each individual operator implemented?
Unfortunately, there is a pretty complicated set of places where any given operator may be implemented. Some important cases to know about:
@ -145,32 +131,27 @@ Unfortunately, there is a pretty complicated set of places where any given opera
- Fake tensor itself has some hardcoded special cases for device-converting operations.
- If there is no meta implementation nor any decomposition, we will generate real zero-filled tensors and attempt to run the operator directly to find out what the results will be. This can cause segfaults if the operator attempts to do indexing with data, so we don't turn this on by default for custom ops.
How does the converter work?
----------------------------
## How does the converter work?
Because fake tensors are used in situations that are very sensitive to the exact properties of a tensor, fake tensors do conversion very carefully, preserving leaf-ness, requires_grad'ness, aliasing, and a whole host of other properties. The bulk of the heavy lifting is in MetaConverter.
Performance characteristics
---------------------------
## Performance characteristics
You would think fake tensors are fast because they don't do any tensor compute. But at small tensor sizes we are actually entirely overhead bound, and, well, fake tensor is in Python, and we often do a LOT of work to do a single tensor operation (because they are implemented as decompositions). So fake tensors are actually pretty slow in practice, especially when symbolic shapes are involved. There are two important fastpaths we currently have in fake tensor that make a big difference in practice:
- Pointwise ops don't go through PrimTorch decomps, instead we've hand-coded their propagation rule.
- If possible, we should.
Fake tensor of fake tensor?
----------------------------
## Fake tensor of fake tensor?
There is interest in sending fake tensors as user inputs into the PT2 stack, which would imply we would need to be able to create a fake tensor of a fake tensor. This isn't really supported right now, but maybe it would not be too difficult to do.
Interaction with dynamic shapes
-------------------------------
## Interaction with dynamic shapes
Every FakeTensorMode contains a ShapeEnv, which tracks all symbolic shapes information. Their lifetimes are typically tied: they live and die together.
Because FakeTensorMode has a ShapeEnv (but meta implementations do not), meta functions that are data-dependent and require allocating an unbacked SymInt live in fake tensor. Fake tensor also takes care of memoizing unbacked SymInts, so that, e.g., if you call nonzero() on the same fake tensor twice, you get the same symbolic size.
Other resources
---------------
## Other resources
`Colab Tutorial On Using FakeTensor To Determine Max Batch Size <https://colab.research.google.com/drive/1zjAisRrc8R6uixKsrs1DRm3lwz5MWN68>`_
[Colab Tutorial On Using FakeTensor To Determine Max Batch Size](https://colab.research.google.com/drive/1zjAisRrc8R6uixKsrs1DRm3lwz5MWN68)