Fixes #155021 Fixes #155158 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155369 Approved by: https://github.com/svekars, https://github.com/malfet
1.9 KiB
torch.func API Reference
.. currentmodule:: torch.func
.. automodule:: torch.func
Function Transforms
.. autosummary::
:toctree: generated
:nosignatures:
vmap
grad
grad_and_value
vjp
jvp
linearize
jacrev
jacfwd
hessian
functionalize
Utilities for working with torch.nn.Modules
In general, you can transform over a function that calls a torch.nn.Module
.
For example, the following is an example of computing a jacobian of a function
that takes three values and returns three values:
model = torch.nn.Linear(3, 3)
def f(x):
return model(x)
x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
However, if you want to do something like compute a jacobian over the parameters of the model, then there needs to be a way to construct a function where the parameters are the inputs to the function. That's what {func}functional_call
is for: it accepts an nn.Module, the transformed parameters
, and the inputs to the Module's forward pass. It returns the value of running the Module's forward pass with the replaced parameters.
Here's how we would compute the Jacobian over the parameters
model = torch.nn.Linear(3, 3)
def f(params, x):
return torch.func.functional_call(model, params, x)
x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
.. autosummary::
:toctree: generated
:nosignatures:
functional_call
stack_module_state
replace_all_batch_norm_modules_
If you're looking for information on fixing Batch Norm modules, please follow the guidance here
.. toctree::
:maxdepth: 1
func.batch_norm
Debug utilities
.. autosummary::
:toctree: generated
:nosignatures:
debug_unwrap