[functorch] Update readme, fix unsqueeze batch rule

This commit is contained in:
Richard Zou
2021-04-27 11:52:08 -07:00
committed by Jon Janzen
parent 7001a2f1e4
commit 8277d74e42
4 changed files with 71 additions and 29 deletions

View File

@ -1,8 +1,8 @@
# functorch
[**Why functorch?**](#why-composable-function-transforms)
| [**Transformations**](#what-are-the-transforms)
| [**Install guide**](#install)
| [**Transformations**](#what-are-the-transforms)
| [**Future Plans**](#future-plans)
`functorch` is a prototype of [JAX-like](https://github.com/google/jax)
@ -28,6 +28,27 @@ Composing `vmap`, `grad`, and `vjp` transforms allows us to express the above
without designing a separate subsystem for each. This idea of composable function
transforms comes from the [JAX framework](https://github.com/google/jax).
## Install
### Binaries
Coming soon!
### From Source
`functorch` is a PyTorch C++ Extension module. To install,
- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source).
Be sure to make sure the changes from https://github.com/pytorch/pytorch/pull/56824
are on the branch. TODO: we should recommend a commit hash that is known to be stable
- Run `python setup.py install`. You can use `DEBUG=1` to compile in debug mode.
Then, try to run some tests to make sure all is OK:
```
pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -v
```
## What are the transforms?
Right now, we support the following transforms:
@ -81,7 +102,7 @@ the gradients of the output of func w.r.t. to `inputs[0]`.
```
When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
```
```py
>>> from functorch import vmap
>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
@ -109,7 +130,7 @@ When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
The `vjp` transform applies `func` to `inputs` and returns a new function that
computes vjps given some `contangents` Tensors.
```
```py
>>> from functorch import jacrev
>>> x = torch.randn(5)
>>> jacobian = jacrev(torch.sin)(x)
@ -119,14 +140,14 @@ computes vjps given some `contangents` Tensors.
Use `jacrev` to compute the jacobian. This can be composed with vmap to produce
batched jacobians:
```
```py
>>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacrev(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5)
```
`jacrev` can be composed with itself to produce hessians:
```
```py
>>> def f(x):
>>> return x.sin().sum()
>>>
@ -134,27 +155,6 @@ batched jacobians:
>>> hessian = jacrev(jacrev(f))(x)
```
## Install
### Binaries
Coming soon!
### From Source
`functorch` is a PyTorch C++ Extension module. To install,
- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source).
Be sure to make sure the changes from https://github.com/pytorch/pytorch/pull/56824
are on the branch. TODO: we should recommend a commit hash that is known to be stable
- Run `python setup.py install`
Then, try to run some tests to make sure all is OK:
```
pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -v
```
## Future Plans
In the end state, we'd like to upstream this into PyTorch once we iron out the

View File

@ -25,9 +25,14 @@ optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val)
}
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) {
// NB: assumes the batch dim is at the front of the tensor
optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt;
auto rank = rankWithoutBatchDim(tensor, bdim);
return maybe_wrap_dim(rank, logical_dim) + 1;
auto wrapped_dim = maybe_wrap_dim(rank, logical_dim);
if (has_batch_dim) {
return wrapped_dim + 1;
}
return wrapped_dim;
}
}}

View File

@ -88,8 +88,11 @@ std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
int64_t dim) {
auto self_ = moveBatchDimToFront(self, self_bdim);
auto rank = rankWithoutBatchDim(self, self_bdim);
dim = maybe_wrap_dim(dim, rank + 1) + 1;
return { self.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
dim = maybe_wrap_dim(dim, rank + 1);
if (self_bdim) {
dim += 1;
}
return { self_.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {

View File

@ -1667,6 +1667,40 @@ class TestVmapOperators(Namespace.TestVmapBase):
with self.assertRaisesRegex(RuntimeError, msg):
vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
def test_unsqueeze(self):
op = torch.unsqueeze
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
# unsqueeze dim 0
test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None))
test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None))
# unsqueeze last dim (positive)
test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None))
# unsqueeze last dim (negative)
test(op, (torch.rand(B0, 2, 5), -1), in_dims=(0, None))
test(op, (torch.rand(2, B0, 5), -1), in_dims=(1, None))
# nested vmaps
def unsqueeze_0(x):
return torch.unsqueeze(x, 0)
def unsqueeze_last(x):
return torch.unsqueeze(x, -1)
# bdims in canonical order
test(vmap(unsqueeze_0), (torch.rand(B0, B1, 2), ))
test(vmap(unsqueeze_last), (torch.rand(B0, B1, 2),))
# wild bdims
test(vmap(unsqueeze_0), (torch.rand(B1, 2, B0),), in_dims=2)
test(vmap(unsqueeze_0, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
test(vmap(unsqueeze_last), (torch.rand(B1, 2, B0),), in_dims=2)
test(vmap(unsqueeze_last, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
def test_movedim(self):
op = torch.movedim
test = self._vmap_view_test