mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] Update readme, fix unsqueeze batch rule
This commit is contained in:
@ -1,8 +1,8 @@
|
||||
# functorch
|
||||
|
||||
[**Why functorch?**](#why-composable-function-transforms)
|
||||
| [**Transformations**](#what-are-the-transforms)
|
||||
| [**Install guide**](#install)
|
||||
| [**Transformations**](#what-are-the-transforms)
|
||||
| [**Future Plans**](#future-plans)
|
||||
|
||||
`functorch` is a prototype of [JAX-like](https://github.com/google/jax)
|
||||
@ -28,6 +28,27 @@ Composing `vmap`, `grad`, and `vjp` transforms allows us to express the above
|
||||
without designing a separate subsystem for each. This idea of composable function
|
||||
transforms comes from the [JAX framework](https://github.com/google/jax).
|
||||
|
||||
## Install
|
||||
|
||||
### Binaries
|
||||
|
||||
Coming soon!
|
||||
|
||||
### From Source
|
||||
|
||||
`functorch` is a PyTorch C++ Extension module. To install,
|
||||
|
||||
- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source).
|
||||
Be sure to make sure the changes from https://github.com/pytorch/pytorch/pull/56824
|
||||
are on the branch. TODO: we should recommend a commit hash that is known to be stable
|
||||
- Run `python setup.py install`. You can use `DEBUG=1` to compile in debug mode.
|
||||
|
||||
Then, try to run some tests to make sure all is OK:
|
||||
```
|
||||
pytest test/test_vmap.py -v
|
||||
pytest test/test_eager_transforms.py -v
|
||||
```
|
||||
|
||||
## What are the transforms?
|
||||
|
||||
Right now, we support the following transforms:
|
||||
@ -81,7 +102,7 @@ the gradients of the output of func w.r.t. to `inputs[0]`.
|
||||
```
|
||||
|
||||
When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
|
||||
```
|
||||
```py
|
||||
>>> from functorch import vmap
|
||||
>>> batch_size, feature_size = 3, 5
|
||||
>>> weights = torch.randn(feature_size, requires_grad=True)
|
||||
@ -109,7 +130,7 @@ When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
|
||||
The `vjp` transform applies `func` to `inputs` and returns a new function that
|
||||
computes vjps given some `contangents` Tensors.
|
||||
|
||||
```
|
||||
```py
|
||||
>>> from functorch import jacrev
|
||||
>>> x = torch.randn(5)
|
||||
>>> jacobian = jacrev(torch.sin)(x)
|
||||
@ -119,14 +140,14 @@ computes vjps given some `contangents` Tensors.
|
||||
Use `jacrev` to compute the jacobian. This can be composed with vmap to produce
|
||||
batched jacobians:
|
||||
|
||||
```
|
||||
```py
|
||||
>>> x = torch.randn(64, 5)
|
||||
>>> jacobian = vmap(jacrev(torch.sin))(x)
|
||||
>>> assert jacobian.shape == (64, 5, 5)
|
||||
```
|
||||
|
||||
`jacrev` can be composed with itself to produce hessians:
|
||||
```
|
||||
```py
|
||||
>>> def f(x):
|
||||
>>> return x.sin().sum()
|
||||
>>>
|
||||
@ -134,27 +155,6 @@ batched jacobians:
|
||||
>>> hessian = jacrev(jacrev(f))(x)
|
||||
```
|
||||
|
||||
## Install
|
||||
|
||||
### Binaries
|
||||
|
||||
Coming soon!
|
||||
|
||||
### From Source
|
||||
|
||||
`functorch` is a PyTorch C++ Extension module. To install,
|
||||
|
||||
- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source).
|
||||
Be sure to make sure the changes from https://github.com/pytorch/pytorch/pull/56824
|
||||
are on the branch. TODO: we should recommend a commit hash that is known to be stable
|
||||
- Run `python setup.py install`
|
||||
|
||||
Then, try to run some tests to make sure all is OK:
|
||||
```
|
||||
pytest test/test_vmap.py -v
|
||||
pytest test/test_eager_transforms.py -v
|
||||
```
|
||||
|
||||
## Future Plans
|
||||
|
||||
In the end state, we'd like to upstream this into PyTorch once we iron out the
|
||||
|
@ -25,9 +25,14 @@ optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val)
|
||||
}
|
||||
|
||||
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) {
|
||||
// NB: assumes the batch dim is at the front of the tensor
|
||||
optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt;
|
||||
auto rank = rankWithoutBatchDim(tensor, bdim);
|
||||
return maybe_wrap_dim(rank, logical_dim) + 1;
|
||||
auto wrapped_dim = maybe_wrap_dim(rank, logical_dim);
|
||||
if (has_batch_dim) {
|
||||
return wrapped_dim + 1;
|
||||
}
|
||||
return wrapped_dim;
|
||||
}
|
||||
|
||||
}}
|
||||
|
@ -88,8 +88,11 @@ std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
|
||||
int64_t dim) {
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
auto rank = rankWithoutBatchDim(self, self_bdim);
|
||||
dim = maybe_wrap_dim(dim, rank + 1) + 1;
|
||||
return { self.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
|
||||
dim = maybe_wrap_dim(dim, rank + 1);
|
||||
if (self_bdim) {
|
||||
dim += 1;
|
||||
}
|
||||
return { self_.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
|
@ -1667,6 +1667,40 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
|
||||
|
||||
def test_unsqueeze(self):
|
||||
op = torch.unsqueeze
|
||||
test = self._vmap_view_test
|
||||
B0, B1, B2 = 7, 11, 13
|
||||
|
||||
# unsqueeze dim 0
|
||||
test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None))
|
||||
test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None))
|
||||
|
||||
# unsqueeze last dim (positive)
|
||||
test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
|
||||
test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None))
|
||||
|
||||
# unsqueeze last dim (negative)
|
||||
test(op, (torch.rand(B0, 2, 5), -1), in_dims=(0, None))
|
||||
test(op, (torch.rand(2, B0, 5), -1), in_dims=(1, None))
|
||||
|
||||
# nested vmaps
|
||||
def unsqueeze_0(x):
|
||||
return torch.unsqueeze(x, 0)
|
||||
|
||||
def unsqueeze_last(x):
|
||||
return torch.unsqueeze(x, -1)
|
||||
|
||||
# bdims in canonical order
|
||||
test(vmap(unsqueeze_0), (torch.rand(B0, B1, 2), ))
|
||||
test(vmap(unsqueeze_last), (torch.rand(B0, B1, 2),))
|
||||
|
||||
# wild bdims
|
||||
test(vmap(unsqueeze_0), (torch.rand(B1, 2, B0),), in_dims=2)
|
||||
test(vmap(unsqueeze_0, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
|
||||
test(vmap(unsqueeze_last), (torch.rand(B1, 2, B0),), in_dims=2)
|
||||
test(vmap(unsqueeze_last, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
|
||||
|
||||
def test_movedim(self):
|
||||
op = torch.movedim
|
||||
test = self._vmap_view_test
|
||||
|
Reference in New Issue
Block a user