Convert rst files to md (#155369)

Fixes #155021
Fixes #155158

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155369
Approved by: https://github.com/svekars, https://github.com/malfet
This commit is contained in:
jafraustro
2025-06-11 23:00:52 +00:00
committed by PyTorch MergeBot
parent 48921721d8
commit 1b032384b1
9 changed files with 417 additions and 402 deletions

View File

@ -1,46 +1,75 @@
FullyShardedDataParallel # FullyShardedDataParallel
========================
```{eval-rst}
.. automodule:: torch.distributed.fsdp .. automodule:: torch.distributed.fsdp
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.FullyShardedDataParallel .. autoclass:: torch.distributed.fsdp.FullyShardedDataParallel
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.BackwardPrefetch .. autoclass:: torch.distributed.fsdp.BackwardPrefetch
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.ShardingStrategy .. autoclass:: torch.distributed.fsdp.ShardingStrategy
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.MixedPrecision .. autoclass:: torch.distributed.fsdp.MixedPrecision
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.CPUOffload .. autoclass:: torch.distributed.fsdp.CPUOffload
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.StateDictConfig .. autoclass:: torch.distributed.fsdp.StateDictConfig
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.FullStateDictConfig .. autoclass:: torch.distributed.fsdp.FullStateDictConfig
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.ShardedStateDictConfig .. autoclass:: torch.distributed.fsdp.ShardedStateDictConfig
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.LocalStateDictConfig .. autoclass:: torch.distributed.fsdp.LocalStateDictConfig
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.OptimStateDictConfig .. autoclass:: torch.distributed.fsdp.OptimStateDictConfig
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.FullOptimStateDictConfig .. autoclass:: torch.distributed.fsdp.FullOptimStateDictConfig
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.ShardedOptimStateDictConfig .. autoclass:: torch.distributed.fsdp.ShardedOptimStateDictConfig
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.LocalOptimStateDictConfig .. autoclass:: torch.distributed.fsdp.LocalOptimStateDictConfig
:members: :members:
```
```{eval-rst}
.. autoclass:: torch.distributed.fsdp.StateDictSettings .. autoclass:: torch.distributed.fsdp.StateDictSettings
:members: :members:
```

88
docs/source/func.api.md Normal file
View File

@ -0,0 +1,88 @@
# torch.func API Reference
```{eval-rst}
.. currentmodule:: torch.func
```
```{eval-rst}
.. automodule:: torch.func
```
## Function Transforms
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
vmap
grad
grad_and_value
vjp
jvp
linearize
jacrev
jacfwd
hessian
functionalize
```
## Utilities for working with torch.nn.Modules
In general, you can transform over a function that calls a ``torch.nn.Module``.
For example, the following is an example of computing a jacobian of a function
that takes three values and returns three values:
```python
model = torch.nn.Linear(3, 3)
def f(x):
return model(x)
x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
```
However, if you want to do something like compute a jacobian over the parameters of the model, then there needs to be a way to construct a function where the parameters are the inputs to the function. That's what {func}`functional_call` is for: it accepts an nn.Module, the transformed `parameters`, and the inputs to the Module's forward pass. It returns the value of running the Module's forward pass with the replaced parameters.
Here's how we would compute the Jacobian over the parameters
```python
model = torch.nn.Linear(3, 3)
def f(params, x):
return torch.func.functional_call(model, params, x)
x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
functional_call
stack_module_state
replace_all_batch_norm_modules_
```
If you're looking for information on fixing Batch Norm modules, please follow the
guidance here
```{eval-rst}
.. toctree::
:maxdepth: 1
func.batch_norm
```
## Debug utilities
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
debug_unwrap
```

View File

@ -1,87 +0,0 @@
torch.func API Reference
========================
.. currentmodule:: torch.func
.. automodule:: torch.func
Function Transforms
-------------------
.. autosummary::
:toctree: generated
:nosignatures:
vmap
grad
grad_and_value
vjp
jvp
linearize
jacrev
jacfwd
hessian
functionalize
Utilities for working with torch.nn.Modules
-------------------------------------------
In general, you can transform over a function that calls a ``torch.nn.Module``.
For example, the following is an example of computing a jacobian of a function
that takes three values and returns three values:
.. code-block:: python
model = torch.nn.Linear(3, 3)
def f(x):
return model(x)
x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
However, if you want to do something like compute a jacobian over the parameters
of the model, then there needs to be a way to construct a function where the
parameters are the inputs to the function.
That's what :func:`functional_call` is for:
it accepts an nn.Module, the transformed ``parameters``, and the inputs to the
Module's forward pass. It returns the value of running the Module's forward pass
with the replaced parameters.
Here's how we would compute the Jacobian over the parameters
.. code-block:: python
model = torch.nn.Linear(3, 3)
def f(params, x):
return torch.func.functional_call(model, params, x)
x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
.. autosummary::
:toctree: generated
:nosignatures:
functional_call
stack_module_state
replace_all_batch_norm_modules_
If you're looking for information on fixing Batch Norm modules, please follow the
guidance here
.. toctree::
:maxdepth: 1
func.batch_norm
Debug utilities
---------------
.. autosummary::
:toctree: generated
:nosignatures:
debug_unwrap

View File

@ -1,83 +1,75 @@
Patching Batch Norm # Patching Batch Norm
===================
What's happening? ## What's happening?
-----------------
Batch Norm requires in-place updates to running_mean and running_var of the same size as the input. Batch Norm requires in-place updates to running_mean and running_var of the same size as the input.
Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e. Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e.
``regular.add_(batched)`` is not allowed). So when vmapping over a batch of inputs to a single module, `regular.add_(batched)` is not allowed). So when vmapping over a batch of inputs to a single module,
we end up with this error we end up with this error
How to fix ## How to fix
----------
One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this
All of these options assume that you don't need running stats. If you're using a module this means All of these options assume that you don't need running stats. If you're using a module this means
that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves
running batch norm with vmap in evaluation mode, please file an issue running batch norm with vmap in evaluation mode, please file an issue
Option 1: Change the BatchNorm ### Option 1: Change the BatchNorm
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with: If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with:
.. code-block:: python ```python
BatchNorm2d(C, G, track_running_stats=False)
```
BatchNorm2d(C, G, track_running_stats=False) Here `C` is the same `C` as in the original BatchNorm. `G` is the number of groups to
break `C` into. As such, `C % G == 0` and as a fallback, you can set `C == G`, meaning
Here ``C`` is the same ``C`` as in the original BatchNorm. ``G`` is the number of groups to
break ``C`` into. As such, ``C % G == 0`` and as a fallback, you can set ``C == G``, meaning
each channel will be treated separately. each channel will be treated separately.
If you must use BatchNorm and you've built the module yourself, you can change the module to If you must use BatchNorm and you've built the module yourself, you can change the module to
not use running stats. In other words, anywhere that there's a BatchNorm module, set the not use running stats. In other words, anywhere that there's a BatchNorm module, set the
``track_running_stats`` flag to be False `track_running_stats` flag to be False
.. code-block:: python ```python
BatchNorm2d(64, track_running_stats=False)
```
BatchNorm2d(64, track_running_stats=False) ### Option 2: torchvision parameter
Some torchvision models, like resnet and regnet, can take in a `norm_layer` parameter. These are
Option 2: torchvision parameter
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are
often defaulted to be BatchNorm2d if they've been defaulted. often defaulted to be BatchNorm2d if they've been defaulted.
Instead you can set it to be GroupNorm. Instead you can set it to be GroupNorm.
.. code-block:: python ```python
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
```
import torchvision Here, once again, `c % g == 0` so as a fallback, set `g = c`.
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
Here, once again, ``c % g == 0`` so as a fallback, set ``g = c``.
If you are attached to BatchNorm, be sure to use a version that doesn't use running stats If you are attached to BatchNorm, be sure to use a version that doesn't use running stats
.. code-block:: python ```python
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
```
import torchvision ### Option 3: functorch's patching
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
Option 3: functorch's patching
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
functorch has added some functionality to allow for quick, in-place patching of the module to not functorch has added some functionality to allow for quick, in-place patching of the module to not
use running stats. Changing the norm layer is more fragile, so we have not offered that. If you use running stats. Changing the norm layer is more fragile, so we have not offered that. If you
have a net where you want the BatchNorm to not use running stats, you can run have a net where you want the BatchNorm to not use running stats, you can run
``replace_all_batch_norm_modules_`` to update the module in-place to not use running stats `replace_all_batch_norm_modules_` to update the module in-place to not use running stats
.. code-block:: python ```python
from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)
```
from torch.func import replace_all_batch_norm_modules_ ### Option 4: eval mode
replace_all_batch_norm_modules_(net)
Option 4: eval mode
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode
.. code-block:: python ```python
model.eval()
model.eval() vmap(model)(x)
vmap(model)(x) model.train()
model.train() ```

56
docs/source/func.md Normal file
View File

@ -0,0 +1,56 @@
# torch.func
```{eval-rst}
.. currentmodule:: torch.func
```
torch.func, previously known as "functorch", is
[JAX-like](https://github.com/google/jax) composable function transforms for PyTorch.
```{note}
This library is currently in [beta](https://pytorch.org/blog/pytorch-feature-classification-changes/#beta).
What this means is that the features generally work (unless otherwise documented)
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
may change under user feedback and we don't have full coverage over PyTorch operations.
If you have suggestions on the API or use-cases you'd like to be covered, please
open a GitHub issue or reach out. We'd love to hear about how you're using the library.
```
## What are composable function transforms?
- A "function transform" is a higher-order function that accepts a numerical function
and returns a new function that computes a different quantity.
- {mod}`torch.func` has auto-differentiation transforms (`grad(f)` returns a function that
computes the gradient of `f`), a vectorization/batching transform (`vmap(f)`
returns a function that computes `f` over batches of inputs), and others.
- These function transforms can compose with each other arbitrarily. For example,
composing `vmap(grad(f))` computes a quantity called per-sample-gradients that
stock PyTorch cannot efficiently compute today.
## Why composable function transforms?
There are a number of use cases that are tricky to do in PyTorch today:
- computing per-sample-gradients (or other per-sample quantities)
- running ensembles of models on a single machine
- efficiently batching together tasks in the inner-loop of MAML
- efficiently computing Jacobians and Hessians
- efficiently computing batched Jacobians and Hessians
Composing {func}`vmap`, {func}`grad`, and {func}`vjp` transforms allows us to express the above without designing a separate subsystem for each.
This idea of composable function transforms comes from the [JAX framework](https://github.com/google/jax).
## Read More
```{eval-rst}
.. toctree::
:maxdepth: 2
func.whirlwind_tour
func.api
func.ux_limitations
func.migrating
```

View File

@ -0,0 +1,201 @@
# Migrating from functorch to torch.func
torch.func, previously known as "functorch", is
[JAX-like](https://github.com/google/jax) composable function transforms for PyTorch.
functorch started as an out-of-tree library over at
the [pytorch/functorch](https://github.com/pytorch/functorch) repository.
Our goal has always been to upstream functorch directly into PyTorch and provide
it as a core PyTorch library.
As the final step of the upstream, we've decided to migrate from being a top level package
(`functorch`) to being a part of PyTorch to reflect how the function transforms are
integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating
`import functorch` and ask that users migrate to the newest APIs, which we
will maintain going forward. `import functorch` will be kept around to maintain
backwards compatibility for a couple of releases.
## function transforms
The following APIs are a drop-in replacement for the following
[functorch APIs](https://pytorch.org/functorch/1.13/functorch.html).
They are fully backwards compatible.
| functorch API | PyTorch API (as of PyTorch 2.0) |
| ----------------------------------- | ---------------------------------------------- |
| functorch.vmap | {func}`torch.vmap` or {func}`torch.func.vmap` |
| functorch.grad | {func}`torch.func.grad` |
| functorch.vjp | {func}`torch.func.vjp` |
| functorch.jvp | {func}`torch.func.jvp` |
| functorch.jacrev | {func}`torch.func.jacrev` |
| functorch.jacfwd | {func}`torch.func.jacfwd` |
| functorch.hessian | {func}`torch.func.hessian` |
| functorch.functionalize | {func}`torch.func.functionalize` |
Furthermore, if you are using torch.autograd.functional APIs, please try out
the {mod}`torch.func` equivalents instead. {mod}`torch.func` function
transforms are more composable and more performant in many cases.
| torch.autograd.functional API | torch.func API (as of PyTorch 2.0) |
| ------------------------------------------- | ---------------------------------------------- |
| {func}`torch.autograd.functional.vjp` | {func}`torch.func.grad` or {func}`torch.func.vjp` |
| {func}`torch.autograd.functional.jvp` | {func}`torch.func.jvp` |
| {func}`torch.autograd.functional.jacobian` | {func}`torch.func.jacrev` or {func}`torch.func.jacfwd` |
| {func}`torch.autograd.functional.hessian` | {func}`torch.func.hessian` |
## NN module utilities
We've changed the APIs to apply function transforms over NN modules to make them
fit better into the PyTorch design philosophy. The new API is different, so
please read this section carefully.
### functorch.make_functional
{func}`torch.func.functional_call` is the replacement for
[functorch.make_functional](https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional)
and
[functorch.make_functional_with_buffers](https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers).
However, it is not a drop-in replacement.
If you're in a hurry, you can use
[helper functions in this gist](https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf)
that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.
We recommend using {func}`torch.func.functional_call` directly because it is a more explicit
and flexible API.
Concretely, functorch.make_functional returns a functional module and parameters.
The functional module accepts parameters and inputs to the model as arguments.
{func}`torch.func.functional_call` allows one to call the forward pass of an existing
module using new parameters and buffers and inputs.
Here's an example of how to compute gradients of parameters of a model using functorch
vs {mod}`torch.func`:
```python
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
def compute_loss(params, inputs, targets):
prediction = fmodel(params, inputs)
return torch.nn.functional.mse_loss(prediction, targets)
grads = functorch.grad(compute_loss)(params, inputs, targets)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
def compute_loss(params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)
grads = torch.func.grad(compute_loss)(params, inputs, targets)
```
And here's an example of how to compute jacobians of model parameters:
```python
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
```
Note that it is important for memory consumption that you should only carry
around a single copy of your parameters. `model.named_parameters()` does not copy
the parameters. If in your model training you update the parameters of the model
in-place, then the `nn.Module` that is your model has the single copy of the
parameters and everything is OK.
However, if you want to carry your parameters around in a dictionary and update
them out-of-place, then there are two copies of parameters: the one in the
dictionary and the one in the `model`. In this case, you should change
`model` to not hold memory by converting it to the meta device via
`model.to('meta')`.
### functorch.combine_state_for_ensemble
Please use {func}`torch.func.stack_module_state` instead of
[functorch.combine_state_for_ensemble](https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html)
{func}`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and
one of stacked buffers, that can then be used with {func}`torch.vmap` and {func}`torch.func.functional_call`
for ensembling.
For example, here is an example of how to ensemble over a very simple model:
```python
import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy
# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')
params, buffers = torch.func.stack_module_state(models)
# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
return torch.func.functional_call(base_model, (params, buffers), (data,))
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
```
## functorch.compile
We are no longer supporting functorch.compile (also known as AOTAutograd)
as a frontend for compilation in PyTorch; we have integrated AOTAutograd
into PyTorch's compilation story. If you are a user, please use
{func}`torch.compile` instead.

View File

@ -1,207 +0,0 @@
Migrating from functorch to torch.func
======================================
torch.func, previously known as "functorch", is
`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
functorch started as an out-of-tree library over at
the `pytorch/functorch <https://github.com/pytorch/functorch>`_ repository.
Our goal has always been to upstream functorch directly into PyTorch and provide
it as a core PyTorch library.
As the final step of the upstream, we've decided to migrate from being a top level package
(``functorch``) to being a part of PyTorch to reflect how the function transforms are
integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating
``import functorch`` and ask that users migrate to the newest APIs, which we
will maintain going forward. ``import functorch`` will be kept around to maintain
backwards compatibility for a couple of releases.
function transforms
-------------------
The following APIs are a drop-in replacement for the following
`functorch APIs <https://pytorch.org/functorch/1.13/functorch.html>`_.
They are fully backwards compatible.
============================== =======================================
functorch API PyTorch API (as of PyTorch 2.0)
============================== =======================================
functorch.vmap :func:`torch.vmap` or :func:`torch.func.vmap`
functorch.grad :func:`torch.func.grad`
functorch.vjp :func:`torch.func.vjp`
functorch.jvp :func:`torch.func.jvp`
functorch.jacrev :func:`torch.func.jacrev`
functorch.jacfwd :func:`torch.func.jacfwd`
functorch.hessian :func:`torch.func.hessian`
functorch.functionalize :func:`torch.func.functionalize`
============================== =======================================
Furthermore, if you are using torch.autograd.functional APIs, please try out
the :mod:`torch.func` equivalents instead. :mod:`torch.func` function
transforms are more composable and more performant in many cases.
=========================================== =======================================
torch.autograd.functional API torch.func API (as of PyTorch 2.0)
=========================================== =======================================
:func:`torch.autograd.functional.vjp` :func:`torch.func.grad` or :func:`torch.func.vjp`
:func:`torch.autograd.functional.jvp` :func:`torch.func.jvp`
:func:`torch.autograd.functional.jacobian` :func:`torch.func.jacrev` or :func:`torch.func.jacfwd`
:func:`torch.autograd.functional.hessian` :func:`torch.func.hessian`
=========================================== =======================================
NN module utilities
-------------------
We've changed the APIs to apply function transforms over NN modules to make them
fit better into the PyTorch design philosophy. The new API is different, so
please read this section carefully.
functorch.make_functional
^^^^^^^^^^^^^^^^^^^^^^^^^
:func:`torch.func.functional_call` is the replacement for
`functorch.make_functional <https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional>`_
and
`functorch.make_functional_with_buffers <https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers>`_.
However, it is not a drop-in replacement.
If you're in a hurry, you can use
`helper functions in this gist <https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf>`_
that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.
We recommend using :func:`torch.func.functional_call` directly because it is a more explicit
and flexible API.
Concretely, functorch.make_functional returns a functional module and parameters.
The functional module accepts parameters and inputs to the model as arguments.
:func:`torch.func.functional_call` allows one to call the forward pass of an existing
module using new parameters and buffers and inputs.
Here's an example of how to compute gradients of parameters of a model using functorch
vs :mod:`torch.func`::
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
def compute_loss(params, inputs, targets):
prediction = fmodel(params, inputs)
return torch.nn.functional.mse_loss(prediction, targets)
grads = functorch.grad(compute_loss)(params, inputs, targets)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
def compute_loss(params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)
grads = torch.func.grad(compute_loss)(params, inputs, targets)
And here's an example of how to compute jacobians of model parameters::
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
Note that it is important for memory consumption that you should only carry
around a single copy of your parameters. ``model.named_parameters()`` does not copy
the parameters. If in your model training you update the parameters of the model
in-place, then the ``nn.Module`` that is your model has the single copy of the
parameters and everything is OK.
However, if you want to carry your parameters around in a dictionary and update
them out-of-place, then there are two copies of parameters: the one in the
dictionary and the one in the ``model``. In this case, you should change
``model`` to not hold memory by converting it to the meta device via
``model.to('meta')``.
functorch.combine_state_for_ensemble
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Please use :func:`torch.func.stack_module_state` instead of
`functorch.combine_state_for_ensemble <https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html>`_
:func:`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and
one of stacked buffers, that can then be used with :func:`torch.vmap` and :func:`torch.func.functional_call`
for ensembling.
For example, here is an example of how to ensemble over a very simple model::
import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy
# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')
params, buffers = torch.func.stack_module_state(models)
# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
return torch.func.functional_call(base_model, (params, buffers), (data,))
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
functorch.compile
-----------------
We are no longer supporting functorch.compile (also known as AOTAutograd)
as a frontend for compilation in PyTorch; we have integrated AOTAutograd
into PyTorch's compilation story. If you are a user, please use
:func:`torch.compile` instead.

View File

@ -1,55 +0,0 @@
torch.func
==========
.. currentmodule:: torch.func
torch.func, previously known as "functorch", is
`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
.. note::
This library is currently in `beta <https://pytorch.org/blog/pytorch-feature-classification-changes/#beta>`_.
What this means is that the features generally work (unless otherwise documented)
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
may change under user feedback and we don't have full coverage over PyTorch operations.
If you have suggestions on the API or use-cases you'd like to be covered, please
open an GitHub issue or reach out. We'd love to hear about how you're using the library.
What are composable function transforms?
----------------------------------------
- A "function transform" is a higher-order function that accepts a numerical function
and returns a new function that computes a different quantity.
- :mod:`torch.func` has auto-differentiation transforms (``grad(f)`` returns a function that
computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)``
returns a function that computes ``f`` over batches of inputs), and others.
- These function transforms can compose with each other arbitrarily. For example,
composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that
stock PyTorch cannot efficiently compute today.
Why composable function transforms?
-----------------------------------
There are a number of use cases that are tricky to do in PyTorch today:
- computing per-sample-gradients (or other per-sample quantities)
- running ensembles of models on a single machine
- efficiently batching together tasks in the inner-loop of MAML
- efficiently computing Jacobians and Hessians
- efficiently computing batched Jacobians and Hessians
Composing :func:`vmap`, :func:`grad`, and :func:`vjp` transforms allows us to express the above without designing a separate subsystem for each.
This idea of composable function transforms comes from the `JAX framework <https://github.com/google/jax>`_.
Read More
---------
.. toctree::
:maxdepth: 2
func.whirlwind_tour
func.api
func.ux_limitations
func.migrating

View File

@ -117,12 +117,10 @@ class OptimStateKeyType(Enum):
class FullyShardedDataParallel(nn.Module, _FSDPState): class FullyShardedDataParallel(nn.Module, _FSDPState):
"""A wrapper for sharding module parameters across data parallel workers. """A wrapper for sharding module parameters across data parallel workers.
This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_. This is inspired by `Xu et al. <https://arxiv.org/abs/2004.13336>`_ as
well as the ZeRO Stage 3 from `DeepSpeed <https://www.deepspeed.ai/>`_.
FullyShardedDataParallel is commonly shortened to FSDP. FullyShardedDataParallel is commonly shortened to FSDP.
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
.. _DeepSpeed: https://www.deepspeed.ai/
To understand FSDP internals, refer to the To understand FSDP internals, refer to the
:ref:`fsdp_notes`. :ref:`fsdp_notes`.