Files
pytorch/functorch
Klaus Zimmermann f16053f0c9 Switch to standard pep517 sdist generation (#152098)
Generate source tarball with PEP 517 conform build tools instead of the custom routine in place right now.

Closes #150461.

The current procedure for generating the source tarball consists in creation of a source tree by manual copying and pruning of source files.

This PR replaces that with a call to the standard [build tool](https://build.pypa.io/en/stable/), which works with the build backend to produce an sdist. For that to work correctly, the build backend also needs to be configured. In the case of Pytorch, the backend currently is (the legacy version of) the setuptools backend, the source dist part of which is mostly configured via the `MANIFEST.in` file.

The resulting source distribution can be used to install directly from source with `pip install ./torch-{version}.tar.gz` or to build wheels directly from source with `pip wheel ./torch-{version}.tar.gz`; both should be considered experimental for now.

## Issues

### sdist name
According to PEP 517, the name of the source distribution file must coincide with the project name, or [more precisely](https://peps.python.org/pep-0517/#source-distributions), the source distribution of a project that generates `{NAME}-{...}.whl` wheels are required to be named `{NAME}-{...}.tar.gz`. Currently, the source tarball is called `pytorch-{...}.tar.gz`, but the generated wheels and python package are called `torch-{...}`.

### Symbolic Links
The source tree at the moment contains a small number of symbolic links. This [has been seen as problematic](https://github.com/pypa/pip/issues/5919) largely because of lack of support on Windows, but also because of [a problem in setuptools](https://github.com/pypa/setuptools/issues/4937). Particularly unfortunate is a circular symlink in the third party `ittapi` module, which can not be resolved by replacing it with a copy.

PEP 721 (now integrated in the [Source Distribution Format Specification](https://packaging.python.org/en/latest/specifications/source-distribution-format/#source-distribution-archive-features)) allows for symbolic links, but only if they don't point outside the destination directory and if they don't contain `../` in their target.

The list of symbolic links currently is as follows:

<details>

|source|target|problem|solution|
|-|-|-|-|
| `.dockerignore` | `.gitignore` |  ok (individual file) ||
| `docs/requirements.txt` | `../.ci/docker/requirements-docs.txt` |`..` in target|swap source and target[^1]|
| `functorch/docs/source/notebooks` | `../../notebooks/` |`..` in target|swap source and target[^1]|
| `.github/ci_commit_pins/triton.txt` | `../../.ci/docker/ci_commit_pins/triton.txt` |  ok (omitted from sdist)||
| `third_party/flatbuffers/docs/source/CONTRIBUTING.md` | `../../CONTRIBUTING.md` |`..` in target|omit from sdist[^2]|
| `third_party/flatbuffers/java/src/test/java/DictionaryLookup` | `../../../../tests/DictionaryLookup` |`..` in target|omit from sdist[^3]|
| `third_party/flatbuffers/java/src/test/java/MyGame` | `../../../../tests/MyGame` |`..` in target|omit from sdist[^3]|
| `third_party/flatbuffers/java/src/test/java/NamespaceA` | `../../../../tests/namespace_test/NamespaceA` |`..` in target|omit from sdist[^3]|
| `third_party/flatbuffers/java/src/test/java/NamespaceC` | `../../../../tests/namespace_test/NamespaceC` |`..` in target|omit from sdist[^3]|
| `third_party/flatbuffers/java/src/test/java/optional_scalars` | `../../../../tests/optional_scalars` |`..` in target|omit from sdist[^3]|
| `third_party/flatbuffers/java/src/test/java/union_vector` | `../../../../tests/union_vector` |`..` in target|omit from sdist[^3]|
| `third_party/flatbuffers/kotlin/benchmark/src/jvmMain/java` | `../../../../java/src/main/java` |`..` in target|omit from sdist[^3]|
| `third_party/ittapi/rust/ittapi-sys/c-library` | `../../` |`..` in target|omit from sdist[^4]|
| `third_party/ittapi/rust/ittapi-sys/LICENSES` | `../../LICENSES` |`..` in target|omit from sdist[^4]|
| `third_party/opentelemetry-cpp/buildscripts/pre-merge-commit` | `./pre-commit` | ok (individual file)||
| `third_party/opentelemetry-cpp/third_party/prometheus-cpp/cmake/project-import-cmake/sample_client.cc` | `../../push/tests/integration/sample_client.cc` |`..` in target|omit from sdist[^5]|
| `third_party/opentelemetry-cpp/third_party/prometheus-cpp/cmake/project-import-cmake/sample_server.cc` | `../../pull/tests/integration/sample_server.cc` |`..` in target|omit from sdist[^5]|
| `third_party/opentelemetry-cpp/third_party/prometheus-cpp/cmake/project-import-pkgconfig/sample_client.cc` | `../../push/tests/integration/sample_client.cc` |`..` in target|omit from sdist[^5]|
| `third_party/opentelemetry-cpp/third_party/prometheus-cpp/cmake/project-import-pkgconfig/sample_server.cc` | `../../pull/tests/integration/sample_server.cc` |`..` in target|omit from sdist[^5]|
| `third_party/XNNPACK/tools/xngen` | `xngen.py` |  ok (individual file)||

</details>

The introduction of symbolic links inside the `.ci/docker` folder creates a new problem, however, because Docker's `COPY` command does not allow symlinks in this way. We work around that by using `tar ch` to dereference the symlinks before handing them over to `docker build`.

[^1]: These resources can be naturally considered to be part of the docs, so moving the actual files into the place of the current symlinks and replacing them with (unproblematic) symlinks can be said to improve semantics as well.

[^2]: The flatbuffers docs already actually use the original file, not the symlink and in the most recent releases, starting from flatbuffers-25.1.21 the symlink is replaced by the actual file thanks to a documentation overhaul.

[^3]: These resources are flatbuffers tests for java and kotlin and can be omitted from our sdist.

[^4]: We don't need to ship the rust bindings for ittapi.

[^5]: These are demonstration examples for how to link to prometheus-cpp using cmake and can be omitted.

### Nccl
Nccl used to be included as a submodule. However, with #146073 (first released in v2.7.0-rc1), the submodule was removed and replaced with a build time checkout procedure in `tools/build_pytorch_libs.py`, which checks out the required version of nccl from the upstream repository based on a commit pin recorded in `.ci/docker/ci_commit_pins/nccl-cu{11,12}.txt`.
This means that a crucial third party dependency is missing from the source distribution and as the `.ci` folder is omitted from the source distribution, it is not possible to use the build time download.
However, it *is* possible to use a system provided Nccl using the `USE_SYSTEM_NCCL` environment variable, which now also is the default for the official Pytorch wheels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152098
Approved by: https://github.com/atalman
2025-06-30 19:07:34 +00:00
..

functorch

Why functorch? | Install guide | Transformations | Documentation | Future Plans

This library is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue or reach out. We'd love to hear about how you're using the library.

functorch is JAX-like composable function transforms for PyTorch.

It aims to provide composable vmap and grad transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance.

In addition, there is experimental functionality to trace through these transformations using FX in order to capture the results of these transforms ahead of time. This would allow us to compile the results of vmap or grad to improve performance.

Why composable function transforms?

There are a number of use cases that are tricky to do in PyTorch today:

  • computing per-sample-gradients (or other per-sample quantities)
  • running ensembles of models on a single machine
  • efficiently batching together tasks in the inner-loop of MAML
  • efficiently computing Jacobians and Hessians
  • efficiently computing batched Jacobians and Hessians

Composing vmap, grad, vjp, and jvp transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework.

Install

There are two ways to install functorch:

  1. functorch from source
  2. functorch beta (compatible with recent PyTorch releases)

We recommend trying out the functorch beta first.

Installing functorch from source

Click to expand

Using Colab

Follow the instructions in this Colab notebook

Locally

As of 9/21/2022, functorch comes installed alongside a nightly PyTorch binary. Please install a Preview (nightly) PyTorch binary; see https://pytorch.org/ for instructions.

Once you've done that, run a quick sanity check in Python:

import torch
from functorch import vmap
x = torch.randn(3)
y = vmap(torch.sin)(x)
assert torch.allclose(y, x.sin())

functorch development setup

As of 9/21/2022, functorch comes installed alongside PyTorch and is in the PyTorch source tree. Please install PyTorch from source, then, you will be able to import functorch.

Try to run some tests to make sure all is OK:

pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -v

AOTAutograd has some additional optional requirements. You can install them via:

pip install networkx

To run functorch tests, please install our test dependencies (expecttest, pyyaml).

Installing functorch beta (compatible with recent PyTorch releases)

Click to expand

Using Colab

Follow the instructions here

pip

Prerequisite: Install PyTorch

pip install functorch

Finally, run a quick sanity check in python:

import torch
from functorch import vmap
x = torch.randn(3)
y = vmap(torch.sin)(x)
assert torch.allclose(y, x.sin())

What are the transforms?

Right now, we support the following transforms:

  • grad, vjp, jvp,
  • jacrev, jacfwd, hessian
  • vmap

Furthermore, we have some utilities for working with PyTorch modules.

  • make_functional(model)
  • make_functional_with_buffers(model)

vmap

Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.

vmap(func)(*inputs) is a transform that adds a dimension to all Tensor operations in func. vmap(func) returns a new function that maps func over some dimension (default: 0) of each Tensor in inputs.

vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with vmap(func), leading to a simpler modeling experience:

from functorch import vmap
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)

def model(feature_vec):
    # Very simple linear model with activation
    assert feature_vec.dim() == 1
    return feature_vec.dot(weights).relu()

examples = torch.randn(batch_size, feature_size)
result = vmap(model)(examples)

grad

grad(func)(*inputs) assumes func returns a single-element Tensor. It compute the gradients of the output of func w.r.t. to inputs[0].

from functorch import grad
x = torch.randn([])
cos_x = grad(lambda x: torch.sin(x))(x)
assert torch.allclose(cos_x, x.cos())

# Second-order gradients
neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
assert torch.allclose(neg_sin_x, -x.sin())

When composed with vmap, grad can be used to compute per-sample-gradients:

from functorch import vmap
batch_size, feature_size = 3, 5

def model(weights,feature_vec):
    # Very simple linear model with activation
    assert feature_vec.dim() == 1
    return feature_vec.dot(weights).relu()

def compute_loss(weights, example, target):
    y = model(weights, example)
    return ((y - target) ** 2).mean()  # MSELoss

weights = torch.randn(feature_size, requires_grad=True)
examples = torch.randn(batch_size, feature_size)
targets = torch.randn(batch_size)
inputs = (weights,examples, targets)
grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

vjp

The vjp transform applies func to inputs and returns a new function that computes vjps given some cotangents Tensors.

from functorch import vjp
outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)

jvp

The jvp transforms computes Jacobian-vector-products and is also known as "forward-mode AD". It is not a higher-order function unlike most other transforms, but it returns the outputs of func(inputs) as well as the jvps.

from functorch import jvp
x = torch.randn(5)
y = torch.randn(5)
f = lambda x, y: (x * y)
_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
assert torch.allclose(output, x + y)

jacrev, jacfwd, and hessian

The jacrev transform returns a new function that takes in x and returns the Jacobian of torch.sin with respect to x using reverse-mode AD.

from functorch import jacrev
x = torch.randn(5)
jacobian = jacrev(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)

Use jacrev to compute the jacobian. This can be composed with vmap to produce batched jacobians:

x = torch.randn(64, 5)
jacobian = vmap(jacrev(torch.sin))(x)
assert jacobian.shape == (64, 5, 5)

jacfwd is a drop-in replacement for jacrev that computes Jacobians using forward-mode AD:

from functorch import jacfwd
x = torch.randn(5)
jacobian = jacfwd(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)

Composing jacrev with itself or jacfwd can produce hessians:

def f(x):
  return x.sin().sum()

x = torch.randn(5)
hessian0 = jacrev(jacrev(f))(x)
hessian1 = jacfwd(jacrev(f))(x)

The hessian is a convenience function that combines jacfwd and jacrev:

from functorch import hessian

def f(x):
  return x.sin().sum()

x = torch.randn(5)
hess = hessian(f)(x)

Tracing through the transformations

We can also trace through these transformations in order to capture the results as new code using make_fx. There is also experimental integration with the NNC compiler (only works on CPU for now!).

from functorch import make_fx, grad
def f(x):
    return torch.sin(x).sum()
x = torch.randn(100)
grad_f = make_fx(grad(f))(x)
print(grad_f.code)

def forward(self, x_1):
    sin = torch.ops.aten.sin(x_1)
    sum_1 = torch.ops.aten.sum(sin, None);  sin = None
    cos = torch.ops.aten.cos(x_1);  x_1 = None
    _tensor_constant0 = self._tensor_constant0
    mul = torch.ops.aten.mul(_tensor_constant0, cos);  _tensor_constant0 = cos = None
    return mul

Working with NN modules: make_functional and friends

Sometimes you may want to perform a transform with respect to the parameters and/or buffers of an nn.Module. This can happen for example in:

  • model ensembling, where all of your weights and buffers have an additional dimension
  • per-sample-gradient computation where you want to compute per-sample-grads of the loss with respect to the model parameters

Our solution to this right now is an API that, given an nn.Module, creates a stateless version of it that can be called like a function.

  • make_functional(model) returns a functional version of model and the model.parameters()
  • make_functional_with_buffers(model) returns a functional version of model and the model.parameters() and model.buffers().

Here's an example where we compute per-sample-gradients using an nn.Linear layer:

import torch
from functorch import make_functional, vmap, grad

model = torch.nn.Linear(3, 3)
data = torch.randn(64, 3)
targets = torch.randn(64, 3)

func_model, params = make_functional(model)

def compute_loss(params, data, targets):
    preds = func_model(params, data)
    return torch.mean((preds - targets) ** 2)

per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)

If you're making an ensemble of models, you may find combine_state_for_ensemble useful.

Documentation

For more documentation, see our docs website.

Debugging

torch._C._functorch.dump_tensor: Dumps dispatch keys on stack torch._C._functorch._set_vmap_fallback_warning_enabled(False) if the vmap warning spam bothers you.

Future Plans

In the end state, we'd like to upstream this into PyTorch once we iron out the design details. To figure out the details, we need your help -- please send us your use cases by starting a conversation in the issue tracker or trying our project out.

License

Functorch has a BSD-style license, as found in the LICENSE file.

Citing functorch

If you use functorch in your publication, please cite it by using the following BibTeX entry.

@Misc{functorch2021,
  author =       {Horace He, Richard Zou},
  title =        {functorch: JAX-like composable function transforms for PyTorch},
  howpublished = {\url{https://github.com/pytorch/functorch}},
  year =         {2021}
}