mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Test Plan: - view preview Future: - still need to figure out the make_fx situation Pull Request resolved: https://github.com/pytorch/pytorch/pull/91811 Approved by: https://github.com/albanD
208 lines
8.7 KiB
ReStructuredText
208 lines
8.7 KiB
ReStructuredText
Migrating from functorch to torch.func
|
|
======================================
|
|
|
|
torch.func, previously known as "functorch", is
|
|
`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
|
|
|
|
functorch started as an out-of-tree library over at
|
|
the `pytorch/functorch <https://github.com/pytorch/functorch>`_ repository.
|
|
Our goal has always been to upstream functorch directly into PyTorch and provide
|
|
it as a core PyTorch library.
|
|
|
|
As the final step of the upstream, we've decided to migrate from being a top level package
|
|
(``functorch``) to being a part of PyTorch to reflect how the function transforms are
|
|
integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating
|
|
``import functorch`` and ask that users migrate to the newest APIs, which we
|
|
will maintain going forward. ``import functorch`` will be kept around to maintain
|
|
backwards compatibility for a couple of releases.
|
|
|
|
function transforms
|
|
-------------------
|
|
|
|
The following APIs are a drop-in replacement for the following
|
|
`functorch APIs <https://pytorch.org/functorch/1.13/functorch.html>`_.
|
|
They are fully backwards compatible.
|
|
|
|
|
|
============================== =======================================
|
|
functorch API PyTorch API (as of PyTorch 2.0)
|
|
============================== =======================================
|
|
functorch.vmap :func:`torch.vmap` or :func:`torch.func.vmap`
|
|
functorch.grad :func:`torch.func.grad`
|
|
functorch.vjp :func:`torch.func.vjp`
|
|
functorch.jvp :func:`torch.func.jvp`
|
|
functorch.jacrev :func:`torch.func.jacrev`
|
|
functorch.jacfwd :func:`torch.func.jacfwd`
|
|
functorch.hessian :func:`torch.func.hessian`
|
|
functorch.functionalize :func:`torch.func.functionalize`
|
|
============================== =======================================
|
|
|
|
Furthermore, if you are using torch.autograd.functional APIs, please try out
|
|
the :mod:`torch.func` equivalents instead. :mod:`torch.func` function
|
|
transforms are more composable and more performant in many cases.
|
|
|
|
=========================================== =======================================
|
|
torch.autograd.functional API torch.func API (as of PyTorch 2.0)
|
|
=========================================== =======================================
|
|
:func:`torch.autograd.functional.vjp` :func:`torch.func.grad` or :func:`torch.func.vjp`
|
|
:func:`torch.autograd.functional.jvp` :func:`torch.func.jvp`
|
|
:func:`torch.autograd.functional.jacobian` :func:`torch.func.jacrev` or :func:`torch.func.jacfwd`
|
|
:func:`torch.autograd.functional.hessian` :func:`torch.func.hessian`
|
|
=========================================== =======================================
|
|
|
|
NN module utilities
|
|
-------------------
|
|
|
|
We've changed the APIs to apply function transforms over NN modules to make them
|
|
fit better into the PyTorch design philosophy. The new API is different, so
|
|
please read this section carefully.
|
|
|
|
functorch.make_functional
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
:func:`torch.func.functional_call` is the replacement for
|
|
`functorch.make_functional <https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional>`_
|
|
and
|
|
`functorch.make_functional_with_buffers <https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers>`_.
|
|
However, it is not a drop-in replacement.
|
|
|
|
If you're in a hurry, you can use
|
|
`helper functions in this gist <https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf>`_
|
|
that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.
|
|
We recommend using :func:`torch.func.functional_call` directly because it is a more explicit
|
|
and flexible API.
|
|
|
|
Concretely, functorch.make_functional returns a functional module and parameters.
|
|
The functional module accepts parameters and inputs to the model as arguments.
|
|
:func:`torch.func.functional_call` allows one to call the forward pass of an existing
|
|
module using new parameters and buffers and inputs.
|
|
|
|
Here's an example of how to compute gradients of parameters of a model using functorch
|
|
vs :mod:`torch.func`::
|
|
|
|
# ---------------
|
|
# using functorch
|
|
# ---------------
|
|
import torch
|
|
import functorch
|
|
inputs = torch.randn(64, 3)
|
|
targets = torch.randn(64, 3)
|
|
model = torch.nn.Linear(3, 3)
|
|
|
|
fmodel, params = functorch.make_functional(model)
|
|
|
|
def compute_loss(params, inputs, targets):
|
|
prediction = fmodel(params, inputs)
|
|
return torch.nn.functional.mse_loss(prediction, targets)
|
|
|
|
grads = functorch.grad(compute_loss)(params, inputs, targets)
|
|
|
|
# ------------------------------------
|
|
# using torch.func (as of PyTorch 2.0)
|
|
# ------------------------------------
|
|
import torch
|
|
inputs = torch.randn(64, 3)
|
|
targets = torch.randn(64, 3)
|
|
model = torch.nn.Linear(3, 3)
|
|
|
|
params = dict(model.named_parameters())
|
|
|
|
def compute_loss(params, inputs, targets):
|
|
prediction = torch.func.functional_call(model, params, (inputs,))
|
|
return torch.nn.functional.mse_loss(prediction, targets)
|
|
|
|
grads = torch.func.grad(compute_loss)(params, inputs, targets)
|
|
|
|
And here's an example of how to compute jacobians of model parameters::
|
|
|
|
# ---------------
|
|
# using functorch
|
|
# ---------------
|
|
import torch
|
|
import functorch
|
|
inputs = torch.randn(64, 3)
|
|
model = torch.nn.Linear(3, 3)
|
|
|
|
fmodel, params = functorch.make_functional(model)
|
|
jacobians = functorch.jacrev(fmodel)(params, inputs)
|
|
|
|
# ------------------------------------
|
|
# using torch.func (as of PyTorch 2.0)
|
|
# ------------------------------------
|
|
import torch
|
|
from torch.func import jacrev, functional_call
|
|
inputs = torch.randn(64, 3)
|
|
model = torch.nn.Linear(3, 3)
|
|
|
|
params = dict(model.named_parameters())
|
|
# jacrev computes jacobians of argnums=0 by default.
|
|
# We set it to 1 to compute jacobians of params
|
|
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
|
|
|
|
Note that it is important for memory consumption that you should only carry
|
|
around a single copy of your parameters. ``model.named_parameters()`` does not copy
|
|
the parameters. If in your model training you update the parameters of the model
|
|
in-place, then the ``nn.Module`` that is your model has the single copy of the
|
|
parameters and everything is OK.
|
|
|
|
However, if you want to carry your parameters around in a dictionary and update
|
|
them out-of-place, then there are two copies of parameters: the one in the
|
|
dictionary and the one in the ``model``. In this case, you should change
|
|
``model`` to not hold memory by converting it to the meta device via
|
|
``model.to('meta')``.
|
|
|
|
functorch.combine_state_for_ensemble
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
Please use :func:`torch.func.stack_module_state` instead of
|
|
`functorch.combine_state_for_ensemble <https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html>`_
|
|
:func:`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and
|
|
one of stacked buffers, that can then be used with :func:`torch.vmap` and :func:`torch.func.functional_call`
|
|
for ensembling.
|
|
|
|
For example, here is an example of how to ensemble over a very simple model::
|
|
|
|
import torch
|
|
num_models = 5
|
|
batch_size = 64
|
|
in_features, out_features = 3, 3
|
|
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
|
|
data = torch.randn(batch_size, 3)
|
|
|
|
# ---------------
|
|
# using functorch
|
|
# ---------------
|
|
import functorch
|
|
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
|
|
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
|
|
assert output.shape == (num_models, batch_size, out_features)
|
|
|
|
# ------------------------------------
|
|
# using torch.func (as of PyTorch 2.0)
|
|
# ------------------------------------
|
|
import copy
|
|
|
|
# Construct a version of the model with no memory by putting the Tensors on
|
|
# the meta device.
|
|
base_model = copy.deepcopy(models[0])
|
|
base_model.to('meta')
|
|
|
|
params, buffers = torch.func.stack_module_state(models)
|
|
|
|
# It is possible to vmap directly over torch.func.functional_call,
|
|
# but wrapping it in a function makes it clearer what is going on.
|
|
def call_single_model(params, buffers, data):
|
|
return torch.func.functional_call(base_model, (params, buffers), (data,))
|
|
|
|
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
|
|
assert output.shape == (num_models, batch_size, out_features)
|
|
|
|
|
|
functorch.compile
|
|
-----------------
|
|
|
|
We are no longer supporting functorch.compile (also known as AOTAutograd)
|
|
as a frontend for compilation in PyTorch; we have integrated AOTAutograd
|
|
into PyTorch's compilation story. If you are a user, please use
|
|
:func:`torch.compile` instead.
|