mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
38 lines
1.6 KiB
ReStructuredText
38 lines
1.6 KiB
ReStructuredText
:github_url: https://github.com/pytorch/functorch
|
|
|
|
functorch
|
|
===================================
|
|
|
|
functorch is `JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
|
|
|
|
It aims to provide composable vmap and grad transforms that work with PyTorch modules
|
|
and PyTorch autograd with good eager-mode performance.
|
|
|
|
**This library is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue or reach out. We'd love to hear about how you're using the library.**
|
|
|
|
Why composable function transforms?
|
|
-----------------------------------
|
|
|
|
There are a number of use cases that are tricky to do in PyTorch today:
|
|
|
|
- computing per-sample-gradients (or other per-sample quantities)
|
|
- running ensembles of models on a single machine
|
|
- efficiently batching together tasks in the inner-loop of MAML
|
|
- efficiently computing Jacobians and Hessians
|
|
- efficiently computing batched Jacobians and Hessians
|
|
|
|
Composing `vmap`, `grad`, and `vjp` transforms allows us to express the above without designing a separate subsystem for each.
|
|
This idea of composable function transforms comes from the `JAX framework <https://github.com/google/jax>`_.
|
|
|
|
Read More
|
|
---------
|
|
|
|
For a whirlwind tour of how to use the transforms, please check out `this section in our README <https://github.com/pytorch/functorch/blob/main/README.md#what-are-the-transforms>`_. For installation instructions or the API reference, please check below.
|
|
|
|
|
|
.. toctree::
|
|
:maxdepth: 1
|
|
|
|
Install <install>
|
|
functorch API Reference <functorch>
|