[functorch] Added dot/add.Scalar/mul.Scalar/etc. batching rules and added functools.wraps to grad

This commit is contained in:
Horace He
2021-04-28 17:02:47 -07:00
committed by Jon Janzen
parent ac9be17a87
commit 918ede7a85
5 changed files with 28 additions and 17 deletions

View File

@ -93,11 +93,11 @@ the gradients of the output of func w.r.t. to `inputs[0]`.
```py
>>> from functorch import grad
>>> x = torch.randn([])
>>> cos_x = grad(torch.sin)(x)
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>>
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(torch.sin))(x)
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())
```

View File

@ -1,5 +1,5 @@
import torch
from functools import partial
from functools import partial, wraps
import collections
import torch.nn as nn
import torch.nn.functional as F
@ -133,6 +133,7 @@ def grad_with_value(f, diff_argnums=(0,), has_aux=False):
return wrapper
def grad(f, diff_argnums=(0,), has_aux=False):
@wraps(f)
def wrapper(*args):
results = grad_with_value(f, diff_argnums, has_aux=has_aux)(*args)
if has_aux:

View File

@ -79,6 +79,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, const Scalar&);
using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
using TensorScalarType = Tensor (*)(const Tensor&, const Scalar&);
using TensorScalarScalarType = Tensor (*)(const Tensor&, const Scalar&, const Scalar&);
#define BINARY_POINTWISE_BATCH_RULE_SCALAR(op) \
binary_pointwise_batch_rule<TensorTensorScalarType, &op, const Scalar&>
@ -87,23 +88,25 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
#define BINARY_POINTWISE_BATCH_RULE(op) binary_pointwise_batch_rule<TensorTensorType, &op>
#define BINARY_POINTWISE(op) VMAP_SUPPORT(#op".Tensor", BINARY_POINTWISE_BATCH_RULE(at::op));
#define SINGLE_ARG(...) __VA_ARGS__
BINARY_POINTWISE_WITH_SCALAR(add);
BINARY_POINTWISE_WITH_SCALAR(sub);
BINARY_POINTWISE_WITH_SCALAR(rsub);
BINARY_POINTWISE(mul);
VMAP_SUPPORT("add.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarScalarType, &at::add, const Scalar&, const Scalar&>));
VMAP_SUPPORT("sub.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarScalarType, &at::sub, const Scalar&, const Scalar&>));
VMAP_SUPPORT("mul.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::mul, const Scalar&>));
VMAP_SUPPORT("div.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::div, const Scalar&>));
BINARY_POINTWISE(div);
VMAP_SUPPORT("tanh_backward", BINARY_POINTWISE_BATCH_RULE(at::tanh_backward));
// at::pow has three out-of-place overloads
#define POW_BATCH_RULE binary_pointwise_batch_rule<TensorTensorType, &at::pow>
VMAP_SUPPORT("pow.Tensor_Tensor", POW_BATCH_RULE);
#undef POW_BATCH_RULE
#define POW_BATCH_RULE basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>
VMAP_SUPPORT("pow.Tensor_Scalar", POW_BATCH_RULE);
#undef POW_BATCH_RULE
VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, &at::pow>));
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>));
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
#undef SINGLE_ARG
#undef BINARY_POINTWISE_BATCH_RULE_SCALAR
#undef BINARY_POINTWISE_BATCH_RULE
#undef BINARY_POINTWISE_WITH_SCALAR

View File

@ -21,9 +21,16 @@ slogdet_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
};
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("slogdet", slogdet_batch_rule);
std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<int64_t> A_bdim, const Tensor& B, optional<int64_t> B_bdim) {
auto A_ = moveBatchDimToFront(A, A_bdim);
auto B_ = moveBatchDimToFront(B, B_bdim);
return {at::matmul(A_, B_.t()), 0};
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("slogdet", slogdet_batch_rule);
VMAP_SUPPORT("dot", dot_batch_rule);
}
}}

View File

@ -27,7 +27,7 @@ namespace at { namespace functorch {
// `self_bdim = 0`, and `dim = 0`. Note that there are **no BatchedTensors**
// involved in this case; there exists some plumbing that automatically unwraps
// BatchedTensors before calling the batch rule.
//
//
// To write the logic of the batch rule: think about the semantics of the
// `sum` operation if `self` had an additional dimension (indicated by self_bdim):
// - If `self_bdim` is null, then we just do `result = self.sum(dim)` as usual
@ -47,15 +47,15 @@ namespace at { namespace functorch {
// VMAP_SUPPORT("sum.int", sum_batch_rule);
// ...
// }
//
//
// Note [Reusing batch rules to add vmap support for a complicated operator]
// Can't figure out how to write a batch rule for a big operation? If the
// operation can be expressed as a composition of other operations that do have
// batch rules, then that is another way to add vmap support. For example,
// consider the following schema
// consider the following schema
// func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1)
// and assume we already have batching rules for basic arithmetic operators.
//
//
// To add vmap support, define a decomposition using the same signature:
// Tensor addcmul_decomp(const Tensor& self, const Tensor& tensor1,
// const Tensor& tensor2, const Scalar& value) {
@ -73,7 +73,7 @@ namespace at { namespace functorch {
// TODO: This is kinda complicated. Saving this for a future date.
std::tuple<Tensor, optional<int64_t>> flatten_batch_rule(
const Tensor& self,
const Tensor& self,
optional<int64_t> self_bdim,
int64_t start_dim, int64_t end_dim) {
auto self_ = moveBatchDimToFront(self, self_bdim);
@ -86,7 +86,7 @@ std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
int64_t dim) {
auto self_ = moveBatchDimToFront(self, self_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
auto rank = rankWithoutBatchDim(self, self_bdim);
dim = maybe_wrap_dim(dim, rank + 1);
if (self_bdim) {