mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Update README, setup.py
This commit is contained in:
1
functorch/.gitignore
vendored
1
functorch/.gitignore
vendored
@ -2,3 +2,4 @@ build/
|
||||
dist/
|
||||
functorch.egg-info/
|
||||
*__pycache__*
|
||||
functorch/version.py
|
||||
|
@ -1 +1,163 @@
|
||||
# functorch
|
||||
# functorch
|
||||
|
||||
[**Why functorch?**](#why-composable-function-transforms)
|
||||
| [**Transformations**](#what-are-the-transforms)
|
||||
| [**Install guide**](#install)
|
||||
| [**Future Plans**](#future-plans)
|
||||
|
||||
`functorch` is a prototype of [JAX-like](https://github.com/google/jax)
|
||||
composable FUNCtion transforms for pyTORCH.
|
||||
|
||||
It aims to provide composable `vmap` and `grad` transforms that work with
|
||||
PyTorch modules and PyTorch autograd with good eager-mode performance. Because
|
||||
this project requires some investment, we'd love to hear from and work with
|
||||
early adopters to shape the design. Please reach out on the issue tracker
|
||||
if you're interested in using this for your project.
|
||||
|
||||
## Why composable function transforms?
|
||||
|
||||
There are a number of use cases that are tricky to do in
|
||||
PyTorch today:
|
||||
- computing per-sample-gradients (or other per-sample quantities)
|
||||
- running ensembles of models on a single machine
|
||||
- efficiently batching together tasks in the inner-loop of MAML
|
||||
- efficiently computing Jacobians and Hessians
|
||||
- efficiently computing batched Jacobians and Hessians
|
||||
|
||||
Composing `vmap`, `grad`, and `vjp` transforms allows us to express the above
|
||||
without designing a separate subsystem for each. This idea of composable function
|
||||
transforms comes from the [JAX framework](https://github.com/google/jax).
|
||||
|
||||
## What are the transforms?
|
||||
|
||||
Right now, we support the following transforms:
|
||||
- `grad`, `vjp`, `jacrev`
|
||||
- `vmap`
|
||||
|
||||
Furthermore, we have some utilities for working with PyTorch modules.
|
||||
- `make_functional_with_buffers`
|
||||
|
||||
### vmap
|
||||
|
||||
Note: `vmap` imposes restrictions on the code that it can be used on.
|
||||
For more details, please read its docstring.
|
||||
|
||||
`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor
|
||||
operations in `func`. `vmap(func)` returns a few function that maps `func` over
|
||||
some dimension (default: 0) of each Tensor in `inputs`.
|
||||
|
||||
`vmap` is useful for hiding batch dimensions: one can write a function `func`
|
||||
that runs on examples and then lift it to a function that can take batches of
|
||||
examples with `vmap(func)`, leading to a simpler modeling experience:
|
||||
|
||||
```py
|
||||
>>> from functorch import vmap
|
||||
>>> batch_size, feature_size = 3, 5
|
||||
>>> weights = torch.randn(feature_size, requires_grad=True)
|
||||
>>>
|
||||
>>> def model(feature_vec):
|
||||
>>> # Very simple linear model with activation
|
||||
>>> assert feature_vec.dim() == 1
|
||||
>>> return feature_vec.dot(weights).relu()
|
||||
>>>
|
||||
>>> examples = torch.randn(batch_size, feature_size)
|
||||
>>> result = vmap(model)(examples)
|
||||
```
|
||||
|
||||
### grad
|
||||
|
||||
`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute
|
||||
the gradients of the output of func w.r.t. to `inputs[0]`.
|
||||
|
||||
```py
|
||||
>>> from functorch import grad
|
||||
>>> x = torch.randn([])
|
||||
>>> cos_x = grad(torch.sin)(x)
|
||||
>>> assert torch.allclose(cos_x, x.cos())
|
||||
>>>
|
||||
>>> # Second-order gradients
|
||||
>>> neg_sin_x = grad(grad(torch.sin))(x)
|
||||
>>> assert torch.allclose(neg_sin_x, -x.sin())
|
||||
```
|
||||
|
||||
When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
|
||||
```
|
||||
>>> from functorch import vmap
|
||||
>>> batch_size, feature_size = 3, 5
|
||||
>>> weights = torch.randn(feature_size, requires_grad=True)
|
||||
>>>
|
||||
>>> def model(feature_vec):
|
||||
>>> # Very simple linear model with activation
|
||||
>>> assert feature_vec.dim() == 1
|
||||
>>> return feature_vec.dot(weights).relu()
|
||||
>>>
|
||||
>>> def compute_loss(weights, example, target):
|
||||
>>> y = model(example)
|
||||
>>> return ((y - t) ** 2).mean() # MSELoss
|
||||
>>>
|
||||
>>> examples = torch.randn(batch_size, feature_size)
|
||||
>>> targets = torch.randn(batch_size)
|
||||
>>> grad_weight_per_example = vmap(grad(compute_loss))(weights, examples, targets)
|
||||
```
|
||||
|
||||
### vjp and jacrev
|
||||
|
||||
```
|
||||
>>> from functorch import vjp
|
||||
>>> outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)
|
||||
```
|
||||
The `vjp` transform applies `func` to `inputs` and returns a new function that
|
||||
computes vjps given some `contangents` Tensors.
|
||||
|
||||
```
|
||||
>>> from functorch import jacrev
|
||||
>>> x = torch.randn(5)
|
||||
>>> jacobian = jacrev(torch.sin)(x)
|
||||
>>> expected = torch.diag(x)
|
||||
>>> assert torch.allclose(jacobian, expected)
|
||||
```
|
||||
Use `jacrev` to compute the jacobian. This can be composed with vmap to produce
|
||||
batched jacobians:
|
||||
|
||||
```
|
||||
>>> x = torch.randn(64, 5)
|
||||
>>> jacobian = vmap(jacrev(torch.sin))(x)
|
||||
>>> assert jacobian.shape == (64, 5, 5)
|
||||
```
|
||||
|
||||
`jacrev` can be composed with itself to produce hessians:
|
||||
```
|
||||
>>> def f(x):
|
||||
>>> return x.sin().sum()
|
||||
>>>
|
||||
>>> x = torch.randn(5)
|
||||
>>> hessian = jacrev(jacrev(f))(x)
|
||||
```
|
||||
|
||||
## Install
|
||||
|
||||
### Binaries
|
||||
|
||||
Coming soon!
|
||||
|
||||
### From Source
|
||||
|
||||
`functorch` is a PyTorch C++ Extension module. To install,
|
||||
|
||||
- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source).
|
||||
Be sure to make sure the changes from https://github.com/pytorch/pytorch/pull/56824
|
||||
are on the branch. TODO: we should recommend a commit hash that is known to be stable
|
||||
- Run `python setup.py install`
|
||||
|
||||
Then, try to run some tests to make sure all is OK:
|
||||
```
|
||||
pytest test/test_vmap.py -v
|
||||
pytest test/test_eager_transforms.py -v
|
||||
```
|
||||
|
||||
## Future Plans
|
||||
|
||||
In the end state, we'd like to upstream this into PyTorch once we iron out the
|
||||
design details. To figure out the details, we need your help -- please send us
|
||||
your use cases by starting a conversation in the issue tracker or try out the
|
||||
prototype.
|
||||
|
@ -1,27 +1,62 @@
|
||||
import distutils
|
||||
import distutils.command.clean
|
||||
import shutil
|
||||
import glob
|
||||
import os
|
||||
import subprocess
|
||||
from setuptools import setup, find_packages
|
||||
from torch.utils.cpp_extension import (
|
||||
CppExtension,
|
||||
BuildExtension,
|
||||
)
|
||||
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
version_txt = os.path.join(cwd, 'version.txt')
|
||||
with open(version_txt, 'r') as f:
|
||||
version = f.readline().strip()
|
||||
|
||||
# class clean(distutils.command.clean.clean):
|
||||
# def run(self):
|
||||
# with open(".gitignore", "r") as f:
|
||||
# ignores = f.read()
|
||||
# for wildcard in filter(None, ignores.split("\n")):
|
||||
# for filename in glob.glob(wildcard):
|
||||
# try:
|
||||
# os.remove(filename)
|
||||
# except OSError:
|
||||
# shutil.rmtree(filename, ignore_errors=True)
|
||||
try:
|
||||
sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip()
|
||||
except Exception:
|
||||
sha = 'Unknown'
|
||||
package_name = 'functorch'
|
||||
|
||||
if os.getenv('BUILD_VERSION'):
|
||||
version = os.getenv('BUILD_VERSION')
|
||||
elif sha != 'Unknown':
|
||||
version += '+' + sha[:7]
|
||||
|
||||
|
||||
def write_version_file():
|
||||
version_path = os.path.join(cwd, 'functorch', 'version.py')
|
||||
with open(version_path, 'w') as f:
|
||||
f.write("__version__ = '{}'\n".format(version))
|
||||
f.write("git_version = {}\n".format(repr(sha)))
|
||||
|
||||
# TODO: is there a way to specify that either of the following is the requirement:
|
||||
# 1. a pytorch nightly
|
||||
# 2. a specific hash of PyTorch?
|
||||
# pytorch_dep = 'torch'
|
||||
# if os.getenv('PYTORCH_VERSION'):
|
||||
# pytorch_dep += "==" + os.getenv('PYTORCH_VERSION')
|
||||
#
|
||||
# # It's an old-style class in Python 2.7...
|
||||
# distutils.command.clean.clean.run(self)
|
||||
# requirements = [
|
||||
# pytorch_dep,
|
||||
# ]
|
||||
|
||||
|
||||
class clean(distutils.command.clean.clean):
|
||||
def run(self):
|
||||
with open(".gitignore", "r") as f:
|
||||
ignores = f.read()
|
||||
for wildcard in filter(None, ignores.split("\n")):
|
||||
for filename in glob.glob(wildcard):
|
||||
try:
|
||||
os.remove(filename)
|
||||
except OSError:
|
||||
shutil.rmtree(filename, ignore_errors=True)
|
||||
|
||||
# It's an old-style class in Python 2.7...
|
||||
distutils.command.clean.clean.run(self)
|
||||
|
||||
|
||||
def get_extensions():
|
||||
@ -31,8 +66,9 @@ def get_extensions():
|
||||
|
||||
extra_link_args = []
|
||||
extra_compile_args = {"cxx": ["-O3", "-g", "-std=c++14"]}
|
||||
if int(os.environ.get("DEBUG", 0)):
|
||||
# if True:
|
||||
debug_mode = os.getenv('DEBUG', '0') == '1'
|
||||
if debug_mode:
|
||||
print("Compiling in debug mode")
|
||||
extra_compile_args = {
|
||||
"cxx": ["-O0", "-fno-inline", "-g", "-std=c++14"]}
|
||||
extra_link_args = ["-O0", "-g"]
|
||||
@ -61,12 +97,24 @@ def get_extensions():
|
||||
return ext_modules
|
||||
|
||||
|
||||
setup(
|
||||
name='functorch',
|
||||
url="https://github.com/zou3519/functorch",
|
||||
packages=find_packages(),
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={
|
||||
# "clean": clean,
|
||||
"build_ext": BuildExtension
|
||||
})
|
||||
if __name__ == '__main__':
|
||||
print("Building wheel {}-{}".format(package_name, version))
|
||||
write_version_file()
|
||||
|
||||
setup(
|
||||
# Metadata
|
||||
name=package_name,
|
||||
version=version,
|
||||
author='PyTorch Core Team',
|
||||
url="https://github.com/zou3519/functorch",
|
||||
description='prototype of composable function transforms for PyTorch',
|
||||
license='BSD',
|
||||
|
||||
# Package info
|
||||
packages=find_packages(),
|
||||
# install_requires=requirements,
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={
|
||||
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
|
||||
'clean': clean,
|
||||
})
|
||||
|
Reference in New Issue
Block a user