mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Convert rst files to md (#155369)
Fixes #155021 Fixes #155158 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155369 Approved by: https://github.com/svekars, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
48921721d8
commit
1b032384b1
@ -1,46 +1,75 @@
|
||||
FullyShardedDataParallel
|
||||
========================
|
||||
# FullyShardedDataParallel
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.distributed.fsdp
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.FullyShardedDataParallel
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.BackwardPrefetch
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.ShardingStrategy
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.MixedPrecision
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.CPUOffload
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.StateDictConfig
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.FullStateDictConfig
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.ShardedStateDictConfig
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.LocalStateDictConfig
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.OptimStateDictConfig
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.FullOptimStateDictConfig
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.ShardedOptimStateDictConfig
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.LocalOptimStateDictConfig
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.distributed.fsdp.StateDictSettings
|
||||
:members:
|
||||
```
|
88
docs/source/func.api.md
Normal file
88
docs/source/func.api.md
Normal file
@ -0,0 +1,88 @@
|
||||
# torch.func API Reference
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.func
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.func
|
||||
```
|
||||
|
||||
## Function Transforms
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
vmap
|
||||
grad
|
||||
grad_and_value
|
||||
vjp
|
||||
jvp
|
||||
linearize
|
||||
jacrev
|
||||
jacfwd
|
||||
hessian
|
||||
functionalize
|
||||
```
|
||||
|
||||
## Utilities for working with torch.nn.Modules
|
||||
|
||||
In general, you can transform over a function that calls a ``torch.nn.Module``.
|
||||
For example, the following is an example of computing a jacobian of a function
|
||||
that takes three values and returns three values:
|
||||
|
||||
```python
|
||||
model = torch.nn.Linear(3, 3)
|
||||
|
||||
def f(x):
|
||||
return model(x)
|
||||
|
||||
x = torch.randn(3)
|
||||
jacobian = jacrev(f)(x)
|
||||
assert jacobian.shape == (3, 3)
|
||||
```
|
||||
|
||||
However, if you want to do something like compute a jacobian over the parameters of the model, then there needs to be a way to construct a function where the parameters are the inputs to the function. That's what {func}`functional_call` is for: it accepts an nn.Module, the transformed `parameters`, and the inputs to the Module's forward pass. It returns the value of running the Module's forward pass with the replaced parameters.
|
||||
|
||||
Here's how we would compute the Jacobian over the parameters
|
||||
|
||||
```python
|
||||
model = torch.nn.Linear(3, 3)
|
||||
|
||||
def f(params, x):
|
||||
return torch.func.functional_call(model, params, x)
|
||||
|
||||
x = torch.randn(3)
|
||||
jacobian = jacrev(f)(dict(model.named_parameters()), x)
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
functional_call
|
||||
stack_module_state
|
||||
replace_all_batch_norm_modules_
|
||||
```
|
||||
|
||||
If you're looking for information on fixing Batch Norm modules, please follow the
|
||||
guidance here
|
||||
|
||||
```{eval-rst}
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
func.batch_norm
|
||||
```
|
||||
|
||||
## Debug utilities
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
debug_unwrap
|
||||
```
|
@ -1,87 +0,0 @@
|
||||
torch.func API Reference
|
||||
========================
|
||||
|
||||
.. currentmodule:: torch.func
|
||||
|
||||
.. automodule:: torch.func
|
||||
|
||||
Function Transforms
|
||||
-------------------
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
vmap
|
||||
grad
|
||||
grad_and_value
|
||||
vjp
|
||||
jvp
|
||||
linearize
|
||||
jacrev
|
||||
jacfwd
|
||||
hessian
|
||||
functionalize
|
||||
|
||||
Utilities for working with torch.nn.Modules
|
||||
-------------------------------------------
|
||||
|
||||
In general, you can transform over a function that calls a ``torch.nn.Module``.
|
||||
For example, the following is an example of computing a jacobian of a function
|
||||
that takes three values and returns three values:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = torch.nn.Linear(3, 3)
|
||||
|
||||
def f(x):
|
||||
return model(x)
|
||||
|
||||
x = torch.randn(3)
|
||||
jacobian = jacrev(f)(x)
|
||||
assert jacobian.shape == (3, 3)
|
||||
|
||||
However, if you want to do something like compute a jacobian over the parameters
|
||||
of the model, then there needs to be a way to construct a function where the
|
||||
parameters are the inputs to the function.
|
||||
That's what :func:`functional_call` is for:
|
||||
it accepts an nn.Module, the transformed ``parameters``, and the inputs to the
|
||||
Module's forward pass. It returns the value of running the Module's forward pass
|
||||
with the replaced parameters.
|
||||
|
||||
Here's how we would compute the Jacobian over the parameters
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = torch.nn.Linear(3, 3)
|
||||
|
||||
def f(params, x):
|
||||
return torch.func.functional_call(model, params, x)
|
||||
|
||||
x = torch.randn(3)
|
||||
jacobian = jacrev(f)(dict(model.named_parameters()), x)
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
functional_call
|
||||
stack_module_state
|
||||
replace_all_batch_norm_modules_
|
||||
|
||||
If you're looking for information on fixing Batch Norm modules, please follow the
|
||||
guidance here
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
func.batch_norm
|
||||
|
||||
Debug utilities
|
||||
---------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
debug_unwrap
|
@ -1,83 +1,75 @@
|
||||
Patching Batch Norm
|
||||
===================
|
||||
# Patching Batch Norm
|
||||
|
||||
What's happening?
|
||||
-----------------
|
||||
## What's happening?
|
||||
Batch Norm requires in-place updates to running_mean and running_var of the same size as the input.
|
||||
Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e.
|
||||
``regular.add_(batched)`` is not allowed). So when vmapping over a batch of inputs to a single module,
|
||||
`regular.add_(batched)` is not allowed). So when vmapping over a batch of inputs to a single module,
|
||||
we end up with this error
|
||||
|
||||
How to fix
|
||||
----------
|
||||
## How to fix
|
||||
One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this
|
||||
|
||||
All of these options assume that you don't need running stats. If you're using a module this means
|
||||
that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves
|
||||
running batch norm with vmap in evaluation mode, please file an issue
|
||||
|
||||
Option 1: Change the BatchNorm
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
### Option 1: Change the BatchNorm
|
||||
If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with:
|
||||
|
||||
.. code-block:: python
|
||||
```python
|
||||
BatchNorm2d(C, G, track_running_stats=False)
|
||||
```
|
||||
|
||||
BatchNorm2d(C, G, track_running_stats=False)
|
||||
|
||||
Here ``C`` is the same ``C`` as in the original BatchNorm. ``G`` is the number of groups to
|
||||
break ``C`` into. As such, ``C % G == 0`` and as a fallback, you can set ``C == G``, meaning
|
||||
Here `C` is the same `C` as in the original BatchNorm. `G` is the number of groups to
|
||||
break `C` into. As such, `C % G == 0` and as a fallback, you can set `C == G`, meaning
|
||||
each channel will be treated separately.
|
||||
|
||||
If you must use BatchNorm and you've built the module yourself, you can change the module to
|
||||
not use running stats. In other words, anywhere that there's a BatchNorm module, set the
|
||||
``track_running_stats`` flag to be False
|
||||
`track_running_stats` flag to be False
|
||||
|
||||
.. code-block:: python
|
||||
```python
|
||||
BatchNorm2d(64, track_running_stats=False)
|
||||
```
|
||||
|
||||
BatchNorm2d(64, track_running_stats=False)
|
||||
|
||||
|
||||
Option 2: torchvision parameter
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are
|
||||
### Option 2: torchvision parameter
|
||||
Some torchvision models, like resnet and regnet, can take in a `norm_layer` parameter. These are
|
||||
often defaulted to be BatchNorm2d if they've been defaulted.
|
||||
|
||||
Instead you can set it to be GroupNorm.
|
||||
|
||||
.. code-block:: python
|
||||
```python
|
||||
import torchvision
|
||||
from functools import partial
|
||||
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
|
||||
```
|
||||
|
||||
import torchvision
|
||||
from functools import partial
|
||||
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
|
||||
|
||||
Here, once again, ``c % g == 0`` so as a fallback, set ``g = c``.
|
||||
Here, once again, `c % g == 0` so as a fallback, set `g = c`.
|
||||
|
||||
If you are attached to BatchNorm, be sure to use a version that doesn't use running stats
|
||||
|
||||
.. code-block:: python
|
||||
```python
|
||||
import torchvision
|
||||
from functools import partial
|
||||
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
|
||||
```
|
||||
|
||||
import torchvision
|
||||
from functools import partial
|
||||
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
|
||||
|
||||
Option 3: functorch's patching
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
### Option 3: functorch's patching
|
||||
functorch has added some functionality to allow for quick, in-place patching of the module to not
|
||||
use running stats. Changing the norm layer is more fragile, so we have not offered that. If you
|
||||
have a net where you want the BatchNorm to not use running stats, you can run
|
||||
``replace_all_batch_norm_modules_`` to update the module in-place to not use running stats
|
||||
`replace_all_batch_norm_modules_` to update the module in-place to not use running stats
|
||||
|
||||
.. code-block:: python
|
||||
```python
|
||||
from torch.func import replace_all_batch_norm_modules_
|
||||
replace_all_batch_norm_modules_(net)
|
||||
```
|
||||
|
||||
from torch.func import replace_all_batch_norm_modules_
|
||||
replace_all_batch_norm_modules_(net)
|
||||
|
||||
Option 4: eval mode
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
### Option 4: eval mode
|
||||
When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model.eval()
|
||||
vmap(model)(x)
|
||||
model.train()
|
||||
```python
|
||||
model.eval()
|
||||
vmap(model)(x)
|
||||
model.train()
|
||||
```
|
56
docs/source/func.md
Normal file
56
docs/source/func.md
Normal file
@ -0,0 +1,56 @@
|
||||
# torch.func
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.func
|
||||
```
|
||||
|
||||
torch.func, previously known as "functorch", is
|
||||
[JAX-like](https://github.com/google/jax) composable function transforms for PyTorch.
|
||||
|
||||
```{note}
|
||||
This library is currently in [beta](https://pytorch.org/blog/pytorch-feature-classification-changes/#beta).
|
||||
What this means is that the features generally work (unless otherwise documented)
|
||||
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
|
||||
may change under user feedback and we don't have full coverage over PyTorch operations.
|
||||
|
||||
If you have suggestions on the API or use-cases you'd like to be covered, please
|
||||
open a GitHub issue or reach out. We'd love to hear about how you're using the library.
|
||||
```
|
||||
|
||||
## What are composable function transforms?
|
||||
|
||||
- A "function transform" is a higher-order function that accepts a numerical function
|
||||
and returns a new function that computes a different quantity.
|
||||
|
||||
- {mod}`torch.func` has auto-differentiation transforms (`grad(f)` returns a function that
|
||||
computes the gradient of `f`), a vectorization/batching transform (`vmap(f)`
|
||||
returns a function that computes `f` over batches of inputs), and others.
|
||||
|
||||
- These function transforms can compose with each other arbitrarily. For example,
|
||||
composing `vmap(grad(f))` computes a quantity called per-sample-gradients that
|
||||
stock PyTorch cannot efficiently compute today.
|
||||
|
||||
## Why composable function transforms?
|
||||
|
||||
There are a number of use cases that are tricky to do in PyTorch today:
|
||||
|
||||
- computing per-sample-gradients (or other per-sample quantities)
|
||||
- running ensembles of models on a single machine
|
||||
- efficiently batching together tasks in the inner-loop of MAML
|
||||
- efficiently computing Jacobians and Hessians
|
||||
- efficiently computing batched Jacobians and Hessians
|
||||
|
||||
Composing {func}`vmap`, {func}`grad`, and {func}`vjp` transforms allows us to express the above without designing a separate subsystem for each.
|
||||
This idea of composable function transforms comes from the [JAX framework](https://github.com/google/jax).
|
||||
|
||||
## Read More
|
||||
|
||||
```{eval-rst}
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
func.whirlwind_tour
|
||||
func.api
|
||||
func.ux_limitations
|
||||
func.migrating
|
||||
```
|
201
docs/source/func.migrating.md
Normal file
201
docs/source/func.migrating.md
Normal file
@ -0,0 +1,201 @@
|
||||
# Migrating from functorch to torch.func
|
||||
|
||||
torch.func, previously known as "functorch", is
|
||||
[JAX-like](https://github.com/google/jax) composable function transforms for PyTorch.
|
||||
|
||||
functorch started as an out-of-tree library over at
|
||||
the [pytorch/functorch](https://github.com/pytorch/functorch) repository.
|
||||
Our goal has always been to upstream functorch directly into PyTorch and provide
|
||||
it as a core PyTorch library.
|
||||
|
||||
As the final step of the upstream, we've decided to migrate from being a top level package
|
||||
(`functorch`) to being a part of PyTorch to reflect how the function transforms are
|
||||
integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating
|
||||
`import functorch` and ask that users migrate to the newest APIs, which we
|
||||
will maintain going forward. `import functorch` will be kept around to maintain
|
||||
backwards compatibility for a couple of releases.
|
||||
|
||||
## function transforms
|
||||
|
||||
The following APIs are a drop-in replacement for the following
|
||||
[functorch APIs](https://pytorch.org/functorch/1.13/functorch.html).
|
||||
They are fully backwards compatible.
|
||||
|
||||
| functorch API | PyTorch API (as of PyTorch 2.0) |
|
||||
| ----------------------------------- | ---------------------------------------------- |
|
||||
| functorch.vmap | {func}`torch.vmap` or {func}`torch.func.vmap` |
|
||||
| functorch.grad | {func}`torch.func.grad` |
|
||||
| functorch.vjp | {func}`torch.func.vjp` |
|
||||
| functorch.jvp | {func}`torch.func.jvp` |
|
||||
| functorch.jacrev | {func}`torch.func.jacrev` |
|
||||
| functorch.jacfwd | {func}`torch.func.jacfwd` |
|
||||
| functorch.hessian | {func}`torch.func.hessian` |
|
||||
| functorch.functionalize | {func}`torch.func.functionalize` |
|
||||
|
||||
Furthermore, if you are using torch.autograd.functional APIs, please try out
|
||||
the {mod}`torch.func` equivalents instead. {mod}`torch.func` function
|
||||
transforms are more composable and more performant in many cases.
|
||||
|
||||
| torch.autograd.functional API | torch.func API (as of PyTorch 2.0) |
|
||||
| ------------------------------------------- | ---------------------------------------------- |
|
||||
| {func}`torch.autograd.functional.vjp` | {func}`torch.func.grad` or {func}`torch.func.vjp` |
|
||||
| {func}`torch.autograd.functional.jvp` | {func}`torch.func.jvp` |
|
||||
| {func}`torch.autograd.functional.jacobian` | {func}`torch.func.jacrev` or {func}`torch.func.jacfwd` |
|
||||
| {func}`torch.autograd.functional.hessian` | {func}`torch.func.hessian` |
|
||||
|
||||
## NN module utilities
|
||||
|
||||
We've changed the APIs to apply function transforms over NN modules to make them
|
||||
fit better into the PyTorch design philosophy. The new API is different, so
|
||||
please read this section carefully.
|
||||
|
||||
### functorch.make_functional
|
||||
|
||||
{func}`torch.func.functional_call` is the replacement for
|
||||
[functorch.make_functional](https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional)
|
||||
and
|
||||
[functorch.make_functional_with_buffers](https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers).
|
||||
However, it is not a drop-in replacement.
|
||||
|
||||
If you're in a hurry, you can use
|
||||
[helper functions in this gist](https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf)
|
||||
that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.
|
||||
We recommend using {func}`torch.func.functional_call` directly because it is a more explicit
|
||||
and flexible API.
|
||||
|
||||
Concretely, functorch.make_functional returns a functional module and parameters.
|
||||
The functional module accepts parameters and inputs to the model as arguments.
|
||||
{func}`torch.func.functional_call` allows one to call the forward pass of an existing
|
||||
module using new parameters and buffers and inputs.
|
||||
|
||||
Here's an example of how to compute gradients of parameters of a model using functorch
|
||||
vs {mod}`torch.func`:
|
||||
|
||||
```python
|
||||
# ---------------
|
||||
# using functorch
|
||||
# ---------------
|
||||
import torch
|
||||
import functorch
|
||||
inputs = torch.randn(64, 3)
|
||||
targets = torch.randn(64, 3)
|
||||
model = torch.nn.Linear(3, 3)
|
||||
|
||||
fmodel, params = functorch.make_functional(model)
|
||||
|
||||
def compute_loss(params, inputs, targets):
|
||||
prediction = fmodel(params, inputs)
|
||||
return torch.nn.functional.mse_loss(prediction, targets)
|
||||
|
||||
grads = functorch.grad(compute_loss)(params, inputs, targets)
|
||||
|
||||
# ------------------------------------
|
||||
# using torch.func (as of PyTorch 2.0)
|
||||
# ------------------------------------
|
||||
import torch
|
||||
inputs = torch.randn(64, 3)
|
||||
targets = torch.randn(64, 3)
|
||||
model = torch.nn.Linear(3, 3)
|
||||
|
||||
params = dict(model.named_parameters())
|
||||
|
||||
def compute_loss(params, inputs, targets):
|
||||
prediction = torch.func.functional_call(model, params, (inputs,))
|
||||
return torch.nn.functional.mse_loss(prediction, targets)
|
||||
|
||||
grads = torch.func.grad(compute_loss)(params, inputs, targets)
|
||||
```
|
||||
|
||||
And here's an example of how to compute jacobians of model parameters:
|
||||
|
||||
```python
|
||||
# ---------------
|
||||
# using functorch
|
||||
# ---------------
|
||||
import torch
|
||||
import functorch
|
||||
inputs = torch.randn(64, 3)
|
||||
model = torch.nn.Linear(3, 3)
|
||||
|
||||
fmodel, params = functorch.make_functional(model)
|
||||
jacobians = functorch.jacrev(fmodel)(params, inputs)
|
||||
|
||||
# ------------------------------------
|
||||
# using torch.func (as of PyTorch 2.0)
|
||||
# ------------------------------------
|
||||
import torch
|
||||
from torch.func import jacrev, functional_call
|
||||
inputs = torch.randn(64, 3)
|
||||
model = torch.nn.Linear(3, 3)
|
||||
|
||||
params = dict(model.named_parameters())
|
||||
# jacrev computes jacobians of argnums=0 by default.
|
||||
# We set it to 1 to compute jacobians of params
|
||||
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
|
||||
```
|
||||
|
||||
Note that it is important for memory consumption that you should only carry
|
||||
around a single copy of your parameters. `model.named_parameters()` does not copy
|
||||
the parameters. If in your model training you update the parameters of the model
|
||||
in-place, then the `nn.Module` that is your model has the single copy of the
|
||||
parameters and everything is OK.
|
||||
|
||||
However, if you want to carry your parameters around in a dictionary and update
|
||||
them out-of-place, then there are two copies of parameters: the one in the
|
||||
dictionary and the one in the `model`. In this case, you should change
|
||||
`model` to not hold memory by converting it to the meta device via
|
||||
`model.to('meta')`.
|
||||
|
||||
### functorch.combine_state_for_ensemble
|
||||
|
||||
Please use {func}`torch.func.stack_module_state` instead of
|
||||
[functorch.combine_state_for_ensemble](https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html)
|
||||
{func}`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and
|
||||
one of stacked buffers, that can then be used with {func}`torch.vmap` and {func}`torch.func.functional_call`
|
||||
for ensembling.
|
||||
|
||||
For example, here is an example of how to ensemble over a very simple model:
|
||||
|
||||
```python
|
||||
import torch
|
||||
num_models = 5
|
||||
batch_size = 64
|
||||
in_features, out_features = 3, 3
|
||||
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
|
||||
data = torch.randn(batch_size, 3)
|
||||
|
||||
# ---------------
|
||||
# using functorch
|
||||
# ---------------
|
||||
import functorch
|
||||
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
|
||||
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
|
||||
assert output.shape == (num_models, batch_size, out_features)
|
||||
|
||||
# ------------------------------------
|
||||
# using torch.func (as of PyTorch 2.0)
|
||||
# ------------------------------------
|
||||
import copy
|
||||
|
||||
# Construct a version of the model with no memory by putting the Tensors on
|
||||
# the meta device.
|
||||
base_model = copy.deepcopy(models[0])
|
||||
base_model.to('meta')
|
||||
|
||||
params, buffers = torch.func.stack_module_state(models)
|
||||
|
||||
# It is possible to vmap directly over torch.func.functional_call,
|
||||
# but wrapping it in a function makes it clearer what is going on.
|
||||
def call_single_model(params, buffers, data):
|
||||
return torch.func.functional_call(base_model, (params, buffers), (data,))
|
||||
|
||||
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
|
||||
assert output.shape == (num_models, batch_size, out_features)
|
||||
```
|
||||
|
||||
## functorch.compile
|
||||
|
||||
We are no longer supporting functorch.compile (also known as AOTAutograd)
|
||||
as a frontend for compilation in PyTorch; we have integrated AOTAutograd
|
||||
into PyTorch's compilation story. If you are a user, please use
|
||||
{func}`torch.compile` instead.
|
@ -1,207 +0,0 @@
|
||||
Migrating from functorch to torch.func
|
||||
======================================
|
||||
|
||||
torch.func, previously known as "functorch", is
|
||||
`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
|
||||
|
||||
functorch started as an out-of-tree library over at
|
||||
the `pytorch/functorch <https://github.com/pytorch/functorch>`_ repository.
|
||||
Our goal has always been to upstream functorch directly into PyTorch and provide
|
||||
it as a core PyTorch library.
|
||||
|
||||
As the final step of the upstream, we've decided to migrate from being a top level package
|
||||
(``functorch``) to being a part of PyTorch to reflect how the function transforms are
|
||||
integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating
|
||||
``import functorch`` and ask that users migrate to the newest APIs, which we
|
||||
will maintain going forward. ``import functorch`` will be kept around to maintain
|
||||
backwards compatibility for a couple of releases.
|
||||
|
||||
function transforms
|
||||
-------------------
|
||||
|
||||
The following APIs are a drop-in replacement for the following
|
||||
`functorch APIs <https://pytorch.org/functorch/1.13/functorch.html>`_.
|
||||
They are fully backwards compatible.
|
||||
|
||||
|
||||
============================== =======================================
|
||||
functorch API PyTorch API (as of PyTorch 2.0)
|
||||
============================== =======================================
|
||||
functorch.vmap :func:`torch.vmap` or :func:`torch.func.vmap`
|
||||
functorch.grad :func:`torch.func.grad`
|
||||
functorch.vjp :func:`torch.func.vjp`
|
||||
functorch.jvp :func:`torch.func.jvp`
|
||||
functorch.jacrev :func:`torch.func.jacrev`
|
||||
functorch.jacfwd :func:`torch.func.jacfwd`
|
||||
functorch.hessian :func:`torch.func.hessian`
|
||||
functorch.functionalize :func:`torch.func.functionalize`
|
||||
============================== =======================================
|
||||
|
||||
Furthermore, if you are using torch.autograd.functional APIs, please try out
|
||||
the :mod:`torch.func` equivalents instead. :mod:`torch.func` function
|
||||
transforms are more composable and more performant in many cases.
|
||||
|
||||
=========================================== =======================================
|
||||
torch.autograd.functional API torch.func API (as of PyTorch 2.0)
|
||||
=========================================== =======================================
|
||||
:func:`torch.autograd.functional.vjp` :func:`torch.func.grad` or :func:`torch.func.vjp`
|
||||
:func:`torch.autograd.functional.jvp` :func:`torch.func.jvp`
|
||||
:func:`torch.autograd.functional.jacobian` :func:`torch.func.jacrev` or :func:`torch.func.jacfwd`
|
||||
:func:`torch.autograd.functional.hessian` :func:`torch.func.hessian`
|
||||
=========================================== =======================================
|
||||
|
||||
NN module utilities
|
||||
-------------------
|
||||
|
||||
We've changed the APIs to apply function transforms over NN modules to make them
|
||||
fit better into the PyTorch design philosophy. The new API is different, so
|
||||
please read this section carefully.
|
||||
|
||||
functorch.make_functional
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
:func:`torch.func.functional_call` is the replacement for
|
||||
`functorch.make_functional <https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional>`_
|
||||
and
|
||||
`functorch.make_functional_with_buffers <https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers>`_.
|
||||
However, it is not a drop-in replacement.
|
||||
|
||||
If you're in a hurry, you can use
|
||||
`helper functions in this gist <https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf>`_
|
||||
that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.
|
||||
We recommend using :func:`torch.func.functional_call` directly because it is a more explicit
|
||||
and flexible API.
|
||||
|
||||
Concretely, functorch.make_functional returns a functional module and parameters.
|
||||
The functional module accepts parameters and inputs to the model as arguments.
|
||||
:func:`torch.func.functional_call` allows one to call the forward pass of an existing
|
||||
module using new parameters and buffers and inputs.
|
||||
|
||||
Here's an example of how to compute gradients of parameters of a model using functorch
|
||||
vs :mod:`torch.func`::
|
||||
|
||||
# ---------------
|
||||
# using functorch
|
||||
# ---------------
|
||||
import torch
|
||||
import functorch
|
||||
inputs = torch.randn(64, 3)
|
||||
targets = torch.randn(64, 3)
|
||||
model = torch.nn.Linear(3, 3)
|
||||
|
||||
fmodel, params = functorch.make_functional(model)
|
||||
|
||||
def compute_loss(params, inputs, targets):
|
||||
prediction = fmodel(params, inputs)
|
||||
return torch.nn.functional.mse_loss(prediction, targets)
|
||||
|
||||
grads = functorch.grad(compute_loss)(params, inputs, targets)
|
||||
|
||||
# ------------------------------------
|
||||
# using torch.func (as of PyTorch 2.0)
|
||||
# ------------------------------------
|
||||
import torch
|
||||
inputs = torch.randn(64, 3)
|
||||
targets = torch.randn(64, 3)
|
||||
model = torch.nn.Linear(3, 3)
|
||||
|
||||
params = dict(model.named_parameters())
|
||||
|
||||
def compute_loss(params, inputs, targets):
|
||||
prediction = torch.func.functional_call(model, params, (inputs,))
|
||||
return torch.nn.functional.mse_loss(prediction, targets)
|
||||
|
||||
grads = torch.func.grad(compute_loss)(params, inputs, targets)
|
||||
|
||||
And here's an example of how to compute jacobians of model parameters::
|
||||
|
||||
# ---------------
|
||||
# using functorch
|
||||
# ---------------
|
||||
import torch
|
||||
import functorch
|
||||
inputs = torch.randn(64, 3)
|
||||
model = torch.nn.Linear(3, 3)
|
||||
|
||||
fmodel, params = functorch.make_functional(model)
|
||||
jacobians = functorch.jacrev(fmodel)(params, inputs)
|
||||
|
||||
# ------------------------------------
|
||||
# using torch.func (as of PyTorch 2.0)
|
||||
# ------------------------------------
|
||||
import torch
|
||||
from torch.func import jacrev, functional_call
|
||||
inputs = torch.randn(64, 3)
|
||||
model = torch.nn.Linear(3, 3)
|
||||
|
||||
params = dict(model.named_parameters())
|
||||
# jacrev computes jacobians of argnums=0 by default.
|
||||
# We set it to 1 to compute jacobians of params
|
||||
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
|
||||
|
||||
Note that it is important for memory consumption that you should only carry
|
||||
around a single copy of your parameters. ``model.named_parameters()`` does not copy
|
||||
the parameters. If in your model training you update the parameters of the model
|
||||
in-place, then the ``nn.Module`` that is your model has the single copy of the
|
||||
parameters and everything is OK.
|
||||
|
||||
However, if you want to carry your parameters around in a dictionary and update
|
||||
them out-of-place, then there are two copies of parameters: the one in the
|
||||
dictionary and the one in the ``model``. In this case, you should change
|
||||
``model`` to not hold memory by converting it to the meta device via
|
||||
``model.to('meta')``.
|
||||
|
||||
functorch.combine_state_for_ensemble
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Please use :func:`torch.func.stack_module_state` instead of
|
||||
`functorch.combine_state_for_ensemble <https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html>`_
|
||||
:func:`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and
|
||||
one of stacked buffers, that can then be used with :func:`torch.vmap` and :func:`torch.func.functional_call`
|
||||
for ensembling.
|
||||
|
||||
For example, here is an example of how to ensemble over a very simple model::
|
||||
|
||||
import torch
|
||||
num_models = 5
|
||||
batch_size = 64
|
||||
in_features, out_features = 3, 3
|
||||
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
|
||||
data = torch.randn(batch_size, 3)
|
||||
|
||||
# ---------------
|
||||
# using functorch
|
||||
# ---------------
|
||||
import functorch
|
||||
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
|
||||
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
|
||||
assert output.shape == (num_models, batch_size, out_features)
|
||||
|
||||
# ------------------------------------
|
||||
# using torch.func (as of PyTorch 2.0)
|
||||
# ------------------------------------
|
||||
import copy
|
||||
|
||||
# Construct a version of the model with no memory by putting the Tensors on
|
||||
# the meta device.
|
||||
base_model = copy.deepcopy(models[0])
|
||||
base_model.to('meta')
|
||||
|
||||
params, buffers = torch.func.stack_module_state(models)
|
||||
|
||||
# It is possible to vmap directly over torch.func.functional_call,
|
||||
# but wrapping it in a function makes it clearer what is going on.
|
||||
def call_single_model(params, buffers, data):
|
||||
return torch.func.functional_call(base_model, (params, buffers), (data,))
|
||||
|
||||
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
|
||||
assert output.shape == (num_models, batch_size, out_features)
|
||||
|
||||
|
||||
functorch.compile
|
||||
-----------------
|
||||
|
||||
We are no longer supporting functorch.compile (also known as AOTAutograd)
|
||||
as a frontend for compilation in PyTorch; we have integrated AOTAutograd
|
||||
into PyTorch's compilation story. If you are a user, please use
|
||||
:func:`torch.compile` instead.
|
@ -1,55 +0,0 @@
|
||||
torch.func
|
||||
==========
|
||||
|
||||
.. currentmodule:: torch.func
|
||||
|
||||
torch.func, previously known as "functorch", is
|
||||
`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
|
||||
|
||||
.. note::
|
||||
This library is currently in `beta <https://pytorch.org/blog/pytorch-feature-classification-changes/#beta>`_.
|
||||
What this means is that the features generally work (unless otherwise documented)
|
||||
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
|
||||
may change under user feedback and we don't have full coverage over PyTorch operations.
|
||||
|
||||
If you have suggestions on the API or use-cases you'd like to be covered, please
|
||||
open an GitHub issue or reach out. We'd love to hear about how you're using the library.
|
||||
|
||||
What are composable function transforms?
|
||||
----------------------------------------
|
||||
|
||||
- A "function transform" is a higher-order function that accepts a numerical function
|
||||
and returns a new function that computes a different quantity.
|
||||
|
||||
- :mod:`torch.func` has auto-differentiation transforms (``grad(f)`` returns a function that
|
||||
computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)``
|
||||
returns a function that computes ``f`` over batches of inputs), and others.
|
||||
|
||||
- These function transforms can compose with each other arbitrarily. For example,
|
||||
composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that
|
||||
stock PyTorch cannot efficiently compute today.
|
||||
|
||||
Why composable function transforms?
|
||||
-----------------------------------
|
||||
|
||||
There are a number of use cases that are tricky to do in PyTorch today:
|
||||
|
||||
- computing per-sample-gradients (or other per-sample quantities)
|
||||
- running ensembles of models on a single machine
|
||||
- efficiently batching together tasks in the inner-loop of MAML
|
||||
- efficiently computing Jacobians and Hessians
|
||||
- efficiently computing batched Jacobians and Hessians
|
||||
|
||||
Composing :func:`vmap`, :func:`grad`, and :func:`vjp` transforms allows us to express the above without designing a separate subsystem for each.
|
||||
This idea of composable function transforms comes from the `JAX framework <https://github.com/google/jax>`_.
|
||||
|
||||
Read More
|
||||
---------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
func.whirlwind_tour
|
||||
func.api
|
||||
func.ux_limitations
|
||||
func.migrating
|
@ -117,12 +117,10 @@ class OptimStateKeyType(Enum):
|
||||
class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||
"""A wrapper for sharding module parameters across data parallel workers.
|
||||
|
||||
This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
|
||||
This is inspired by `Xu et al. <https://arxiv.org/abs/2004.13336>`_ as
|
||||
well as the ZeRO Stage 3 from `DeepSpeed <https://www.deepspeed.ai/>`_.
|
||||
FullyShardedDataParallel is commonly shortened to FSDP.
|
||||
|
||||
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
|
||||
.. _DeepSpeed: https://www.deepspeed.ai/
|
||||
|
||||
To understand FSDP internals, refer to the
|
||||
:ref:`fsdp_notes`.
|
||||
|
||||
|
Reference in New Issue
Block a user