[functorch] Added dot/add.Scalar/mul.Scalar/etc. batching rules and added functools.wraps to grad

This commit is contained in:
Horace He
2021-04-28 17:02:47 -07:00
committed by Jon Janzen
parent ac9be17a87
commit 918ede7a85
5 changed files with 28 additions and 17 deletions

View File

@ -93,11 +93,11 @@ the gradients of the output of func w.r.t. to `inputs[0]`.
```py
>>> from functorch import grad
>>> x = torch.randn([])
>>> cos_x = grad(torch.sin)(x)
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>>
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(torch.sin))(x)
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())
```

View File

@ -1,5 +1,5 @@
import torch
from functools import partial
from functools import partial, wraps
import collections
import torch.nn as nn
import torch.nn.functional as F
@ -133,6 +133,7 @@ def grad_with_value(f, diff_argnums=(0,), has_aux=False):
return wrapper
def grad(f, diff_argnums=(0,), has_aux=False):
@wraps(f)
def wrapper(*args):
results = grad_with_value(f, diff_argnums, has_aux=has_aux)(*args)
if has_aux:

View File

@ -79,6 +79,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, const Scalar&);
using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
using TensorScalarType = Tensor (*)(const Tensor&, const Scalar&);
using TensorScalarScalarType = Tensor (*)(const Tensor&, const Scalar&, const Scalar&);
#define BINARY_POINTWISE_BATCH_RULE_SCALAR(op) \
binary_pointwise_batch_rule<TensorTensorScalarType, &op, const Scalar&>
@ -87,23 +88,25 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
#define BINARY_POINTWISE_BATCH_RULE(op) binary_pointwise_batch_rule<TensorTensorType, &op>
#define BINARY_POINTWISE(op) VMAP_SUPPORT(#op".Tensor", BINARY_POINTWISE_BATCH_RULE(at::op));
#define SINGLE_ARG(...) __VA_ARGS__
BINARY_POINTWISE_WITH_SCALAR(add);
BINARY_POINTWISE_WITH_SCALAR(sub);
BINARY_POINTWISE_WITH_SCALAR(rsub);
BINARY_POINTWISE(mul);
VMAP_SUPPORT("add.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarScalarType, &at::add, const Scalar&, const Scalar&>));
VMAP_SUPPORT("sub.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarScalarType, &at::sub, const Scalar&, const Scalar&>));
VMAP_SUPPORT("mul.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::mul, const Scalar&>));
VMAP_SUPPORT("div.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::div, const Scalar&>));
BINARY_POINTWISE(div);
VMAP_SUPPORT("tanh_backward", BINARY_POINTWISE_BATCH_RULE(at::tanh_backward));
// at::pow has three out-of-place overloads
#define POW_BATCH_RULE binary_pointwise_batch_rule<TensorTensorType, &at::pow>
VMAP_SUPPORT("pow.Tensor_Tensor", POW_BATCH_RULE);
#undef POW_BATCH_RULE
#define POW_BATCH_RULE basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>
VMAP_SUPPORT("pow.Tensor_Scalar", POW_BATCH_RULE);
#undef POW_BATCH_RULE
VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, &at::pow>));
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>));
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
#undef SINGLE_ARG
#undef BINARY_POINTWISE_BATCH_RULE_SCALAR
#undef BINARY_POINTWISE_BATCH_RULE
#undef BINARY_POINTWISE_WITH_SCALAR

View File

@ -21,9 +21,16 @@ slogdet_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
};
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("slogdet", slogdet_batch_rule);
std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<int64_t> A_bdim, const Tensor& B, optional<int64_t> B_bdim) {
auto A_ = moveBatchDimToFront(A, A_bdim);
auto B_ = moveBatchDimToFront(B, B_bdim);
return {at::matmul(A_, B_.t()), 0};
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("slogdet", slogdet_batch_rule);
VMAP_SUPPORT("dot", dot_batch_rule);
}
}}