[functorch] Update readme, fix unsqueeze batch rule

This commit is contained in:
Richard Zou
2021-04-27 11:52:08 -07:00
committed by Jon Janzen
parent 7001a2f1e4
commit 8277d74e42
4 changed files with 71 additions and 29 deletions

View File

@ -1,8 +1,8 @@
# functorch # functorch
[**Why functorch?**](#why-composable-function-transforms) [**Why functorch?**](#why-composable-function-transforms)
| [**Transformations**](#what-are-the-transforms)
| [**Install guide**](#install) | [**Install guide**](#install)
| [**Transformations**](#what-are-the-transforms)
| [**Future Plans**](#future-plans) | [**Future Plans**](#future-plans)
`functorch` is a prototype of [JAX-like](https://github.com/google/jax) `functorch` is a prototype of [JAX-like](https://github.com/google/jax)
@ -28,6 +28,27 @@ Composing `vmap`, `grad`, and `vjp` transforms allows us to express the above
without designing a separate subsystem for each. This idea of composable function without designing a separate subsystem for each. This idea of composable function
transforms comes from the [JAX framework](https://github.com/google/jax). transforms comes from the [JAX framework](https://github.com/google/jax).
## Install
### Binaries
Coming soon!
### From Source
`functorch` is a PyTorch C++ Extension module. To install,
- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source).
Be sure to make sure the changes from https://github.com/pytorch/pytorch/pull/56824
are on the branch. TODO: we should recommend a commit hash that is known to be stable
- Run `python setup.py install`. You can use `DEBUG=1` to compile in debug mode.
Then, try to run some tests to make sure all is OK:
```
pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -v
```
## What are the transforms? ## What are the transforms?
Right now, we support the following transforms: Right now, we support the following transforms:
@ -81,7 +102,7 @@ the gradients of the output of func w.r.t. to `inputs[0]`.
``` ```
When composed with `vmap`, `grad` can be used to compute per-sample-gradients: When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
``` ```py
>>> from functorch import vmap >>> from functorch import vmap
>>> batch_size, feature_size = 3, 5 >>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True) >>> weights = torch.randn(feature_size, requires_grad=True)
@ -109,7 +130,7 @@ When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
The `vjp` transform applies `func` to `inputs` and returns a new function that The `vjp` transform applies `func` to `inputs` and returns a new function that
computes vjps given some `contangents` Tensors. computes vjps given some `contangents` Tensors.
``` ```py
>>> from functorch import jacrev >>> from functorch import jacrev
>>> x = torch.randn(5) >>> x = torch.randn(5)
>>> jacobian = jacrev(torch.sin)(x) >>> jacobian = jacrev(torch.sin)(x)
@ -119,14 +140,14 @@ computes vjps given some `contangents` Tensors.
Use `jacrev` to compute the jacobian. This can be composed with vmap to produce Use `jacrev` to compute the jacobian. This can be composed with vmap to produce
batched jacobians: batched jacobians:
``` ```py
>>> x = torch.randn(64, 5) >>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacrev(torch.sin))(x) >>> jacobian = vmap(jacrev(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5) >>> assert jacobian.shape == (64, 5, 5)
``` ```
`jacrev` can be composed with itself to produce hessians: `jacrev` can be composed with itself to produce hessians:
``` ```py
>>> def f(x): >>> def f(x):
>>> return x.sin().sum() >>> return x.sin().sum()
>>> >>>
@ -134,27 +155,6 @@ batched jacobians:
>>> hessian = jacrev(jacrev(f))(x) >>> hessian = jacrev(jacrev(f))(x)
``` ```
## Install
### Binaries
Coming soon!
### From Source
`functorch` is a PyTorch C++ Extension module. To install,
- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source).
Be sure to make sure the changes from https://github.com/pytorch/pytorch/pull/56824
are on the branch. TODO: we should recommend a commit hash that is known to be stable
- Run `python setup.py install`
Then, try to run some tests to make sure all is OK:
```
pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -v
```
## Future Plans ## Future Plans
In the end state, we'd like to upstream this into PyTorch once we iron out the In the end state, we'd like to upstream this into PyTorch once we iron out the

View File

@ -25,9 +25,14 @@ optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val)
} }
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) { int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) {
// NB: assumes the batch dim is at the front of the tensor
optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt; optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt;
auto rank = rankWithoutBatchDim(tensor, bdim); auto rank = rankWithoutBatchDim(tensor, bdim);
return maybe_wrap_dim(rank, logical_dim) + 1; auto wrapped_dim = maybe_wrap_dim(rank, logical_dim);
if (has_batch_dim) {
return wrapped_dim + 1;
}
return wrapped_dim;
} }
}} }}

View File

@ -88,8 +88,11 @@ std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
int64_t dim) { int64_t dim) {
auto self_ = moveBatchDimToFront(self, self_bdim); auto self_ = moveBatchDimToFront(self, self_bdim);
auto rank = rankWithoutBatchDim(self, self_bdim); auto rank = rankWithoutBatchDim(self, self_bdim);
dim = maybe_wrap_dim(dim, rank + 1) + 1; dim = maybe_wrap_dim(dim, rank + 1);
return { self.unsqueeze(dim), valIfNonempty(self_bdim, 0) }; if (self_bdim) {
dim += 1;
}
return { self_.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
} }
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {

View File

@ -1667,6 +1667,40 @@ class TestVmapOperators(Namespace.TestVmapBase):
with self.assertRaisesRegex(RuntimeError, msg): with self.assertRaisesRegex(RuntimeError, msg):
vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor) vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
def test_unsqueeze(self):
op = torch.unsqueeze
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
# unsqueeze dim 0
test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None))
test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None))
# unsqueeze last dim (positive)
test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None))
# unsqueeze last dim (negative)
test(op, (torch.rand(B0, 2, 5), -1), in_dims=(0, None))
test(op, (torch.rand(2, B0, 5), -1), in_dims=(1, None))
# nested vmaps
def unsqueeze_0(x):
return torch.unsqueeze(x, 0)
def unsqueeze_last(x):
return torch.unsqueeze(x, -1)
# bdims in canonical order
test(vmap(unsqueeze_0), (torch.rand(B0, B1, 2), ))
test(vmap(unsqueeze_last), (torch.rand(B0, B1, 2),))
# wild bdims
test(vmap(unsqueeze_0), (torch.rand(B1, 2, B0),), in_dims=2)
test(vmap(unsqueeze_0, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
test(vmap(unsqueeze_last), (torch.rand(B1, 2, B0),), in_dims=2)
test(vmap(unsqueeze_last, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
def test_movedim(self): def test_movedim(self):
op = torch.movedim op = torch.movedim
test = self._vmap_view_test test = self._vmap_view_test