mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Added dot/add.Scalar/mul.Scalar/etc. batching rules and added functools.wraps to grad
This commit is contained in:
@ -93,11 +93,11 @@ the gradients of the output of func w.r.t. to `inputs[0]`.
|
||||
```py
|
||||
>>> from functorch import grad
|
||||
>>> x = torch.randn([])
|
||||
>>> cos_x = grad(torch.sin)(x)
|
||||
>>> cos_x = grad(lambda x: torch.sin(x))(x)
|
||||
>>> assert torch.allclose(cos_x, x.cos())
|
||||
>>>
|
||||
>>> # Second-order gradients
|
||||
>>> neg_sin_x = grad(grad(torch.sin))(x)
|
||||
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
|
||||
>>> assert torch.allclose(neg_sin_x, -x.sin())
|
||||
```
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
import collections
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -133,6 +133,7 @@ def grad_with_value(f, diff_argnums=(0,), has_aux=False):
|
||||
return wrapper
|
||||
|
||||
def grad(f, diff_argnums=(0,), has_aux=False):
|
||||
@wraps(f)
|
||||
def wrapper(*args):
|
||||
results = grad_with_value(f, diff_argnums, has_aux=has_aux)(*args)
|
||||
if has_aux:
|
||||
|
@ -79,6 +79,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, const Scalar&);
|
||||
using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
|
||||
using TensorScalarType = Tensor (*)(const Tensor&, const Scalar&);
|
||||
using TensorScalarScalarType = Tensor (*)(const Tensor&, const Scalar&, const Scalar&);
|
||||
|
||||
#define BINARY_POINTWISE_BATCH_RULE_SCALAR(op) \
|
||||
binary_pointwise_batch_rule<TensorTensorScalarType, &op, const Scalar&>
|
||||
@ -87,23 +88,25 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
|
||||
#define BINARY_POINTWISE_BATCH_RULE(op) binary_pointwise_batch_rule<TensorTensorType, &op>
|
||||
#define BINARY_POINTWISE(op) VMAP_SUPPORT(#op".Tensor", BINARY_POINTWISE_BATCH_RULE(at::op));
|
||||
#define SINGLE_ARG(...) __VA_ARGS__
|
||||
|
||||
BINARY_POINTWISE_WITH_SCALAR(add);
|
||||
BINARY_POINTWISE_WITH_SCALAR(sub);
|
||||
BINARY_POINTWISE_WITH_SCALAR(rsub);
|
||||
BINARY_POINTWISE(mul);
|
||||
VMAP_SUPPORT("add.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarScalarType, &at::add, const Scalar&, const Scalar&>));
|
||||
VMAP_SUPPORT("sub.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarScalarType, &at::sub, const Scalar&, const Scalar&>));
|
||||
VMAP_SUPPORT("mul.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::mul, const Scalar&>));
|
||||
VMAP_SUPPORT("div.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::div, const Scalar&>));
|
||||
BINARY_POINTWISE(div);
|
||||
VMAP_SUPPORT("tanh_backward", BINARY_POINTWISE_BATCH_RULE(at::tanh_backward));
|
||||
|
||||
// at::pow has three out-of-place overloads
|
||||
#define POW_BATCH_RULE binary_pointwise_batch_rule<TensorTensorType, &at::pow>
|
||||
VMAP_SUPPORT("pow.Tensor_Tensor", POW_BATCH_RULE);
|
||||
#undef POW_BATCH_RULE
|
||||
#define POW_BATCH_RULE basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>
|
||||
VMAP_SUPPORT("pow.Tensor_Scalar", POW_BATCH_RULE);
|
||||
#undef POW_BATCH_RULE
|
||||
VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, &at::pow>));
|
||||
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>));
|
||||
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
|
||||
|
||||
#undef SINGLE_ARG
|
||||
#undef BINARY_POINTWISE_BATCH_RULE_SCALAR
|
||||
#undef BINARY_POINTWISE_BATCH_RULE
|
||||
#undef BINARY_POINTWISE_WITH_SCALAR
|
||||
|
@ -21,9 +21,16 @@ slogdet_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
|
||||
};
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT("slogdet", slogdet_batch_rule);
|
||||
std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<int64_t> A_bdim, const Tensor& B, optional<int64_t> B_bdim) {
|
||||
auto A_ = moveBatchDimToFront(A, A_bdim);
|
||||
auto B_ = moveBatchDimToFront(B, B_bdim);
|
||||
return {at::matmul(A_, B_.t()), 0};
|
||||
}
|
||||
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT("slogdet", slogdet_batch_rule);
|
||||
VMAP_SUPPORT("dot", dot_batch_rule);
|
||||
}
|
||||
}}
|
||||
|
||||
|
Reference in New Issue
Block a user