mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update custom backend docs (#92721)
Title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/92721 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
c499e760f5
commit
2a6e085704
@ -1,8 +1,103 @@
|
||||
Custom Backends
|
||||
===============
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
``torch.compile`` provides a straightforward method to enable users
|
||||
to define custom backends.
|
||||
|
||||
A backend function has the contract
|
||||
``(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable``.
|
||||
|
||||
Backend functions can be called by TorchDynamo, the graph tracing component of ``torch.compile``,
|
||||
after tracing an FX graph and are
|
||||
expected to return a compiled function that is equivalent to the traced FX graph.
|
||||
The returned callable should have the same contract as the ``forward`` function of the original ``torch.fx.GraphModule``
|
||||
passed into the backend:
|
||||
``(*args: torch.Tensor) -> List[torch.Tensor]``.
|
||||
|
||||
In order for TorchDynamo to call your backend, pass your backend function as the ``backend`` kwarg in
|
||||
``torch.compile``. For example,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
|
||||
def my_custom_backend(gm, example_inputs):
|
||||
return gm.forward
|
||||
|
||||
def f(...):
|
||||
...
|
||||
|
||||
f_opt = torch.compile(f, backend=my_custom_backend)
|
||||
|
||||
@torch.compile(backend=my_custom_backend)
|
||||
def g(...):
|
||||
...
|
||||
|
||||
See below for more examples.
|
||||
|
||||
Registering Custom Backends
|
||||
---------------------------
|
||||
|
||||
You can register your backend using the ``register_backend`` decorator, for example,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from torch._dynamo.optimizations import register_backend
|
||||
|
||||
@register_backend
|
||||
def my_compiler(gm, example_inputs):
|
||||
...
|
||||
|
||||
Registration serves two purposes:
|
||||
|
||||
* You can pass a string containing your backend function's name to ``torch.compile`` instead of the function itself,
|
||||
for example, ``torch.compile(model, backend="my_compiler")``.
|
||||
* It is required for use with the `minifier <https://pytorch.org/docs/master/dynamo/troubleshooting.html>`__. Any generated
|
||||
code from the minifier must call your code that registers your backend function, typically through an ``import`` statement.
|
||||
|
||||
Custom Backends after AOTAutograd
|
||||
---------------------------------
|
||||
|
||||
It is possible to define custom backends that are called by AOTAutograd rather than TorchDynamo.
|
||||
This is useful for 2 main reasons:
|
||||
|
||||
* Users can define backends that support model training, as AOTAutograd can generate the backward graph for compilation.
|
||||
* AOTAutograd produces FX graphs consisting of `canonical Aten ops <https://pytorch.org/docs/master/ir.html#canonical-aten-ir>`__. As a result,
|
||||
custom backends only need to support the canonical Aten opset, which is a significantly smaller opset than the entire torch/Aten opset.
|
||||
|
||||
Wrap your backend with
|
||||
``torch._dynamo.optimizations.training.aot_autograd`` and use ``torch.compile`` with the ``backend`` kwarg as before.
|
||||
Backend functions wrapped by ``aot_autograd`` should have the same contract as before.
|
||||
|
||||
Backend functions are passed to ``aot_autograd`` through the ``fw_compiler`` (forward compiler)
|
||||
or ``bw_compiler`` (backward compiler) kwargs. If ``bw_compiler`` is not specified, the backward compile function
|
||||
defaults to the forward compile function.
|
||||
|
||||
One caveat is that AOTAutograd requires compiled functions returned by backends to be "boxed". This can be done by wrapping
|
||||
the compiled function with ``functorch.compile.make_boxed_func``.
|
||||
|
||||
For example,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from torch._dynamo.optimizations.training import aot_autograd
|
||||
from functorch.compile import make_boxed_func
|
||||
|
||||
def my_compiler(gm, example_inputs):
|
||||
return make_boxed_func(gm.forward)
|
||||
|
||||
my_backend = aot_autograd(fw_compiler=my_compiler) # bw_compiler=my_compiler
|
||||
|
||||
model_opt = torch.compile(model, backend=my_backend)
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Debugging Backend
|
||||
-----------------
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
If you want to better understand what is going on during a
|
||||
compilation, you can create a custom compiler, which is referred to as
|
||||
@ -16,12 +111,11 @@ For example:
|
||||
|
||||
from typing import List
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
print("my_compiler() called with FX graph:")
|
||||
gm.graph.print_tabular()
|
||||
return gm.forward # return a python callable
|
||||
@dynamo.optimize(my_compiler)
|
||||
@torch.compile(backend=my_compiler)
|
||||
def fn(x, y):
|
||||
a = torch.cos(x)
|
||||
b = torch.sin(y)
|
||||
@ -46,8 +140,12 @@ This works for ``torch.nn.Module`` as well as shown below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from typing import List
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
print("my_compiler() called with FX graph:")
|
||||
gm.graph.print_tabular()
|
||||
return gm.forward # return a python callable
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -55,7 +153,7 @@ This works for ``torch.nn.Module`` as well as shown below:
|
||||
def forward(self, x):
|
||||
return self.relu(torch.cos(x))
|
||||
mod = MockModule()
|
||||
optimized_mod = dynamo.optimize(my_compiler)(mod)
|
||||
optimized_mod = torch.compile(mod, backend=my_compiler)
|
||||
optimized_mod(torch.randn(10))
|
||||
|
||||
Let’s take a look at one more example with control flow:
|
||||
@ -64,12 +162,11 @@ Let’s take a look at one more example with control flow:
|
||||
|
||||
from typing import List
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
print("my_compiler() called with FX graph:")
|
||||
gm.graph.print_tabular()
|
||||
return gm.forward # return a python callable
|
||||
@dynamo.optimize(my_compiler)
|
||||
@torch.compile(backend=my_compiler)
|
||||
def toy_example(a, b):
|
||||
x = a / (torch.abs(a) + 1)
|
||||
if b.sum() < 0:
|
||||
@ -115,7 +212,7 @@ The order of the last two graphs is nondeterministic depending
|
||||
on which one is encountered first by the just-in-time compiler.
|
||||
|
||||
Speedy Backend
|
||||
--------------
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
Integrating a custom backend that offers superior performance is also
|
||||
easy and we’ll integrate a real one
|
||||
@ -124,34 +221,40 @@ with `optimize_for_inference <https://pytorch.org/docs/stable/generated/torch.ji
|
||||
.. code-block:: python
|
||||
|
||||
def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
scripted = torch.jit.trace(gm, example_inputs)
|
||||
scripted = torch.jit.script(gm)
|
||||
return torch.jit.optimize_for_inference(scripted)
|
||||
|
||||
And then you should be able to optimize any existing code with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@dynamo.optimize(optimize_for_inference_compiler)
|
||||
@torch.compile(backend=optimize_for_inference_compiler)
|
||||
def code_to_accelerate():
|
||||
...
|
||||
|
||||
Composable Backends
|
||||
-------------------
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
TorchDynamo includes many backends, which can be found in
|
||||
`backends.py <https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py>`__
|
||||
or ``torchdynamo.list_backends()``. You can combine these backends
|
||||
or ``torch._dynamo.list_backends()``. You can combine these backends
|
||||
together with the following code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from torch._dynamo.optimizations import BACKENDS
|
||||
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
trt_compiled = BACKENDS["tensorrt"](gm, example_inputs)
|
||||
if trt_compiled is not None:
|
||||
return trt_compiled
|
||||
# first backend failed, try something else...
|
||||
cudagraphs_compiled = BACKENDS["cudagraphs"](gm, example_inputs)
|
||||
if cudagraphs_compiled is not None:
|
||||
return cudagraphs_compiled
|
||||
return gm.forward
|
||||
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
try:
|
||||
trt_compiled = BACKENDS["tensorrt"](gm, example_inputs)
|
||||
if trt_compiled is not None:
|
||||
return trt_compiled
|
||||
except Exception:
|
||||
pass
|
||||
# first backend failed, try something else...
|
||||
try:
|
||||
inductor_compiled = BACKENDS["inductor"](gm, example_inputs)
|
||||
if inductor_compiled is not None:
|
||||
return inductor_compiled
|
||||
except Exception:
|
||||
pass
|
||||
return gm.forward
|
||||
|
Reference in New Issue
Block a user