mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #155021 Fixes #155158 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155369 Approved by: https://github.com/svekars, https://github.com/malfet
89 lines
1.9 KiB
Markdown
89 lines
1.9 KiB
Markdown
# torch.func API Reference
|
|
|
|
```{eval-rst}
|
|
.. currentmodule:: torch.func
|
|
```
|
|
|
|
```{eval-rst}
|
|
.. automodule:: torch.func
|
|
```
|
|
|
|
## Function Transforms
|
|
```{eval-rst}
|
|
.. autosummary::
|
|
:toctree: generated
|
|
:nosignatures:
|
|
|
|
vmap
|
|
grad
|
|
grad_and_value
|
|
vjp
|
|
jvp
|
|
linearize
|
|
jacrev
|
|
jacfwd
|
|
hessian
|
|
functionalize
|
|
```
|
|
|
|
## Utilities for working with torch.nn.Modules
|
|
|
|
In general, you can transform over a function that calls a ``torch.nn.Module``.
|
|
For example, the following is an example of computing a jacobian of a function
|
|
that takes three values and returns three values:
|
|
|
|
```python
|
|
model = torch.nn.Linear(3, 3)
|
|
|
|
def f(x):
|
|
return model(x)
|
|
|
|
x = torch.randn(3)
|
|
jacobian = jacrev(f)(x)
|
|
assert jacobian.shape == (3, 3)
|
|
```
|
|
|
|
However, if you want to do something like compute a jacobian over the parameters of the model, then there needs to be a way to construct a function where the parameters are the inputs to the function. That's what {func}`functional_call` is for: it accepts an nn.Module, the transformed `parameters`, and the inputs to the Module's forward pass. It returns the value of running the Module's forward pass with the replaced parameters.
|
|
|
|
Here's how we would compute the Jacobian over the parameters
|
|
|
|
```python
|
|
model = torch.nn.Linear(3, 3)
|
|
|
|
def f(params, x):
|
|
return torch.func.functional_call(model, params, x)
|
|
|
|
x = torch.randn(3)
|
|
jacobian = jacrev(f)(dict(model.named_parameters()), x)
|
|
```
|
|
|
|
```{eval-rst}
|
|
.. autosummary::
|
|
:toctree: generated
|
|
:nosignatures:
|
|
|
|
functional_call
|
|
stack_module_state
|
|
replace_all_batch_norm_modules_
|
|
```
|
|
|
|
If you're looking for information on fixing Batch Norm modules, please follow the
|
|
guidance here
|
|
|
|
```{eval-rst}
|
|
.. toctree::
|
|
:maxdepth: 1
|
|
|
|
func.batch_norm
|
|
```
|
|
|
|
## Debug utilities
|
|
|
|
```{eval-rst}
|
|
.. autosummary::
|
|
:toctree: generated
|
|
:nosignatures:
|
|
|
|
debug_unwrap
|
|
```
|