From 8277d74e42beac26054fcd51264d92d0a40ec85f Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 27 Apr 2021 11:52:08 -0700 Subject: [PATCH] [functorch] Update readme, fix unsqueeze batch rule --- functorch/README.md | 52 +++++++++---------- functorch/functorch/csrc/BatchRulesHelper.cpp | 7 ++- functorch/functorch/csrc/BatchRulesViews.cpp | 7 ++- functorch/test/test_vmap.py | 34 ++++++++++++ 4 files changed, 71 insertions(+), 29 deletions(-) diff --git a/functorch/README.md b/functorch/README.md index 11a1a5f7b741..0cfc0f2952ef 100644 --- a/functorch/README.md +++ b/functorch/README.md @@ -1,8 +1,8 @@ # functorch [**Why functorch?**](#why-composable-function-transforms) -| [**Transformations**](#what-are-the-transforms) | [**Install guide**](#install) +| [**Transformations**](#what-are-the-transforms) | [**Future Plans**](#future-plans) `functorch` is a prototype of [JAX-like](https://github.com/google/jax) @@ -28,6 +28,27 @@ Composing `vmap`, `grad`, and `vjp` transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the [JAX framework](https://github.com/google/jax). +## Install + +### Binaries + +Coming soon! + +### From Source + +`functorch` is a PyTorch C++ Extension module. To install, + +- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source). +Be sure to make sure the changes from https://github.com/pytorch/pytorch/pull/56824 +are on the branch. TODO: we should recommend a commit hash that is known to be stable +- Run `python setup.py install`. You can use `DEBUG=1` to compile in debug mode. + +Then, try to run some tests to make sure all is OK: +``` +pytest test/test_vmap.py -v +pytest test/test_eager_transforms.py -v +``` + ## What are the transforms? Right now, we support the following transforms: @@ -81,7 +102,7 @@ the gradients of the output of func w.r.t. to `inputs[0]`. ``` When composed with `vmap`, `grad` can be used to compute per-sample-gradients: -``` +```py >>> from functorch import vmap >>> batch_size, feature_size = 3, 5 >>> weights = torch.randn(feature_size, requires_grad=True) @@ -109,7 +130,7 @@ When composed with `vmap`, `grad` can be used to compute per-sample-gradients: The `vjp` transform applies `func` to `inputs` and returns a new function that computes vjps given some `contangents` Tensors. -``` +```py >>> from functorch import jacrev >>> x = torch.randn(5) >>> jacobian = jacrev(torch.sin)(x) @@ -119,14 +140,14 @@ computes vjps given some `contangents` Tensors. Use `jacrev` to compute the jacobian. This can be composed with vmap to produce batched jacobians: -``` +```py >>> x = torch.randn(64, 5) >>> jacobian = vmap(jacrev(torch.sin))(x) >>> assert jacobian.shape == (64, 5, 5) ``` `jacrev` can be composed with itself to produce hessians: -``` +```py >>> def f(x): >>> return x.sin().sum() >>> @@ -134,27 +155,6 @@ batched jacobians: >>> hessian = jacrev(jacrev(f))(x) ``` -## Install - -### Binaries - -Coming soon! - -### From Source - -`functorch` is a PyTorch C++ Extension module. To install, - -- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source). -Be sure to make sure the changes from https://github.com/pytorch/pytorch/pull/56824 -are on the branch. TODO: we should recommend a commit hash that is known to be stable -- Run `python setup.py install` - -Then, try to run some tests to make sure all is OK: -``` -pytest test/test_vmap.py -v -pytest test/test_eager_transforms.py -v -``` - ## Future Plans In the end state, we'd like to upstream this into PyTorch once we iron out the diff --git a/functorch/functorch/csrc/BatchRulesHelper.cpp b/functorch/functorch/csrc/BatchRulesHelper.cpp index 4719aa9b0386..a67d8f33b7ec 100644 --- a/functorch/functorch/csrc/BatchRulesHelper.cpp +++ b/functorch/functorch/csrc/BatchRulesHelper.cpp @@ -25,9 +25,14 @@ optional valIfNonempty(optional maybe_empty, int64_t new_val) } int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) { + // NB: assumes the batch dim is at the front of the tensor optional bdim = has_batch_dim ? optional(0) : nullopt; auto rank = rankWithoutBatchDim(tensor, bdim); - return maybe_wrap_dim(rank, logical_dim) + 1; + auto wrapped_dim = maybe_wrap_dim(rank, logical_dim); + if (has_batch_dim) { + return wrapped_dim + 1; + } + return wrapped_dim; } }} diff --git a/functorch/functorch/csrc/BatchRulesViews.cpp b/functorch/functorch/csrc/BatchRulesViews.cpp index c242f261fda6..1145bc7eb88b 100644 --- a/functorch/functorch/csrc/BatchRulesViews.cpp +++ b/functorch/functorch/csrc/BatchRulesViews.cpp @@ -88,8 +88,11 @@ std::tuple> unsqueeze_batch_rule( int64_t dim) { auto self_ = moveBatchDimToFront(self, self_bdim); auto rank = rankWithoutBatchDim(self, self_bdim); - dim = maybe_wrap_dim(dim, rank + 1) + 1; - return { self.unsqueeze(dim), valIfNonempty(self_bdim, 0) }; + dim = maybe_wrap_dim(dim, rank + 1); + if (self_bdim) { + dim += 1; + } + return { self_.unsqueeze(dim), valIfNonempty(self_bdim, 0) }; } TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py index e9fb2dd7e34e..eb3e99da67ae 100644 --- a/functorch/test/test_vmap.py +++ b/functorch/test/test_vmap.py @@ -1667,6 +1667,40 @@ class TestVmapOperators(Namespace.TestVmapBase): with self.assertRaisesRegex(RuntimeError, msg): vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor) + def test_unsqueeze(self): + op = torch.unsqueeze + test = self._vmap_view_test + B0, B1, B2 = 7, 11, 13 + + # unsqueeze dim 0 + test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None)) + test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None)) + + # unsqueeze last dim (positive) + test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None)) + test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None)) + + # unsqueeze last dim (negative) + test(op, (torch.rand(B0, 2, 5), -1), in_dims=(0, None)) + test(op, (torch.rand(2, B0, 5), -1), in_dims=(1, None)) + + # nested vmaps + def unsqueeze_0(x): + return torch.unsqueeze(x, 0) + + def unsqueeze_last(x): + return torch.unsqueeze(x, -1) + + # bdims in canonical order + test(vmap(unsqueeze_0), (torch.rand(B0, B1, 2), )) + test(vmap(unsqueeze_last), (torch.rand(B0, B1, 2),)) + + # wild bdims + test(vmap(unsqueeze_0), (torch.rand(B1, 2, B0),), in_dims=2) + test(vmap(unsqueeze_0, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2) + test(vmap(unsqueeze_last), (torch.rand(B1, 2, B0),), in_dims=2) + test(vmap(unsqueeze_last, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2) + def test_movedim(self): op = torch.movedim test = self._vmap_view_test