mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Better intro to functorch (pytorch/functorch#688)
For our docs landing page. Fixes https://github.com/pytorch/functorch/issues/605
This commit is contained in:
@ -7,11 +7,8 @@ functorch
|
||||
|
||||
functorch is `JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
|
||||
|
||||
It aims to provide composable vmap and grad transforms that work with PyTorch modules
|
||||
and PyTorch autograd with good eager-mode performance.
|
||||
|
||||
.. note::
|
||||
This library is currently in [beta](https://pytorch.org/blog/pytorch-feature-classification-changes/#beta).
|
||||
This library is currently in `beta <https://pytorch.org/blog/pytorch-feature-classification-changes/#beta>`_.
|
||||
What this means is that the features generally work (unless otherwise documented)
|
||||
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
|
||||
may change under user feedback and we don't have full coverage over PyTorch operations.
|
||||
@ -19,6 +16,20 @@ and PyTorch autograd with good eager-mode performance.
|
||||
If you have suggestions on the API or use-cases you'd like to be covered, please
|
||||
open an github issue or reach out. We'd love to hear about how you're using the library.
|
||||
|
||||
What are composable function transforms?
|
||||
----------------------------------------
|
||||
|
||||
- A "function transform" is a higher-order function that accepts a numerical function
|
||||
and returns a new function that computes a different quantity.
|
||||
|
||||
- functorch has auto-differentiation transforms (``grad(f)`` returns a function that
|
||||
computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)``
|
||||
returns a function that computes ``f`` over batches of inputs), and others.
|
||||
|
||||
- These function transforms can compose with each other arbitrarily. For example,
|
||||
composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that
|
||||
stock PyTorch cannot efficiently compute today.
|
||||
|
||||
Why composable function transforms?
|
||||
-----------------------------------
|
||||
|
||||
@ -36,7 +47,7 @@ This idea of composable function transforms comes from the `JAX framework <https
|
||||
Read More
|
||||
---------
|
||||
|
||||
For a whirlwind tour of how to use the transforms, please check out `this section in our README <https://github.com/pytorch/functorch/blob/main/README.md#what-are-the-transforms>`_. For installation instructions or the API reference, please check below.
|
||||
Check out our `whirlwind tour <whirlwind_tour>`_ or some of our tutorials mentioned below.
|
||||
|
||||
|
||||
.. toctree::
|
||||
|
||||
@ -7,7 +7,15 @@
|
||||
"source": [
|
||||
"# Whirlwind Tour\n",
|
||||
"\n",
|
||||
"functorch is [JAX](https://github.com/google/jax)-like composable function transforms for PyTorch. In this whirlwind tour, we'll introduce all the functorch transforms.\n",
|
||||
"\n",
|
||||
"## What is functorch?\n",
|
||||
"\n",
|
||||
"functorch is a library for [JAX](https://github.com/google/jax)-like composable function transforms in PyTorch.\n",
|
||||
"- A \"function transform\" is a higher-order function that accepts a numerical function and returns a new function that computes a different quantity.\n",
|
||||
"- functorch has auto-differentiation transforms (`grad(f)` returns a function that computes the gradient of `f`), a vectorization/batching transform (`vmap(f)` returns a function that computes `f` over batches of inputs), and others.\n",
|
||||
"- These function transforms can compose with each other arbitrarily. For example, composing `vmap(grad(f))` computes a quantity called per-sample-gradients that stock PyTorch cannot efficiently compute today.\n",
|
||||
"\n",
|
||||
"Furthermore, we also provide an experimental compilation transform in the `functorch.compile` namespace. Our compilation transform, named AOT (ahead-of-time) Autograd, returns to you an [FX graph](https://pytorch.org/docs/stable/fx.html) (that optionally contains a backward pass), of which compilation via various backends is one path you can take.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Why composable function transforms?\n",
|
||||
@ -18,21 +26,38 @@
|
||||
"- efficiently computing Jacobians and Hessians\n",
|
||||
"- efficiently computing batched Jacobians and Hessians\n",
|
||||
"\n",
|
||||
"Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the [JAX framework](https://github.com/google/jax).\n",
|
||||
"Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above without designing a separate subsystem for each.\n",
|
||||
"\n",
|
||||
"## What are the transforms?\n",
|
||||
"\n",
|
||||
"Right now, we support the following transforms:\n",
|
||||
"### grad (gradient computation)\n",
|
||||
"\n",
|
||||
"- `grad`, `vjp`, `jvp`,\n",
|
||||
"- `jacrev`, `jacfwd`, `hessian`\n",
|
||||
"- `vmap`\n",
|
||||
"`grad(func)` is our gradient computation transform. It returns a new function that computes the gradients of `func`. It assumes `func` returns a single-element Tensor and by default it computes the gradients of the output of `func` w.r.t. to the first input."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f920b923",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from functorch import grad\n",
|
||||
"x = torch.randn([])\n",
|
||||
"cos_x = grad(lambda x: torch.sin(x))(x)\n",
|
||||
"assert torch.allclose(cos_x, x.cos())\n",
|
||||
"\n",
|
||||
"Furthermore, we have some utilities for working with PyTorch modules.\n",
|
||||
"- `make_functional(model)`\n",
|
||||
"- `make_functional_with_buffers(model)`\n",
|
||||
"\n",
|
||||
"### vmap\n",
|
||||
"# Second-order gradients\n",
|
||||
"neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)\n",
|
||||
"assert torch.allclose(neg_sin_x, -x.sin())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ef3b2d85",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### vmap (auto-vectorization)\n",
|
||||
"\n",
|
||||
"Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.\n",
|
||||
"\n",
|
||||
@ -43,8 +68,8 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "f920b923",
|
||||
"execution_count": null,
|
||||
"id": "6ebac649",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -62,44 +87,17 @@
|
||||
"result = vmap(model)(examples)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ef3b2d85",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### grad\n",
|
||||
"\n",
|
||||
"`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. By default, it computes the gradients of the output of `func` w.r.t. to `inputs[0]`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "6ebac649",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from functorch import grad\n",
|
||||
"x = torch.randn([])\n",
|
||||
"cos_x = grad(lambda x: torch.sin(x))(x)\n",
|
||||
"assert torch.allclose(cos_x, x.cos())\n",
|
||||
"\n",
|
||||
"# Second-order gradients\n",
|
||||
"neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)\n",
|
||||
"assert torch.allclose(neg_sin_x, -x.sin())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5161e6d2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When composed with vmap, grad can be used to compute per-sample-gradients:"
|
||||
"When composed with `grad`, `vmap` can be used to compute per-sample-gradients:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "ffb2fcb1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -128,14 +126,14 @@
|
||||
"id": "11d711af",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### vjp\n",
|
||||
"### vjp (vector-Jacobian product)\n",
|
||||
"\n",
|
||||
"The `vjp` transform applies `func` to `inputs` and returns a new function that computes vjps given some `cotangents` Tensors."
|
||||
"The `vjp` transform applies `func` to `inputs` and returns a new function that computes the vector-Jacobian product (vjp) given some `cotangents` Tensors."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"id": "ad48f9d4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -154,14 +152,14 @@
|
||||
"id": "e0221270",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### jvp\n",
|
||||
"### jvp (Jacobian-vector product)\n",
|
||||
"\n",
|
||||
"The `jvp` transforms computes Jacobian-vector-products and is also known as \"forward-mode AD\". It is not a higher-order function unlike most other transforms, but it returns the outputs of `func(inputs)` as well as the jvps."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"id": "f3772f43",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -187,7 +185,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"id": "20f53be2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -209,7 +207,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"id": "97d6c382",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -229,7 +227,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"id": "a8c1dedb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -251,7 +249,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"id": "1e511139",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -274,7 +272,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"id": "fd1765df",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -315,7 +313,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.4"
|
||||
"version": "3.7.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user