From 7001a2f1e4d8382b4c255342b4b53eae9389f8b1 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 27 Apr 2021 11:10:33 -0700 Subject: [PATCH] [functorch] Update README, setup.py --- functorch/.gitignore | 1 + functorch/README.md | 164 ++++++++++++++++++++++++++++++++++++++++++- functorch/setup.py | 96 ++++++++++++++++++------- 3 files changed, 236 insertions(+), 25 deletions(-) diff --git a/functorch/.gitignore b/functorch/.gitignore index e38e12de85ed..9f5c3b244949 100644 --- a/functorch/.gitignore +++ b/functorch/.gitignore @@ -2,3 +2,4 @@ build/ dist/ functorch.egg-info/ *__pycache__* +functorch/version.py diff --git a/functorch/README.md b/functorch/README.md index 165de8c7430f..11a1a5f7b741 100644 --- a/functorch/README.md +++ b/functorch/README.md @@ -1 +1,163 @@ -# functorch \ No newline at end of file +# functorch + +[**Why functorch?**](#why-composable-function-transforms) +| [**Transformations**](#what-are-the-transforms) +| [**Install guide**](#install) +| [**Future Plans**](#future-plans) + +`functorch` is a prototype of [JAX-like](https://github.com/google/jax) +composable FUNCtion transforms for pyTORCH. + +It aims to provide composable `vmap` and `grad` transforms that work with +PyTorch modules and PyTorch autograd with good eager-mode performance. Because +this project requires some investment, we'd love to hear from and work with +early adopters to shape the design. Please reach out on the issue tracker +if you're interested in using this for your project. + +## Why composable function transforms? + +There are a number of use cases that are tricky to do in +PyTorch today: +- computing per-sample-gradients (or other per-sample quantities) +- running ensembles of models on a single machine +- efficiently batching together tasks in the inner-loop of MAML +- efficiently computing Jacobians and Hessians +- efficiently computing batched Jacobians and Hessians + +Composing `vmap`, `grad`, and `vjp` transforms allows us to express the above +without designing a separate subsystem for each. This idea of composable function +transforms comes from the [JAX framework](https://github.com/google/jax). + +## What are the transforms? + +Right now, we support the following transforms: +- `grad`, `vjp`, `jacrev` +- `vmap` + +Furthermore, we have some utilities for working with PyTorch modules. +- `make_functional_with_buffers` + +### vmap + +Note: `vmap` imposes restrictions on the code that it can be used on. +For more details, please read its docstring. + +`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor +operations in `func`. `vmap(func)` returns a few function that maps `func` over +some dimension (default: 0) of each Tensor in `inputs`. + +`vmap` is useful for hiding batch dimensions: one can write a function `func` +that runs on examples and then lift it to a function that can take batches of +examples with `vmap(func)`, leading to a simpler modeling experience: + +```py +>>> from functorch import vmap +>>> batch_size, feature_size = 3, 5 +>>> weights = torch.randn(feature_size, requires_grad=True) +>>> +>>> def model(feature_vec): +>>> # Very simple linear model with activation +>>> assert feature_vec.dim() == 1 +>>> return feature_vec.dot(weights).relu() +>>> +>>> examples = torch.randn(batch_size, feature_size) +>>> result = vmap(model)(examples) +``` + +### grad + +`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute +the gradients of the output of func w.r.t. to `inputs[0]`. + +```py +>>> from functorch import grad +>>> x = torch.randn([]) +>>> cos_x = grad(torch.sin)(x) +>>> assert torch.allclose(cos_x, x.cos()) +>>> +>>> # Second-order gradients +>>> neg_sin_x = grad(grad(torch.sin))(x) +>>> assert torch.allclose(neg_sin_x, -x.sin()) +``` + +When composed with `vmap`, `grad` can be used to compute per-sample-gradients: +``` +>>> from functorch import vmap +>>> batch_size, feature_size = 3, 5 +>>> weights = torch.randn(feature_size, requires_grad=True) +>>> +>>> def model(feature_vec): +>>> # Very simple linear model with activation +>>> assert feature_vec.dim() == 1 +>>> return feature_vec.dot(weights).relu() +>>> +>>> def compute_loss(weights, example, target): +>>> y = model(example) +>>> return ((y - t) ** 2).mean() # MSELoss +>>> +>>> examples = torch.randn(batch_size, feature_size) +>>> targets = torch.randn(batch_size) +>>> grad_weight_per_example = vmap(grad(compute_loss))(weights, examples, targets) +``` + +### vjp and jacrev + +``` +>>> from functorch import vjp +>>> outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents) +``` +The `vjp` transform applies `func` to `inputs` and returns a new function that +computes vjps given some `contangents` Tensors. + +``` +>>> from functorch import jacrev +>>> x = torch.randn(5) +>>> jacobian = jacrev(torch.sin)(x) +>>> expected = torch.diag(x) +>>> assert torch.allclose(jacobian, expected) +``` +Use `jacrev` to compute the jacobian. This can be composed with vmap to produce +batched jacobians: + +``` +>>> x = torch.randn(64, 5) +>>> jacobian = vmap(jacrev(torch.sin))(x) +>>> assert jacobian.shape == (64, 5, 5) +``` + +`jacrev` can be composed with itself to produce hessians: +``` +>>> def f(x): +>>> return x.sin().sum() +>>> +>>> x = torch.randn(5) +>>> hessian = jacrev(jacrev(f))(x) +``` + +## Install + +### Binaries + +Coming soon! + +### From Source + +`functorch` is a PyTorch C++ Extension module. To install, + +- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source). +Be sure to make sure the changes from https://github.com/pytorch/pytorch/pull/56824 +are on the branch. TODO: we should recommend a commit hash that is known to be stable +- Run `python setup.py install` + +Then, try to run some tests to make sure all is OK: +``` +pytest test/test_vmap.py -v +pytest test/test_eager_transforms.py -v +``` + +## Future Plans + +In the end state, we'd like to upstream this into PyTorch once we iron out the +design details. To figure out the details, we need your help -- please send us +your use cases by starting a conversation in the issue tracker or try out the +prototype. diff --git a/functorch/setup.py b/functorch/setup.py index 443ef4b027ce..7bb58cdb90bd 100644 --- a/functorch/setup.py +++ b/functorch/setup.py @@ -1,27 +1,62 @@ -import distutils +import distutils.command.clean import shutil import glob import os +import subprocess from setuptools import setup, find_packages from torch.utils.cpp_extension import ( CppExtension, BuildExtension, ) +cwd = os.path.dirname(os.path.abspath(__file__)) +version_txt = os.path.join(cwd, 'version.txt') +with open(version_txt, 'r') as f: + version = f.readline().strip() -# class clean(distutils.command.clean.clean): -# def run(self): -# with open(".gitignore", "r") as f: -# ignores = f.read() -# for wildcard in filter(None, ignores.split("\n")): -# for filename in glob.glob(wildcard): -# try: -# os.remove(filename) -# except OSError: -# shutil.rmtree(filename, ignore_errors=True) +try: + sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() +except Exception: + sha = 'Unknown' +package_name = 'functorch' + +if os.getenv('BUILD_VERSION'): + version = os.getenv('BUILD_VERSION') +elif sha != 'Unknown': + version += '+' + sha[:7] + + +def write_version_file(): + version_path = os.path.join(cwd, 'functorch', 'version.py') + with open(version_path, 'w') as f: + f.write("__version__ = '{}'\n".format(version)) + f.write("git_version = {}\n".format(repr(sha))) + +# TODO: is there a way to specify that either of the following is the requirement: +# 1. a pytorch nightly +# 2. a specific hash of PyTorch? +# pytorch_dep = 'torch' +# if os.getenv('PYTORCH_VERSION'): +# pytorch_dep += "==" + os.getenv('PYTORCH_VERSION') # -# # It's an old-style class in Python 2.7... -# distutils.command.clean.clean.run(self) +# requirements = [ +# pytorch_dep, +# ] + + +class clean(distutils.command.clean.clean): + def run(self): + with open(".gitignore", "r") as f: + ignores = f.read() + for wildcard in filter(None, ignores.split("\n")): + for filename in glob.glob(wildcard): + try: + os.remove(filename) + except OSError: + shutil.rmtree(filename, ignore_errors=True) + + # It's an old-style class in Python 2.7... + distutils.command.clean.clean.run(self) def get_extensions(): @@ -31,8 +66,9 @@ def get_extensions(): extra_link_args = [] extra_compile_args = {"cxx": ["-O3", "-g", "-std=c++14"]} - if int(os.environ.get("DEBUG", 0)): - # if True: + debug_mode = os.getenv('DEBUG', '0') == '1' + if debug_mode: + print("Compiling in debug mode") extra_compile_args = { "cxx": ["-O0", "-fno-inline", "-g", "-std=c++14"]} extra_link_args = ["-O0", "-g"] @@ -61,12 +97,24 @@ def get_extensions(): return ext_modules -setup( - name='functorch', - url="https://github.com/zou3519/functorch", - packages=find_packages(), - ext_modules=get_extensions(), - cmdclass={ - # "clean": clean, - "build_ext": BuildExtension - }) +if __name__ == '__main__': + print("Building wheel {}-{}".format(package_name, version)) + write_version_file() + + setup( + # Metadata + name=package_name, + version=version, + author='PyTorch Core Team', + url="https://github.com/zou3519/functorch", + description='prototype of composable function transforms for PyTorch', + license='BSD', + + # Package info + packages=find_packages(), + # install_requires=requirements, + ext_modules=get_extensions(), + cmdclass={ + "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), + 'clean': clean, + })