mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 00:54:52 +08:00
This PR fixes typo and URL (`http -> https`) in `rst` files under `docs` directory Pull Request resolved: https://github.com/pytorch/pytorch/pull/92762 Approved by: https://github.com/H-Huang
56 lines
2.2 KiB
ReStructuredText
56 lines
2.2 KiB
ReStructuredText
torch.func
|
|
==========
|
|
|
|
.. currentmodule:: torch.func
|
|
|
|
torch.func, previously known as "functorch", is
|
|
`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
|
|
|
|
.. note::
|
|
This library is currently in `beta <https://pytorch.org/blog/pytorch-feature-classification-changes/#beta>`_.
|
|
What this means is that the features generally work (unless otherwise documented)
|
|
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
|
|
may change under user feedback and we don't have full coverage over PyTorch operations.
|
|
|
|
If you have suggestions on the API or use-cases you'd like to be covered, please
|
|
open an GitHub issue or reach out. We'd love to hear about how you're using the library.
|
|
|
|
What are composable function transforms?
|
|
----------------------------------------
|
|
|
|
- A "function transform" is a higher-order function that accepts a numerical function
|
|
and returns a new function that computes a different quantity.
|
|
|
|
- :mod:`torch.func` has auto-differentiation transforms (``grad(f)`` returns a function that
|
|
computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)``
|
|
returns a function that computes ``f`` over batches of inputs), and others.
|
|
|
|
- These function transforms can compose with each other arbitrarily. For example,
|
|
composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that
|
|
stock PyTorch cannot efficiently compute today.
|
|
|
|
Why composable function transforms?
|
|
-----------------------------------
|
|
|
|
There are a number of use cases that are tricky to do in PyTorch today:
|
|
|
|
- computing per-sample-gradients (or other per-sample quantities)
|
|
- running ensembles of models on a single machine
|
|
- efficiently batching together tasks in the inner-loop of MAML
|
|
- efficiently computing Jacobians and Hessians
|
|
- efficiently computing batched Jacobians and Hessians
|
|
|
|
Composing :func:`vmap`, :func:`grad`, and :func:`vjp` transforms allows us to express the above without designing a separate subsystem for each.
|
|
This idea of composable function transforms comes from the `JAX framework <https://github.com/google/jax>`_.
|
|
|
|
Read More
|
|
---------
|
|
|
|
.. toctree::
|
|
:maxdepth: 2
|
|
|
|
func.whirlwind_tour
|
|
func.api
|
|
func.ux_limitations
|
|
func.migrating
|