mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Update readme, fix unsqueeze batch rule
This commit is contained in:
@ -1,8 +1,8 @@
|
|||||||
# functorch
|
# functorch
|
||||||
|
|
||||||
[**Why functorch?**](#why-composable-function-transforms)
|
[**Why functorch?**](#why-composable-function-transforms)
|
||||||
| [**Transformations**](#what-are-the-transforms)
|
|
||||||
| [**Install guide**](#install)
|
| [**Install guide**](#install)
|
||||||
|
| [**Transformations**](#what-are-the-transforms)
|
||||||
| [**Future Plans**](#future-plans)
|
| [**Future Plans**](#future-plans)
|
||||||
|
|
||||||
`functorch` is a prototype of [JAX-like](https://github.com/google/jax)
|
`functorch` is a prototype of [JAX-like](https://github.com/google/jax)
|
||||||
@ -28,6 +28,27 @@ Composing `vmap`, `grad`, and `vjp` transforms allows us to express the above
|
|||||||
without designing a separate subsystem for each. This idea of composable function
|
without designing a separate subsystem for each. This idea of composable function
|
||||||
transforms comes from the [JAX framework](https://github.com/google/jax).
|
transforms comes from the [JAX framework](https://github.com/google/jax).
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
### Binaries
|
||||||
|
|
||||||
|
Coming soon!
|
||||||
|
|
||||||
|
### From Source
|
||||||
|
|
||||||
|
`functorch` is a PyTorch C++ Extension module. To install,
|
||||||
|
|
||||||
|
- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source).
|
||||||
|
Be sure to make sure the changes from https://github.com/pytorch/pytorch/pull/56824
|
||||||
|
are on the branch. TODO: we should recommend a commit hash that is known to be stable
|
||||||
|
- Run `python setup.py install`. You can use `DEBUG=1` to compile in debug mode.
|
||||||
|
|
||||||
|
Then, try to run some tests to make sure all is OK:
|
||||||
|
```
|
||||||
|
pytest test/test_vmap.py -v
|
||||||
|
pytest test/test_eager_transforms.py -v
|
||||||
|
```
|
||||||
|
|
||||||
## What are the transforms?
|
## What are the transforms?
|
||||||
|
|
||||||
Right now, we support the following transforms:
|
Right now, we support the following transforms:
|
||||||
@ -81,7 +102,7 @@ the gradients of the output of func w.r.t. to `inputs[0]`.
|
|||||||
```
|
```
|
||||||
|
|
||||||
When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
|
When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
|
||||||
```
|
```py
|
||||||
>>> from functorch import vmap
|
>>> from functorch import vmap
|
||||||
>>> batch_size, feature_size = 3, 5
|
>>> batch_size, feature_size = 3, 5
|
||||||
>>> weights = torch.randn(feature_size, requires_grad=True)
|
>>> weights = torch.randn(feature_size, requires_grad=True)
|
||||||
@ -109,7 +130,7 @@ When composed with `vmap`, `grad` can be used to compute per-sample-gradients:
|
|||||||
The `vjp` transform applies `func` to `inputs` and returns a new function that
|
The `vjp` transform applies `func` to `inputs` and returns a new function that
|
||||||
computes vjps given some `contangents` Tensors.
|
computes vjps given some `contangents` Tensors.
|
||||||
|
|
||||||
```
|
```py
|
||||||
>>> from functorch import jacrev
|
>>> from functorch import jacrev
|
||||||
>>> x = torch.randn(5)
|
>>> x = torch.randn(5)
|
||||||
>>> jacobian = jacrev(torch.sin)(x)
|
>>> jacobian = jacrev(torch.sin)(x)
|
||||||
@ -119,14 +140,14 @@ computes vjps given some `contangents` Tensors.
|
|||||||
Use `jacrev` to compute the jacobian. This can be composed with vmap to produce
|
Use `jacrev` to compute the jacobian. This can be composed with vmap to produce
|
||||||
batched jacobians:
|
batched jacobians:
|
||||||
|
|
||||||
```
|
```py
|
||||||
>>> x = torch.randn(64, 5)
|
>>> x = torch.randn(64, 5)
|
||||||
>>> jacobian = vmap(jacrev(torch.sin))(x)
|
>>> jacobian = vmap(jacrev(torch.sin))(x)
|
||||||
>>> assert jacobian.shape == (64, 5, 5)
|
>>> assert jacobian.shape == (64, 5, 5)
|
||||||
```
|
```
|
||||||
|
|
||||||
`jacrev` can be composed with itself to produce hessians:
|
`jacrev` can be composed with itself to produce hessians:
|
||||||
```
|
```py
|
||||||
>>> def f(x):
|
>>> def f(x):
|
||||||
>>> return x.sin().sum()
|
>>> return x.sin().sum()
|
||||||
>>>
|
>>>
|
||||||
@ -134,27 +155,6 @@ batched jacobians:
|
|||||||
>>> hessian = jacrev(jacrev(f))(x)
|
>>> hessian = jacrev(jacrev(f))(x)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Install
|
|
||||||
|
|
||||||
### Binaries
|
|
||||||
|
|
||||||
Coming soon!
|
|
||||||
|
|
||||||
### From Source
|
|
||||||
|
|
||||||
`functorch` is a PyTorch C++ Extension module. To install,
|
|
||||||
|
|
||||||
- Install [PyTorch from source](https://github.com/pytorch/pytorch#from-source).
|
|
||||||
Be sure to make sure the changes from https://github.com/pytorch/pytorch/pull/56824
|
|
||||||
are on the branch. TODO: we should recommend a commit hash that is known to be stable
|
|
||||||
- Run `python setup.py install`
|
|
||||||
|
|
||||||
Then, try to run some tests to make sure all is OK:
|
|
||||||
```
|
|
||||||
pytest test/test_vmap.py -v
|
|
||||||
pytest test/test_eager_transforms.py -v
|
|
||||||
```
|
|
||||||
|
|
||||||
## Future Plans
|
## Future Plans
|
||||||
|
|
||||||
In the end state, we'd like to upstream this into PyTorch once we iron out the
|
In the end state, we'd like to upstream this into PyTorch once we iron out the
|
||||||
|
@ -25,9 +25,14 @@ optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val)
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) {
|
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) {
|
||||||
|
// NB: assumes the batch dim is at the front of the tensor
|
||||||
optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt;
|
optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt;
|
||||||
auto rank = rankWithoutBatchDim(tensor, bdim);
|
auto rank = rankWithoutBatchDim(tensor, bdim);
|
||||||
return maybe_wrap_dim(rank, logical_dim) + 1;
|
auto wrapped_dim = maybe_wrap_dim(rank, logical_dim);
|
||||||
|
if (has_batch_dim) {
|
||||||
|
return wrapped_dim + 1;
|
||||||
|
}
|
||||||
|
return wrapped_dim;
|
||||||
}
|
}
|
||||||
|
|
||||||
}}
|
}}
|
||||||
|
@ -88,8 +88,11 @@ std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
|
|||||||
int64_t dim) {
|
int64_t dim) {
|
||||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||||
auto rank = rankWithoutBatchDim(self, self_bdim);
|
auto rank = rankWithoutBatchDim(self, self_bdim);
|
||||||
dim = maybe_wrap_dim(dim, rank + 1) + 1;
|
dim = maybe_wrap_dim(dim, rank + 1);
|
||||||
return { self.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
|
if (self_bdim) {
|
||||||
|
dim += 1;
|
||||||
|
}
|
||||||
|
return { self_.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||||
|
@ -1667,6 +1667,40 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
|||||||
with self.assertRaisesRegex(RuntimeError, msg):
|
with self.assertRaisesRegex(RuntimeError, msg):
|
||||||
vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
|
vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
|
||||||
|
|
||||||
|
def test_unsqueeze(self):
|
||||||
|
op = torch.unsqueeze
|
||||||
|
test = self._vmap_view_test
|
||||||
|
B0, B1, B2 = 7, 11, 13
|
||||||
|
|
||||||
|
# unsqueeze dim 0
|
||||||
|
test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None))
|
||||||
|
test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None))
|
||||||
|
|
||||||
|
# unsqueeze last dim (positive)
|
||||||
|
test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
|
||||||
|
test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None))
|
||||||
|
|
||||||
|
# unsqueeze last dim (negative)
|
||||||
|
test(op, (torch.rand(B0, 2, 5), -1), in_dims=(0, None))
|
||||||
|
test(op, (torch.rand(2, B0, 5), -1), in_dims=(1, None))
|
||||||
|
|
||||||
|
# nested vmaps
|
||||||
|
def unsqueeze_0(x):
|
||||||
|
return torch.unsqueeze(x, 0)
|
||||||
|
|
||||||
|
def unsqueeze_last(x):
|
||||||
|
return torch.unsqueeze(x, -1)
|
||||||
|
|
||||||
|
# bdims in canonical order
|
||||||
|
test(vmap(unsqueeze_0), (torch.rand(B0, B1, 2), ))
|
||||||
|
test(vmap(unsqueeze_last), (torch.rand(B0, B1, 2),))
|
||||||
|
|
||||||
|
# wild bdims
|
||||||
|
test(vmap(unsqueeze_0), (torch.rand(B1, 2, B0),), in_dims=2)
|
||||||
|
test(vmap(unsqueeze_0, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
|
||||||
|
test(vmap(unsqueeze_last), (torch.rand(B1, 2, B0),), in_dims=2)
|
||||||
|
test(vmap(unsqueeze_last, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
|
||||||
|
|
||||||
def test_movedim(self):
|
def test_movedim(self):
|
||||||
op = torch.movedim
|
op = torch.movedim
|
||||||
test = self._vmap_view_test
|
test = self._vmap_view_test
|
||||||
|
Reference in New Issue
Block a user