[functorch] Better intro to functorch (pytorch/functorch#688)

For our docs landing page.

Fixes https://github.com/pytorch/functorch/issues/605
This commit is contained in:
Richard Zou
2022-04-11 13:56:45 -04:00
committed by Jon Janzen
parent f61d825462
commit b2f03d7f7d
2 changed files with 67 additions and 58 deletions

View File

@ -7,11 +7,8 @@ functorch
functorch is `JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
It aims to provide composable vmap and grad transforms that work with PyTorch modules
and PyTorch autograd with good eager-mode performance.
.. note::
This library is currently in [beta](https://pytorch.org/blog/pytorch-feature-classification-changes/#beta).
This library is currently in `beta <https://pytorch.org/blog/pytorch-feature-classification-changes/#beta>`_.
What this means is that the features generally work (unless otherwise documented)
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
may change under user feedback and we don't have full coverage over PyTorch operations.
@ -19,6 +16,20 @@ and PyTorch autograd with good eager-mode performance.
If you have suggestions on the API or use-cases you'd like to be covered, please
open an github issue or reach out. We'd love to hear about how you're using the library.
What are composable function transforms?
----------------------------------------
- A "function transform" is a higher-order function that accepts a numerical function
and returns a new function that computes a different quantity.
- functorch has auto-differentiation transforms (``grad(f)`` returns a function that
computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)``
returns a function that computes ``f`` over batches of inputs), and others.
- These function transforms can compose with each other arbitrarily. For example,
composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that
stock PyTorch cannot efficiently compute today.
Why composable function transforms?
-----------------------------------
@ -36,7 +47,7 @@ This idea of composable function transforms comes from the `JAX framework <https
Read More
---------
For a whirlwind tour of how to use the transforms, please check out `this section in our README <https://github.com/pytorch/functorch/blob/main/README.md#what-are-the-transforms>`_. For installation instructions or the API reference, please check below.
Check out our `whirlwind tour <whirlwind_tour>`_ or some of our tutorials mentioned below.
.. toctree::

View File

@ -7,7 +7,15 @@
"source": [
"# Whirlwind Tour\n",
"\n",
"functorch is [JAX](https://github.com/google/jax)-like composable function transforms for PyTorch. In this whirlwind tour, we'll introduce all the functorch transforms.\n",
"\n",
"## What is functorch?\n",
"\n",
"functorch is a library for [JAX](https://github.com/google/jax)-like composable function transforms in PyTorch.\n",
"- A \"function transform\" is a higher-order function that accepts a numerical function and returns a new function that computes a different quantity.\n",
"- functorch has auto-differentiation transforms (`grad(f)` returns a function that computes the gradient of `f`), a vectorization/batching transform (`vmap(f)` returns a function that computes `f` over batches of inputs), and others.\n",
"- These function transforms can compose with each other arbitrarily. For example, composing `vmap(grad(f))` computes a quantity called per-sample-gradients that stock PyTorch cannot efficiently compute today.\n",
"\n",
"Furthermore, we also provide an experimental compilation transform in the `functorch.compile` namespace. Our compilation transform, named AOT (ahead-of-time) Autograd, returns to you an [FX graph](https://pytorch.org/docs/stable/fx.html) (that optionally contains a backward pass), of which compilation via various backends is one path you can take.\n",
"\n",
"\n",
"## Why composable function transforms?\n",
@ -18,21 +26,38 @@
"- efficiently computing Jacobians and Hessians\n",
"- efficiently computing batched Jacobians and Hessians\n",
"\n",
"Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the [JAX framework](https://github.com/google/jax).\n",
"Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above without designing a separate subsystem for each.\n",
"\n",
"## What are the transforms?\n",
"\n",
"Right now, we support the following transforms:\n",
"### grad (gradient computation)\n",
"\n",
"- `grad`, `vjp`, `jvp`,\n",
"- `jacrev`, `jacfwd`, `hessian`\n",
"- `vmap`\n",
"`grad(func)` is our gradient computation transform. It returns a new function that computes the gradients of `func`. It assumes `func` returns a single-element Tensor and by default it computes the gradients of the output of `func` w.r.t. to the first input."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f920b923",
"metadata": {},
"outputs": [],
"source": [
"from functorch import grad\n",
"x = torch.randn([])\n",
"cos_x = grad(lambda x: torch.sin(x))(x)\n",
"assert torch.allclose(cos_x, x.cos())\n",
"\n",
"Furthermore, we have some utilities for working with PyTorch modules.\n",
"- `make_functional(model)`\n",
"- `make_functional_with_buffers(model)`\n",
"\n",
"### vmap\n",
"# Second-order gradients\n",
"neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)\n",
"assert torch.allclose(neg_sin_x, -x.sin())"
]
},
{
"cell_type": "markdown",
"id": "ef3b2d85",
"metadata": {},
"source": [
"### vmap (auto-vectorization)\n",
"\n",
"Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.\n",
"\n",
@ -43,8 +68,8 @@
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f920b923",
"execution_count": null,
"id": "6ebac649",
"metadata": {},
"outputs": [],
"source": [
@ -62,44 +87,17 @@
"result = vmap(model)(examples)"
]
},
{
"cell_type": "markdown",
"id": "ef3b2d85",
"metadata": {},
"source": [
"### grad\n",
"\n",
"`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. By default, it computes the gradients of the output of `func` w.r.t. to `inputs[0]`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6ebac649",
"metadata": {},
"outputs": [],
"source": [
"from functorch import grad\n",
"x = torch.randn([])\n",
"cos_x = grad(lambda x: torch.sin(x))(x)\n",
"assert torch.allclose(cos_x, x.cos())\n",
"\n",
"# Second-order gradients\n",
"neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)\n",
"assert torch.allclose(neg_sin_x, -x.sin())"
]
},
{
"cell_type": "markdown",
"id": "5161e6d2",
"metadata": {},
"source": [
"When composed with vmap, grad can be used to compute per-sample-gradients:"
"When composed with `grad`, `vmap` can be used to compute per-sample-gradients:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "ffb2fcb1",
"metadata": {},
"outputs": [],
@ -128,14 +126,14 @@
"id": "11d711af",
"metadata": {},
"source": [
"### vjp\n",
"### vjp (vector-Jacobian product)\n",
"\n",
"The `vjp` transform applies `func` to `inputs` and returns a new function that computes vjps given some `cotangents` Tensors."
"The `vjp` transform applies `func` to `inputs` and returns a new function that computes the vector-Jacobian product (vjp) given some `cotangents` Tensors."
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "ad48f9d4",
"metadata": {},
"outputs": [],
@ -154,14 +152,14 @@
"id": "e0221270",
"metadata": {},
"source": [
"### jvp\n",
"### jvp (Jacobian-vector product)\n",
"\n",
"The `jvp` transforms computes Jacobian-vector-products and is also known as \"forward-mode AD\". It is not a higher-order function unlike most other transforms, but it returns the outputs of `func(inputs)` as well as the jvps."
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "f3772f43",
"metadata": {},
"outputs": [],
@ -187,7 +185,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "20f53be2",
"metadata": {},
"outputs": [],
@ -209,7 +207,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "97d6c382",
"metadata": {},
"outputs": [],
@ -229,7 +227,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "a8c1dedb",
"metadata": {},
"outputs": [],
@ -251,7 +249,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "1e511139",
"metadata": {},
"outputs": [],
@ -274,7 +272,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "fd1765df",
"metadata": {},
"outputs": [],
@ -315,7 +313,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.4"
"version": "3.7.4"
}
},
"nbformat": 4,