Files
pytorch/docs/source/func.api.rst
rzou 15b1ac3e86 Add torch.func.debug_unwrap (#146528)
Use it to unwrap any functorch-wrapped tensor. I don't recommend using
the output in a program since it breaks the semantics of the transforms,
but it seems useful for debugging.

I will note that some people have wanted to get intermediate values out
of an e.g. grad transform, so this might be a way to do that...

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146528
Approved by: https://github.com/Chillee
2025-02-06 18:48:09 +00:00

88 lines
1.9 KiB
ReStructuredText

torch.func API Reference
========================
.. currentmodule:: torch.func
.. automodule:: torch.func
Function Transforms
-------------------
.. autosummary::
:toctree: generated
:nosignatures:
vmap
grad
grad_and_value
vjp
jvp
linearize
jacrev
jacfwd
hessian
functionalize
Utilities for working with torch.nn.Modules
-------------------------------------------
In general, you can transform over a function that calls a ``torch.nn.Module``.
For example, the following is an example of computing a jacobian of a function
that takes three values and returns three values:
.. code-block:: python
model = torch.nn.Linear(3, 3)
def f(x):
return model(x)
x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
However, if you want to do something like compute a jacobian over the parameters
of the model, then there needs to be a way to construct a function where the
parameters are the inputs to the function.
That's what :func:`functional_call` is for:
it accepts an nn.Module, the transformed ``parameters``, and the inputs to the
Module's forward pass. It returns the value of running the Module's forward pass
with the replaced parameters.
Here's how we would compute the Jacobian over the parameters
.. code-block:: python
model = torch.nn.Linear(3, 3)
def f(params, x):
return torch.func.functional_call(model, params, x)
x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
.. autosummary::
:toctree: generated
:nosignatures:
functional_call
stack_module_state
replace_all_batch_norm_modules_
If you're looking for information on fixing Batch Norm modules, please follow the
guidance here
.. toctree::
:maxdepth: 1
func.batch_norm
Debug utilities
---------------
.. autosummary::
:toctree: generated
:nosignatures:
debug_unwrap