From 918ede7a8590604491369ff441993c8b23e05f3b Mon Sep 17 00:00:00 2001 From: Horace He Date: Wed, 28 Apr 2021 17:02:47 -0700 Subject: [PATCH] [functorch] Added dot/add.Scalar/mul.Scalar/etc. batching rules and added functools.wraps to grad --- functorch/README.md | 4 ++-- functorch/functorch/_src/eager_transforms.py | 3 ++- functorch/functorch/csrc/BatchRulesBinaryOps.cpp | 15 +++++++++------ .../functorch/csrc/BatchRulesLinearAlgebra.cpp | 11 +++++++++-- functorch/functorch/csrc/BatchRulesViews.cpp | 12 ++++++------ 5 files changed, 28 insertions(+), 17 deletions(-) diff --git a/functorch/README.md b/functorch/README.md index a532dbbdf1e7..b64d13f4fe4d 100644 --- a/functorch/README.md +++ b/functorch/README.md @@ -93,11 +93,11 @@ the gradients of the output of func w.r.t. to `inputs[0]`. ```py >>> from functorch import grad >>> x = torch.randn([]) ->>> cos_x = grad(torch.sin)(x) +>>> cos_x = grad(lambda x: torch.sin(x))(x) >>> assert torch.allclose(cos_x, x.cos()) >>> >>> # Second-order gradients ->>> neg_sin_x = grad(grad(torch.sin))(x) +>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) >>> assert torch.allclose(neg_sin_x, -x.sin()) ``` diff --git a/functorch/functorch/_src/eager_transforms.py b/functorch/functorch/_src/eager_transforms.py index 7257630f9d84..cbb8a463622f 100644 --- a/functorch/functorch/_src/eager_transforms.py +++ b/functorch/functorch/_src/eager_transforms.py @@ -1,5 +1,5 @@ import torch -from functools import partial +from functools import partial, wraps import collections import torch.nn as nn import torch.nn.functional as F @@ -133,6 +133,7 @@ def grad_with_value(f, diff_argnums=(0,), has_aux=False): return wrapper def grad(f, diff_argnums=(0,), has_aux=False): + @wraps(f) def wrapper(*args): results = grad_with_value(f, diff_argnums, has_aux=has_aux)(*args) if has_aux: diff --git a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp index 1144f10a5395..58354110f733 100644 --- a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp +++ b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp @@ -79,6 +79,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, const Scalar&); using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&); using TensorScalarType = Tensor (*)(const Tensor&, const Scalar&); + using TensorScalarScalarType = Tensor (*)(const Tensor&, const Scalar&, const Scalar&); #define BINARY_POINTWISE_BATCH_RULE_SCALAR(op) \ binary_pointwise_batch_rule @@ -87,23 +88,25 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { #define BINARY_POINTWISE_BATCH_RULE(op) binary_pointwise_batch_rule #define BINARY_POINTWISE(op) VMAP_SUPPORT(#op".Tensor", BINARY_POINTWISE_BATCH_RULE(at::op)); +#define SINGLE_ARG(...) __VA_ARGS__ BINARY_POINTWISE_WITH_SCALAR(add); BINARY_POINTWISE_WITH_SCALAR(sub); BINARY_POINTWISE_WITH_SCALAR(rsub); BINARY_POINTWISE(mul); + VMAP_SUPPORT("add.Scalar", SINGLE_ARG(basic_unary_batch_rule)); + VMAP_SUPPORT("sub.Scalar", SINGLE_ARG(basic_unary_batch_rule)); + VMAP_SUPPORT("mul.Scalar", SINGLE_ARG(basic_unary_batch_rule)); + VMAP_SUPPORT("div.Scalar", SINGLE_ARG(basic_unary_batch_rule)); BINARY_POINTWISE(div); VMAP_SUPPORT("tanh_backward", BINARY_POINTWISE_BATCH_RULE(at::tanh_backward)); // at::pow has three out-of-place overloads -#define POW_BATCH_RULE binary_pointwise_batch_rule - VMAP_SUPPORT("pow.Tensor_Tensor", POW_BATCH_RULE); -#undef POW_BATCH_RULE -#define POW_BATCH_RULE basic_unary_batch_rule - VMAP_SUPPORT("pow.Tensor_Scalar", POW_BATCH_RULE); -#undef POW_BATCH_RULE + VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule)); + VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule)); VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule); +#undef SINGLE_ARG #undef BINARY_POINTWISE_BATCH_RULE_SCALAR #undef BINARY_POINTWISE_BATCH_RULE #undef BINARY_POINTWISE_WITH_SCALAR diff --git a/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp b/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp index fb62d11f5405..d22599d4d1a8 100644 --- a/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp +++ b/functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp @@ -21,9 +21,16 @@ slogdet_batch_rule(const Tensor& self, optional self_bdim) { }; } -TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { - VMAP_SUPPORT("slogdet", slogdet_batch_rule); +std::tuple> dot_batch_rule(const Tensor& A, optional A_bdim, const Tensor& B, optional B_bdim) { + auto A_ = moveBatchDimToFront(A, A_bdim); + auto B_ = moveBatchDimToFront(B, B_bdim); + return {at::matmul(A_, B_.t()), 0}; } + +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { + VMAP_SUPPORT("slogdet", slogdet_batch_rule); + VMAP_SUPPORT("dot", dot_batch_rule); +} }} diff --git a/functorch/functorch/csrc/BatchRulesViews.cpp b/functorch/functorch/csrc/BatchRulesViews.cpp index fad03ef3e6e2..d70b13fb76c7 100644 --- a/functorch/functorch/csrc/BatchRulesViews.cpp +++ b/functorch/functorch/csrc/BatchRulesViews.cpp @@ -27,7 +27,7 @@ namespace at { namespace functorch { // `self_bdim = 0`, and `dim = 0`. Note that there are **no BatchedTensors** // involved in this case; there exists some plumbing that automatically unwraps // BatchedTensors before calling the batch rule. -// +// // To write the logic of the batch rule: think about the semantics of the // `sum` operation if `self` had an additional dimension (indicated by self_bdim): // - If `self_bdim` is null, then we just do `result = self.sum(dim)` as usual @@ -47,15 +47,15 @@ namespace at { namespace functorch { // VMAP_SUPPORT("sum.int", sum_batch_rule); // ... // } -// +// // Note [Reusing batch rules to add vmap support for a complicated operator] // Can't figure out how to write a batch rule for a big operation? If the // operation can be expressed as a composition of other operations that do have // batch rules, then that is another way to add vmap support. For example, -// consider the following schema +// consider the following schema // func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) // and assume we already have batching rules for basic arithmetic operators. -// +// // To add vmap support, define a decomposition using the same signature: // Tensor addcmul_decomp(const Tensor& self, const Tensor& tensor1, // const Tensor& tensor2, const Scalar& value) { @@ -73,7 +73,7 @@ namespace at { namespace functorch { // TODO: This is kinda complicated. Saving this for a future date. std::tuple> flatten_batch_rule( - const Tensor& self, + const Tensor& self, optional self_bdim, int64_t start_dim, int64_t end_dim) { auto self_ = moveBatchDimToFront(self, self_bdim); @@ -86,7 +86,7 @@ std::tuple> unsqueeze_batch_rule( const Tensor& self, optional self_bdim, int64_t dim) { - auto self_ = moveBatchDimToFront(self, self_bdim); + auto self_ = moveBatchDimToFront(self, self_bdim); auto rank = rankWithoutBatchDim(self, self_bdim); dim = maybe_wrap_dim(dim, rank + 1); if (self_bdim) {