mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Convert rst files to md (#155369)
Fixes #155021 Fixes #155158 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155369 Approved by: https://github.com/svekars, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
48921721d8
commit
1b032384b1
@ -1,46 +1,75 @@
|
|||||||
FullyShardedDataParallel
|
# FullyShardedDataParallel
|
||||||
========================
|
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. automodule:: torch.distributed.fsdp
|
.. automodule:: torch.distributed.fsdp
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.FullyShardedDataParallel
|
.. autoclass:: torch.distributed.fsdp.FullyShardedDataParallel
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.BackwardPrefetch
|
.. autoclass:: torch.distributed.fsdp.BackwardPrefetch
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.ShardingStrategy
|
.. autoclass:: torch.distributed.fsdp.ShardingStrategy
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.MixedPrecision
|
.. autoclass:: torch.distributed.fsdp.MixedPrecision
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.CPUOffload
|
.. autoclass:: torch.distributed.fsdp.CPUOffload
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.StateDictConfig
|
.. autoclass:: torch.distributed.fsdp.StateDictConfig
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.FullStateDictConfig
|
.. autoclass:: torch.distributed.fsdp.FullStateDictConfig
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.ShardedStateDictConfig
|
.. autoclass:: torch.distributed.fsdp.ShardedStateDictConfig
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.LocalStateDictConfig
|
.. autoclass:: torch.distributed.fsdp.LocalStateDictConfig
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.OptimStateDictConfig
|
.. autoclass:: torch.distributed.fsdp.OptimStateDictConfig
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.FullOptimStateDictConfig
|
.. autoclass:: torch.distributed.fsdp.FullOptimStateDictConfig
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.ShardedOptimStateDictConfig
|
.. autoclass:: torch.distributed.fsdp.ShardedOptimStateDictConfig
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.LocalOptimStateDictConfig
|
.. autoclass:: torch.distributed.fsdp.LocalOptimStateDictConfig
|
||||||
:members:
|
:members:
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
.. autoclass:: torch.distributed.fsdp.StateDictSettings
|
.. autoclass:: torch.distributed.fsdp.StateDictSettings
|
||||||
:members:
|
:members:
|
||||||
|
```
|
88
docs/source/func.api.md
Normal file
88
docs/source/func.api.md
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# torch.func API Reference
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. currentmodule:: torch.func
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. automodule:: torch.func
|
||||||
|
```
|
||||||
|
|
||||||
|
## Function Transforms
|
||||||
|
```{eval-rst}
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
vmap
|
||||||
|
grad
|
||||||
|
grad_and_value
|
||||||
|
vjp
|
||||||
|
jvp
|
||||||
|
linearize
|
||||||
|
jacrev
|
||||||
|
jacfwd
|
||||||
|
hessian
|
||||||
|
functionalize
|
||||||
|
```
|
||||||
|
|
||||||
|
## Utilities for working with torch.nn.Modules
|
||||||
|
|
||||||
|
In general, you can transform over a function that calls a ``torch.nn.Module``.
|
||||||
|
For example, the following is an example of computing a jacobian of a function
|
||||||
|
that takes three values and returns three values:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = torch.nn.Linear(3, 3)
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return model(x)
|
||||||
|
|
||||||
|
x = torch.randn(3)
|
||||||
|
jacobian = jacrev(f)(x)
|
||||||
|
assert jacobian.shape == (3, 3)
|
||||||
|
```
|
||||||
|
|
||||||
|
However, if you want to do something like compute a jacobian over the parameters of the model, then there needs to be a way to construct a function where the parameters are the inputs to the function. That's what {func}`functional_call` is for: it accepts an nn.Module, the transformed `parameters`, and the inputs to the Module's forward pass. It returns the value of running the Module's forward pass with the replaced parameters.
|
||||||
|
|
||||||
|
Here's how we would compute the Jacobian over the parameters
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = torch.nn.Linear(3, 3)
|
||||||
|
|
||||||
|
def f(params, x):
|
||||||
|
return torch.func.functional_call(model, params, x)
|
||||||
|
|
||||||
|
x = torch.randn(3)
|
||||||
|
jacobian = jacrev(f)(dict(model.named_parameters()), x)
|
||||||
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
functional_call
|
||||||
|
stack_module_state
|
||||||
|
replace_all_batch_norm_modules_
|
||||||
|
```
|
||||||
|
|
||||||
|
If you're looking for information on fixing Batch Norm modules, please follow the
|
||||||
|
guidance here
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
func.batch_norm
|
||||||
|
```
|
||||||
|
|
||||||
|
## Debug utilities
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
debug_unwrap
|
||||||
|
```
|
@ -1,87 +0,0 @@
|
|||||||
torch.func API Reference
|
|
||||||
========================
|
|
||||||
|
|
||||||
.. currentmodule:: torch.func
|
|
||||||
|
|
||||||
.. automodule:: torch.func
|
|
||||||
|
|
||||||
Function Transforms
|
|
||||||
-------------------
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: generated
|
|
||||||
:nosignatures:
|
|
||||||
|
|
||||||
vmap
|
|
||||||
grad
|
|
||||||
grad_and_value
|
|
||||||
vjp
|
|
||||||
jvp
|
|
||||||
linearize
|
|
||||||
jacrev
|
|
||||||
jacfwd
|
|
||||||
hessian
|
|
||||||
functionalize
|
|
||||||
|
|
||||||
Utilities for working with torch.nn.Modules
|
|
||||||
-------------------------------------------
|
|
||||||
|
|
||||||
In general, you can transform over a function that calls a ``torch.nn.Module``.
|
|
||||||
For example, the following is an example of computing a jacobian of a function
|
|
||||||
that takes three values and returns three values:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
model = torch.nn.Linear(3, 3)
|
|
||||||
|
|
||||||
def f(x):
|
|
||||||
return model(x)
|
|
||||||
|
|
||||||
x = torch.randn(3)
|
|
||||||
jacobian = jacrev(f)(x)
|
|
||||||
assert jacobian.shape == (3, 3)
|
|
||||||
|
|
||||||
However, if you want to do something like compute a jacobian over the parameters
|
|
||||||
of the model, then there needs to be a way to construct a function where the
|
|
||||||
parameters are the inputs to the function.
|
|
||||||
That's what :func:`functional_call` is for:
|
|
||||||
it accepts an nn.Module, the transformed ``parameters``, and the inputs to the
|
|
||||||
Module's forward pass. It returns the value of running the Module's forward pass
|
|
||||||
with the replaced parameters.
|
|
||||||
|
|
||||||
Here's how we would compute the Jacobian over the parameters
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
model = torch.nn.Linear(3, 3)
|
|
||||||
|
|
||||||
def f(params, x):
|
|
||||||
return torch.func.functional_call(model, params, x)
|
|
||||||
|
|
||||||
x = torch.randn(3)
|
|
||||||
jacobian = jacrev(f)(dict(model.named_parameters()), x)
|
|
||||||
|
|
||||||
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: generated
|
|
||||||
:nosignatures:
|
|
||||||
|
|
||||||
functional_call
|
|
||||||
stack_module_state
|
|
||||||
replace_all_batch_norm_modules_
|
|
||||||
|
|
||||||
If you're looking for information on fixing Batch Norm modules, please follow the
|
|
||||||
guidance here
|
|
||||||
|
|
||||||
.. toctree::
|
|
||||||
:maxdepth: 1
|
|
||||||
|
|
||||||
func.batch_norm
|
|
||||||
|
|
||||||
Debug utilities
|
|
||||||
---------------
|
|
||||||
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: generated
|
|
||||||
:nosignatures:
|
|
||||||
|
|
||||||
debug_unwrap
|
|
@ -1,83 +1,75 @@
|
|||||||
Patching Batch Norm
|
# Patching Batch Norm
|
||||||
===================
|
|
||||||
|
|
||||||
What's happening?
|
## What's happening?
|
||||||
-----------------
|
|
||||||
Batch Norm requires in-place updates to running_mean and running_var of the same size as the input.
|
Batch Norm requires in-place updates to running_mean and running_var of the same size as the input.
|
||||||
Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e.
|
Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e.
|
||||||
``regular.add_(batched)`` is not allowed). So when vmapping over a batch of inputs to a single module,
|
`regular.add_(batched)` is not allowed). So when vmapping over a batch of inputs to a single module,
|
||||||
we end up with this error
|
we end up with this error
|
||||||
|
|
||||||
How to fix
|
## How to fix
|
||||||
----------
|
|
||||||
One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this
|
One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this
|
||||||
|
|
||||||
All of these options assume that you don't need running stats. If you're using a module this means
|
All of these options assume that you don't need running stats. If you're using a module this means
|
||||||
that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves
|
that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves
|
||||||
running batch norm with vmap in evaluation mode, please file an issue
|
running batch norm with vmap in evaluation mode, please file an issue
|
||||||
|
|
||||||
Option 1: Change the BatchNorm
|
### Option 1: Change the BatchNorm
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with:
|
If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with:
|
||||||
|
|
||||||
.. code-block:: python
|
```python
|
||||||
|
BatchNorm2d(C, G, track_running_stats=False)
|
||||||
|
```
|
||||||
|
|
||||||
BatchNorm2d(C, G, track_running_stats=False)
|
Here `C` is the same `C` as in the original BatchNorm. `G` is the number of groups to
|
||||||
|
break `C` into. As such, `C % G == 0` and as a fallback, you can set `C == G`, meaning
|
||||||
Here ``C`` is the same ``C`` as in the original BatchNorm. ``G`` is the number of groups to
|
|
||||||
break ``C`` into. As such, ``C % G == 0`` and as a fallback, you can set ``C == G``, meaning
|
|
||||||
each channel will be treated separately.
|
each channel will be treated separately.
|
||||||
|
|
||||||
If you must use BatchNorm and you've built the module yourself, you can change the module to
|
If you must use BatchNorm and you've built the module yourself, you can change the module to
|
||||||
not use running stats. In other words, anywhere that there's a BatchNorm module, set the
|
not use running stats. In other words, anywhere that there's a BatchNorm module, set the
|
||||||
``track_running_stats`` flag to be False
|
`track_running_stats` flag to be False
|
||||||
|
|
||||||
.. code-block:: python
|
```python
|
||||||
|
BatchNorm2d(64, track_running_stats=False)
|
||||||
|
```
|
||||||
|
|
||||||
BatchNorm2d(64, track_running_stats=False)
|
### Option 2: torchvision parameter
|
||||||
|
Some torchvision models, like resnet and regnet, can take in a `norm_layer` parameter. These are
|
||||||
|
|
||||||
Option 2: torchvision parameter
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are
|
|
||||||
often defaulted to be BatchNorm2d if they've been defaulted.
|
often defaulted to be BatchNorm2d if they've been defaulted.
|
||||||
|
|
||||||
Instead you can set it to be GroupNorm.
|
Instead you can set it to be GroupNorm.
|
||||||
|
|
||||||
.. code-block:: python
|
```python
|
||||||
|
import torchvision
|
||||||
|
from functools import partial
|
||||||
|
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
|
||||||
|
```
|
||||||
|
|
||||||
import torchvision
|
Here, once again, `c % g == 0` so as a fallback, set `g = c`.
|
||||||
from functools import partial
|
|
||||||
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
|
|
||||||
|
|
||||||
Here, once again, ``c % g == 0`` so as a fallback, set ``g = c``.
|
|
||||||
|
|
||||||
If you are attached to BatchNorm, be sure to use a version that doesn't use running stats
|
If you are attached to BatchNorm, be sure to use a version that doesn't use running stats
|
||||||
|
|
||||||
.. code-block:: python
|
```python
|
||||||
|
import torchvision
|
||||||
|
from functools import partial
|
||||||
|
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
|
||||||
|
```
|
||||||
|
|
||||||
import torchvision
|
### Option 3: functorch's patching
|
||||||
from functools import partial
|
|
||||||
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
|
|
||||||
|
|
||||||
Option 3: functorch's patching
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
functorch has added some functionality to allow for quick, in-place patching of the module to not
|
functorch has added some functionality to allow for quick, in-place patching of the module to not
|
||||||
use running stats. Changing the norm layer is more fragile, so we have not offered that. If you
|
use running stats. Changing the norm layer is more fragile, so we have not offered that. If you
|
||||||
have a net where you want the BatchNorm to not use running stats, you can run
|
have a net where you want the BatchNorm to not use running stats, you can run
|
||||||
``replace_all_batch_norm_modules_`` to update the module in-place to not use running stats
|
`replace_all_batch_norm_modules_` to update the module in-place to not use running stats
|
||||||
|
|
||||||
.. code-block:: python
|
```python
|
||||||
|
from torch.func import replace_all_batch_norm_modules_
|
||||||
|
replace_all_batch_norm_modules_(net)
|
||||||
|
```
|
||||||
|
|
||||||
from torch.func import replace_all_batch_norm_modules_
|
### Option 4: eval mode
|
||||||
replace_all_batch_norm_modules_(net)
|
|
||||||
|
|
||||||
Option 4: eval mode
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode
|
When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode
|
||||||
|
|
||||||
.. code-block:: python
|
```python
|
||||||
|
model.eval()
|
||||||
model.eval()
|
vmap(model)(x)
|
||||||
vmap(model)(x)
|
model.train()
|
||||||
model.train()
|
```
|
56
docs/source/func.md
Normal file
56
docs/source/func.md
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# torch.func
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. currentmodule:: torch.func
|
||||||
|
```
|
||||||
|
|
||||||
|
torch.func, previously known as "functorch", is
|
||||||
|
[JAX-like](https://github.com/google/jax) composable function transforms for PyTorch.
|
||||||
|
|
||||||
|
```{note}
|
||||||
|
This library is currently in [beta](https://pytorch.org/blog/pytorch-feature-classification-changes/#beta).
|
||||||
|
What this means is that the features generally work (unless otherwise documented)
|
||||||
|
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
|
||||||
|
may change under user feedback and we don't have full coverage over PyTorch operations.
|
||||||
|
|
||||||
|
If you have suggestions on the API or use-cases you'd like to be covered, please
|
||||||
|
open a GitHub issue or reach out. We'd love to hear about how you're using the library.
|
||||||
|
```
|
||||||
|
|
||||||
|
## What are composable function transforms?
|
||||||
|
|
||||||
|
- A "function transform" is a higher-order function that accepts a numerical function
|
||||||
|
and returns a new function that computes a different quantity.
|
||||||
|
|
||||||
|
- {mod}`torch.func` has auto-differentiation transforms (`grad(f)` returns a function that
|
||||||
|
computes the gradient of `f`), a vectorization/batching transform (`vmap(f)`
|
||||||
|
returns a function that computes `f` over batches of inputs), and others.
|
||||||
|
|
||||||
|
- These function transforms can compose with each other arbitrarily. For example,
|
||||||
|
composing `vmap(grad(f))` computes a quantity called per-sample-gradients that
|
||||||
|
stock PyTorch cannot efficiently compute today.
|
||||||
|
|
||||||
|
## Why composable function transforms?
|
||||||
|
|
||||||
|
There are a number of use cases that are tricky to do in PyTorch today:
|
||||||
|
|
||||||
|
- computing per-sample-gradients (or other per-sample quantities)
|
||||||
|
- running ensembles of models on a single machine
|
||||||
|
- efficiently batching together tasks in the inner-loop of MAML
|
||||||
|
- efficiently computing Jacobians and Hessians
|
||||||
|
- efficiently computing batched Jacobians and Hessians
|
||||||
|
|
||||||
|
Composing {func}`vmap`, {func}`grad`, and {func}`vjp` transforms allows us to express the above without designing a separate subsystem for each.
|
||||||
|
This idea of composable function transforms comes from the [JAX framework](https://github.com/google/jax).
|
||||||
|
|
||||||
|
## Read More
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
|
||||||
|
func.whirlwind_tour
|
||||||
|
func.api
|
||||||
|
func.ux_limitations
|
||||||
|
func.migrating
|
||||||
|
```
|
201
docs/source/func.migrating.md
Normal file
201
docs/source/func.migrating.md
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
# Migrating from functorch to torch.func
|
||||||
|
|
||||||
|
torch.func, previously known as "functorch", is
|
||||||
|
[JAX-like](https://github.com/google/jax) composable function transforms for PyTorch.
|
||||||
|
|
||||||
|
functorch started as an out-of-tree library over at
|
||||||
|
the [pytorch/functorch](https://github.com/pytorch/functorch) repository.
|
||||||
|
Our goal has always been to upstream functorch directly into PyTorch and provide
|
||||||
|
it as a core PyTorch library.
|
||||||
|
|
||||||
|
As the final step of the upstream, we've decided to migrate from being a top level package
|
||||||
|
(`functorch`) to being a part of PyTorch to reflect how the function transforms are
|
||||||
|
integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating
|
||||||
|
`import functorch` and ask that users migrate to the newest APIs, which we
|
||||||
|
will maintain going forward. `import functorch` will be kept around to maintain
|
||||||
|
backwards compatibility for a couple of releases.
|
||||||
|
|
||||||
|
## function transforms
|
||||||
|
|
||||||
|
The following APIs are a drop-in replacement for the following
|
||||||
|
[functorch APIs](https://pytorch.org/functorch/1.13/functorch.html).
|
||||||
|
They are fully backwards compatible.
|
||||||
|
|
||||||
|
| functorch API | PyTorch API (as of PyTorch 2.0) |
|
||||||
|
| ----------------------------------- | ---------------------------------------------- |
|
||||||
|
| functorch.vmap | {func}`torch.vmap` or {func}`torch.func.vmap` |
|
||||||
|
| functorch.grad | {func}`torch.func.grad` |
|
||||||
|
| functorch.vjp | {func}`torch.func.vjp` |
|
||||||
|
| functorch.jvp | {func}`torch.func.jvp` |
|
||||||
|
| functorch.jacrev | {func}`torch.func.jacrev` |
|
||||||
|
| functorch.jacfwd | {func}`torch.func.jacfwd` |
|
||||||
|
| functorch.hessian | {func}`torch.func.hessian` |
|
||||||
|
| functorch.functionalize | {func}`torch.func.functionalize` |
|
||||||
|
|
||||||
|
Furthermore, if you are using torch.autograd.functional APIs, please try out
|
||||||
|
the {mod}`torch.func` equivalents instead. {mod}`torch.func` function
|
||||||
|
transforms are more composable and more performant in many cases.
|
||||||
|
|
||||||
|
| torch.autograd.functional API | torch.func API (as of PyTorch 2.0) |
|
||||||
|
| ------------------------------------------- | ---------------------------------------------- |
|
||||||
|
| {func}`torch.autograd.functional.vjp` | {func}`torch.func.grad` or {func}`torch.func.vjp` |
|
||||||
|
| {func}`torch.autograd.functional.jvp` | {func}`torch.func.jvp` |
|
||||||
|
| {func}`torch.autograd.functional.jacobian` | {func}`torch.func.jacrev` or {func}`torch.func.jacfwd` |
|
||||||
|
| {func}`torch.autograd.functional.hessian` | {func}`torch.func.hessian` |
|
||||||
|
|
||||||
|
## NN module utilities
|
||||||
|
|
||||||
|
We've changed the APIs to apply function transforms over NN modules to make them
|
||||||
|
fit better into the PyTorch design philosophy. The new API is different, so
|
||||||
|
please read this section carefully.
|
||||||
|
|
||||||
|
### functorch.make_functional
|
||||||
|
|
||||||
|
{func}`torch.func.functional_call` is the replacement for
|
||||||
|
[functorch.make_functional](https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional)
|
||||||
|
and
|
||||||
|
[functorch.make_functional_with_buffers](https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers).
|
||||||
|
However, it is not a drop-in replacement.
|
||||||
|
|
||||||
|
If you're in a hurry, you can use
|
||||||
|
[helper functions in this gist](https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf)
|
||||||
|
that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.
|
||||||
|
We recommend using {func}`torch.func.functional_call` directly because it is a more explicit
|
||||||
|
and flexible API.
|
||||||
|
|
||||||
|
Concretely, functorch.make_functional returns a functional module and parameters.
|
||||||
|
The functional module accepts parameters and inputs to the model as arguments.
|
||||||
|
{func}`torch.func.functional_call` allows one to call the forward pass of an existing
|
||||||
|
module using new parameters and buffers and inputs.
|
||||||
|
|
||||||
|
Here's an example of how to compute gradients of parameters of a model using functorch
|
||||||
|
vs {mod}`torch.func`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ---------------
|
||||||
|
# using functorch
|
||||||
|
# ---------------
|
||||||
|
import torch
|
||||||
|
import functorch
|
||||||
|
inputs = torch.randn(64, 3)
|
||||||
|
targets = torch.randn(64, 3)
|
||||||
|
model = torch.nn.Linear(3, 3)
|
||||||
|
|
||||||
|
fmodel, params = functorch.make_functional(model)
|
||||||
|
|
||||||
|
def compute_loss(params, inputs, targets):
|
||||||
|
prediction = fmodel(params, inputs)
|
||||||
|
return torch.nn.functional.mse_loss(prediction, targets)
|
||||||
|
|
||||||
|
grads = functorch.grad(compute_loss)(params, inputs, targets)
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
# using torch.func (as of PyTorch 2.0)
|
||||||
|
# ------------------------------------
|
||||||
|
import torch
|
||||||
|
inputs = torch.randn(64, 3)
|
||||||
|
targets = torch.randn(64, 3)
|
||||||
|
model = torch.nn.Linear(3, 3)
|
||||||
|
|
||||||
|
params = dict(model.named_parameters())
|
||||||
|
|
||||||
|
def compute_loss(params, inputs, targets):
|
||||||
|
prediction = torch.func.functional_call(model, params, (inputs,))
|
||||||
|
return torch.nn.functional.mse_loss(prediction, targets)
|
||||||
|
|
||||||
|
grads = torch.func.grad(compute_loss)(params, inputs, targets)
|
||||||
|
```
|
||||||
|
|
||||||
|
And here's an example of how to compute jacobians of model parameters:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ---------------
|
||||||
|
# using functorch
|
||||||
|
# ---------------
|
||||||
|
import torch
|
||||||
|
import functorch
|
||||||
|
inputs = torch.randn(64, 3)
|
||||||
|
model = torch.nn.Linear(3, 3)
|
||||||
|
|
||||||
|
fmodel, params = functorch.make_functional(model)
|
||||||
|
jacobians = functorch.jacrev(fmodel)(params, inputs)
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
# using torch.func (as of PyTorch 2.0)
|
||||||
|
# ------------------------------------
|
||||||
|
import torch
|
||||||
|
from torch.func import jacrev, functional_call
|
||||||
|
inputs = torch.randn(64, 3)
|
||||||
|
model = torch.nn.Linear(3, 3)
|
||||||
|
|
||||||
|
params = dict(model.named_parameters())
|
||||||
|
# jacrev computes jacobians of argnums=0 by default.
|
||||||
|
# We set it to 1 to compute jacobians of params
|
||||||
|
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that it is important for memory consumption that you should only carry
|
||||||
|
around a single copy of your parameters. `model.named_parameters()` does not copy
|
||||||
|
the parameters. If in your model training you update the parameters of the model
|
||||||
|
in-place, then the `nn.Module` that is your model has the single copy of the
|
||||||
|
parameters and everything is OK.
|
||||||
|
|
||||||
|
However, if you want to carry your parameters around in a dictionary and update
|
||||||
|
them out-of-place, then there are two copies of parameters: the one in the
|
||||||
|
dictionary and the one in the `model`. In this case, you should change
|
||||||
|
`model` to not hold memory by converting it to the meta device via
|
||||||
|
`model.to('meta')`.
|
||||||
|
|
||||||
|
### functorch.combine_state_for_ensemble
|
||||||
|
|
||||||
|
Please use {func}`torch.func.stack_module_state` instead of
|
||||||
|
[functorch.combine_state_for_ensemble](https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html)
|
||||||
|
{func}`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and
|
||||||
|
one of stacked buffers, that can then be used with {func}`torch.vmap` and {func}`torch.func.functional_call`
|
||||||
|
for ensembling.
|
||||||
|
|
||||||
|
For example, here is an example of how to ensemble over a very simple model:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
num_models = 5
|
||||||
|
batch_size = 64
|
||||||
|
in_features, out_features = 3, 3
|
||||||
|
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
|
||||||
|
data = torch.randn(batch_size, 3)
|
||||||
|
|
||||||
|
# ---------------
|
||||||
|
# using functorch
|
||||||
|
# ---------------
|
||||||
|
import functorch
|
||||||
|
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
|
||||||
|
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
|
||||||
|
assert output.shape == (num_models, batch_size, out_features)
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
# using torch.func (as of PyTorch 2.0)
|
||||||
|
# ------------------------------------
|
||||||
|
import copy
|
||||||
|
|
||||||
|
# Construct a version of the model with no memory by putting the Tensors on
|
||||||
|
# the meta device.
|
||||||
|
base_model = copy.deepcopy(models[0])
|
||||||
|
base_model.to('meta')
|
||||||
|
|
||||||
|
params, buffers = torch.func.stack_module_state(models)
|
||||||
|
|
||||||
|
# It is possible to vmap directly over torch.func.functional_call,
|
||||||
|
# but wrapping it in a function makes it clearer what is going on.
|
||||||
|
def call_single_model(params, buffers, data):
|
||||||
|
return torch.func.functional_call(base_model, (params, buffers), (data,))
|
||||||
|
|
||||||
|
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
|
||||||
|
assert output.shape == (num_models, batch_size, out_features)
|
||||||
|
```
|
||||||
|
|
||||||
|
## functorch.compile
|
||||||
|
|
||||||
|
We are no longer supporting functorch.compile (also known as AOTAutograd)
|
||||||
|
as a frontend for compilation in PyTorch; we have integrated AOTAutograd
|
||||||
|
into PyTorch's compilation story. If you are a user, please use
|
||||||
|
{func}`torch.compile` instead.
|
@ -1,207 +0,0 @@
|
|||||||
Migrating from functorch to torch.func
|
|
||||||
======================================
|
|
||||||
|
|
||||||
torch.func, previously known as "functorch", is
|
|
||||||
`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
|
|
||||||
|
|
||||||
functorch started as an out-of-tree library over at
|
|
||||||
the `pytorch/functorch <https://github.com/pytorch/functorch>`_ repository.
|
|
||||||
Our goal has always been to upstream functorch directly into PyTorch and provide
|
|
||||||
it as a core PyTorch library.
|
|
||||||
|
|
||||||
As the final step of the upstream, we've decided to migrate from being a top level package
|
|
||||||
(``functorch``) to being a part of PyTorch to reflect how the function transforms are
|
|
||||||
integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating
|
|
||||||
``import functorch`` and ask that users migrate to the newest APIs, which we
|
|
||||||
will maintain going forward. ``import functorch`` will be kept around to maintain
|
|
||||||
backwards compatibility for a couple of releases.
|
|
||||||
|
|
||||||
function transforms
|
|
||||||
-------------------
|
|
||||||
|
|
||||||
The following APIs are a drop-in replacement for the following
|
|
||||||
`functorch APIs <https://pytorch.org/functorch/1.13/functorch.html>`_.
|
|
||||||
They are fully backwards compatible.
|
|
||||||
|
|
||||||
|
|
||||||
============================== =======================================
|
|
||||||
functorch API PyTorch API (as of PyTorch 2.0)
|
|
||||||
============================== =======================================
|
|
||||||
functorch.vmap :func:`torch.vmap` or :func:`torch.func.vmap`
|
|
||||||
functorch.grad :func:`torch.func.grad`
|
|
||||||
functorch.vjp :func:`torch.func.vjp`
|
|
||||||
functorch.jvp :func:`torch.func.jvp`
|
|
||||||
functorch.jacrev :func:`torch.func.jacrev`
|
|
||||||
functorch.jacfwd :func:`torch.func.jacfwd`
|
|
||||||
functorch.hessian :func:`torch.func.hessian`
|
|
||||||
functorch.functionalize :func:`torch.func.functionalize`
|
|
||||||
============================== =======================================
|
|
||||||
|
|
||||||
Furthermore, if you are using torch.autograd.functional APIs, please try out
|
|
||||||
the :mod:`torch.func` equivalents instead. :mod:`torch.func` function
|
|
||||||
transforms are more composable and more performant in many cases.
|
|
||||||
|
|
||||||
=========================================== =======================================
|
|
||||||
torch.autograd.functional API torch.func API (as of PyTorch 2.0)
|
|
||||||
=========================================== =======================================
|
|
||||||
:func:`torch.autograd.functional.vjp` :func:`torch.func.grad` or :func:`torch.func.vjp`
|
|
||||||
:func:`torch.autograd.functional.jvp` :func:`torch.func.jvp`
|
|
||||||
:func:`torch.autograd.functional.jacobian` :func:`torch.func.jacrev` or :func:`torch.func.jacfwd`
|
|
||||||
:func:`torch.autograd.functional.hessian` :func:`torch.func.hessian`
|
|
||||||
=========================================== =======================================
|
|
||||||
|
|
||||||
NN module utilities
|
|
||||||
-------------------
|
|
||||||
|
|
||||||
We've changed the APIs to apply function transforms over NN modules to make them
|
|
||||||
fit better into the PyTorch design philosophy. The new API is different, so
|
|
||||||
please read this section carefully.
|
|
||||||
|
|
||||||
functorch.make_functional
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
:func:`torch.func.functional_call` is the replacement for
|
|
||||||
`functorch.make_functional <https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional>`_
|
|
||||||
and
|
|
||||||
`functorch.make_functional_with_buffers <https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers>`_.
|
|
||||||
However, it is not a drop-in replacement.
|
|
||||||
|
|
||||||
If you're in a hurry, you can use
|
|
||||||
`helper functions in this gist <https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf>`_
|
|
||||||
that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.
|
|
||||||
We recommend using :func:`torch.func.functional_call` directly because it is a more explicit
|
|
||||||
and flexible API.
|
|
||||||
|
|
||||||
Concretely, functorch.make_functional returns a functional module and parameters.
|
|
||||||
The functional module accepts parameters and inputs to the model as arguments.
|
|
||||||
:func:`torch.func.functional_call` allows one to call the forward pass of an existing
|
|
||||||
module using new parameters and buffers and inputs.
|
|
||||||
|
|
||||||
Here's an example of how to compute gradients of parameters of a model using functorch
|
|
||||||
vs :mod:`torch.func`::
|
|
||||||
|
|
||||||
# ---------------
|
|
||||||
# using functorch
|
|
||||||
# ---------------
|
|
||||||
import torch
|
|
||||||
import functorch
|
|
||||||
inputs = torch.randn(64, 3)
|
|
||||||
targets = torch.randn(64, 3)
|
|
||||||
model = torch.nn.Linear(3, 3)
|
|
||||||
|
|
||||||
fmodel, params = functorch.make_functional(model)
|
|
||||||
|
|
||||||
def compute_loss(params, inputs, targets):
|
|
||||||
prediction = fmodel(params, inputs)
|
|
||||||
return torch.nn.functional.mse_loss(prediction, targets)
|
|
||||||
|
|
||||||
grads = functorch.grad(compute_loss)(params, inputs, targets)
|
|
||||||
|
|
||||||
# ------------------------------------
|
|
||||||
# using torch.func (as of PyTorch 2.0)
|
|
||||||
# ------------------------------------
|
|
||||||
import torch
|
|
||||||
inputs = torch.randn(64, 3)
|
|
||||||
targets = torch.randn(64, 3)
|
|
||||||
model = torch.nn.Linear(3, 3)
|
|
||||||
|
|
||||||
params = dict(model.named_parameters())
|
|
||||||
|
|
||||||
def compute_loss(params, inputs, targets):
|
|
||||||
prediction = torch.func.functional_call(model, params, (inputs,))
|
|
||||||
return torch.nn.functional.mse_loss(prediction, targets)
|
|
||||||
|
|
||||||
grads = torch.func.grad(compute_loss)(params, inputs, targets)
|
|
||||||
|
|
||||||
And here's an example of how to compute jacobians of model parameters::
|
|
||||||
|
|
||||||
# ---------------
|
|
||||||
# using functorch
|
|
||||||
# ---------------
|
|
||||||
import torch
|
|
||||||
import functorch
|
|
||||||
inputs = torch.randn(64, 3)
|
|
||||||
model = torch.nn.Linear(3, 3)
|
|
||||||
|
|
||||||
fmodel, params = functorch.make_functional(model)
|
|
||||||
jacobians = functorch.jacrev(fmodel)(params, inputs)
|
|
||||||
|
|
||||||
# ------------------------------------
|
|
||||||
# using torch.func (as of PyTorch 2.0)
|
|
||||||
# ------------------------------------
|
|
||||||
import torch
|
|
||||||
from torch.func import jacrev, functional_call
|
|
||||||
inputs = torch.randn(64, 3)
|
|
||||||
model = torch.nn.Linear(3, 3)
|
|
||||||
|
|
||||||
params = dict(model.named_parameters())
|
|
||||||
# jacrev computes jacobians of argnums=0 by default.
|
|
||||||
# We set it to 1 to compute jacobians of params
|
|
||||||
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
|
|
||||||
|
|
||||||
Note that it is important for memory consumption that you should only carry
|
|
||||||
around a single copy of your parameters. ``model.named_parameters()`` does not copy
|
|
||||||
the parameters. If in your model training you update the parameters of the model
|
|
||||||
in-place, then the ``nn.Module`` that is your model has the single copy of the
|
|
||||||
parameters and everything is OK.
|
|
||||||
|
|
||||||
However, if you want to carry your parameters around in a dictionary and update
|
|
||||||
them out-of-place, then there are two copies of parameters: the one in the
|
|
||||||
dictionary and the one in the ``model``. In this case, you should change
|
|
||||||
``model`` to not hold memory by converting it to the meta device via
|
|
||||||
``model.to('meta')``.
|
|
||||||
|
|
||||||
functorch.combine_state_for_ensemble
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
Please use :func:`torch.func.stack_module_state` instead of
|
|
||||||
`functorch.combine_state_for_ensemble <https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html>`_
|
|
||||||
:func:`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and
|
|
||||||
one of stacked buffers, that can then be used with :func:`torch.vmap` and :func:`torch.func.functional_call`
|
|
||||||
for ensembling.
|
|
||||||
|
|
||||||
For example, here is an example of how to ensemble over a very simple model::
|
|
||||||
|
|
||||||
import torch
|
|
||||||
num_models = 5
|
|
||||||
batch_size = 64
|
|
||||||
in_features, out_features = 3, 3
|
|
||||||
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
|
|
||||||
data = torch.randn(batch_size, 3)
|
|
||||||
|
|
||||||
# ---------------
|
|
||||||
# using functorch
|
|
||||||
# ---------------
|
|
||||||
import functorch
|
|
||||||
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
|
|
||||||
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
|
|
||||||
assert output.shape == (num_models, batch_size, out_features)
|
|
||||||
|
|
||||||
# ------------------------------------
|
|
||||||
# using torch.func (as of PyTorch 2.0)
|
|
||||||
# ------------------------------------
|
|
||||||
import copy
|
|
||||||
|
|
||||||
# Construct a version of the model with no memory by putting the Tensors on
|
|
||||||
# the meta device.
|
|
||||||
base_model = copy.deepcopy(models[0])
|
|
||||||
base_model.to('meta')
|
|
||||||
|
|
||||||
params, buffers = torch.func.stack_module_state(models)
|
|
||||||
|
|
||||||
# It is possible to vmap directly over torch.func.functional_call,
|
|
||||||
# but wrapping it in a function makes it clearer what is going on.
|
|
||||||
def call_single_model(params, buffers, data):
|
|
||||||
return torch.func.functional_call(base_model, (params, buffers), (data,))
|
|
||||||
|
|
||||||
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
|
|
||||||
assert output.shape == (num_models, batch_size, out_features)
|
|
||||||
|
|
||||||
|
|
||||||
functorch.compile
|
|
||||||
-----------------
|
|
||||||
|
|
||||||
We are no longer supporting functorch.compile (also known as AOTAutograd)
|
|
||||||
as a frontend for compilation in PyTorch; we have integrated AOTAutograd
|
|
||||||
into PyTorch's compilation story. If you are a user, please use
|
|
||||||
:func:`torch.compile` instead.
|
|
@ -1,55 +0,0 @@
|
|||||||
torch.func
|
|
||||||
==========
|
|
||||||
|
|
||||||
.. currentmodule:: torch.func
|
|
||||||
|
|
||||||
torch.func, previously known as "functorch", is
|
|
||||||
`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
This library is currently in `beta <https://pytorch.org/blog/pytorch-feature-classification-changes/#beta>`_.
|
|
||||||
What this means is that the features generally work (unless otherwise documented)
|
|
||||||
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
|
|
||||||
may change under user feedback and we don't have full coverage over PyTorch operations.
|
|
||||||
|
|
||||||
If you have suggestions on the API or use-cases you'd like to be covered, please
|
|
||||||
open an GitHub issue or reach out. We'd love to hear about how you're using the library.
|
|
||||||
|
|
||||||
What are composable function transforms?
|
|
||||||
----------------------------------------
|
|
||||||
|
|
||||||
- A "function transform" is a higher-order function that accepts a numerical function
|
|
||||||
and returns a new function that computes a different quantity.
|
|
||||||
|
|
||||||
- :mod:`torch.func` has auto-differentiation transforms (``grad(f)`` returns a function that
|
|
||||||
computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)``
|
|
||||||
returns a function that computes ``f`` over batches of inputs), and others.
|
|
||||||
|
|
||||||
- These function transforms can compose with each other arbitrarily. For example,
|
|
||||||
composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that
|
|
||||||
stock PyTorch cannot efficiently compute today.
|
|
||||||
|
|
||||||
Why composable function transforms?
|
|
||||||
-----------------------------------
|
|
||||||
|
|
||||||
There are a number of use cases that are tricky to do in PyTorch today:
|
|
||||||
|
|
||||||
- computing per-sample-gradients (or other per-sample quantities)
|
|
||||||
- running ensembles of models on a single machine
|
|
||||||
- efficiently batching together tasks in the inner-loop of MAML
|
|
||||||
- efficiently computing Jacobians and Hessians
|
|
||||||
- efficiently computing batched Jacobians and Hessians
|
|
||||||
|
|
||||||
Composing :func:`vmap`, :func:`grad`, and :func:`vjp` transforms allows us to express the above without designing a separate subsystem for each.
|
|
||||||
This idea of composable function transforms comes from the `JAX framework <https://github.com/google/jax>`_.
|
|
||||||
|
|
||||||
Read More
|
|
||||||
---------
|
|
||||||
|
|
||||||
.. toctree::
|
|
||||||
:maxdepth: 2
|
|
||||||
|
|
||||||
func.whirlwind_tour
|
|
||||||
func.api
|
|
||||||
func.ux_limitations
|
|
||||||
func.migrating
|
|
@ -117,12 +117,10 @@ class OptimStateKeyType(Enum):
|
|||||||
class FullyShardedDataParallel(nn.Module, _FSDPState):
|
class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||||
"""A wrapper for sharding module parameters across data parallel workers.
|
"""A wrapper for sharding module parameters across data parallel workers.
|
||||||
|
|
||||||
This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
|
This is inspired by `Xu et al. <https://arxiv.org/abs/2004.13336>`_ as
|
||||||
|
well as the ZeRO Stage 3 from `DeepSpeed <https://www.deepspeed.ai/>`_.
|
||||||
FullyShardedDataParallel is commonly shortened to FSDP.
|
FullyShardedDataParallel is commonly shortened to FSDP.
|
||||||
|
|
||||||
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
|
|
||||||
.. _DeepSpeed: https://www.deepspeed.ai/
|
|
||||||
|
|
||||||
To understand FSDP internals, refer to the
|
To understand FSDP internals, refer to the
|
||||||
:ref:`fsdp_notes`.
|
:ref:`fsdp_notes`.
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user