From 985b35c23d4766415f0c0044c574549013a4201d Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Fri, 23 Apr 2021 12:29:29 -0700 Subject: [PATCH] [functorch] new_blah hack --- .../functorch/csrc/BatchingRegistrations.cpp | 102 ++++++++++++++++++ functorch/functorch/csrc/new_blah_hacks.cpp | 58 ++++++++++ functorch/setup.py | 6 +- functorch/test/test_eager_transforms.py | 19 +++- 4 files changed, 180 insertions(+), 5 deletions(-) create mode 100644 functorch/functorch/csrc/new_blah_hacks.cpp diff --git a/functorch/functorch/csrc/BatchingRegistrations.cpp b/functorch/functorch/csrc/BatchingRegistrations.cpp index afac8fd4cfde..e49095ce80f3 100644 --- a/functorch/functorch/csrc/BatchingRegistrations.cpp +++ b/functorch/functorch/csrc/BatchingRegistrations.cpp @@ -1499,6 +1499,107 @@ std::tuple> binary_pointwise_batch_rule( return { std::move(result), std::move(result_batch_dim) }; } +Tensor matmul_decomposed( + const Tensor& tensor1, + const Tensor& tensor2) { + auto dim_tensor1 = tensor1.dim(); + auto dim_tensor2 = tensor2.dim(); + + if (dim_tensor1 == 1 && dim_tensor2 == 1) { + return tensor1.dot(tensor2); + } else if (dim_tensor1 == 2 && dim_tensor2 == 1) { + return tensor1.mv(tensor2); + } else if (dim_tensor1 == 1 && dim_tensor2 == 2) { + return tensor1.unsqueeze(0).mm(tensor2).squeeze_(0); + } else if (dim_tensor1 == 2 && dim_tensor2 == 2) { + return tensor1.mm(tensor2); + } else if (dim_tensor1 >= 3 && (dim_tensor2 == 1 || dim_tensor2 == 2)) { + // optimization: use mm instead of bmm by folding tensor1's batch into + // its leading matrix dimension. + + Tensor t2 = dim_tensor2 == 1 ? tensor2.unsqueeze(-1) : tensor2; + auto size1 = tensor1.sizes(); + auto size2 = t2.sizes(); + std::vector output_size; + output_size.insert(output_size.end(), size1.begin(), size1.end() - 1); + if (dim_tensor2 > 1) { + output_size.push_back(size2[dim_tensor2 - 1]); + } + + // fold the batch into the first dimension + Tensor t1 = tensor1.reshape({-1, size1[size1.size() - 1]}); + Tensor output = t1.mm(t2).view(output_size); + return output; + } else if ((dim_tensor1 == 1 || dim_tensor1 == 2) && dim_tensor2 >= 3) { + // optimization: transpose the inner dimensions of the arguments, call + // matmul on the swapped arguments, then transpose the inner dimensions + // of the result. + const int64_t n = dim_tensor1 == 2 ? tensor1.size(-2) : 1; + const int64_t m = tensor1.size(-1); + const int64_t p = tensor2.size(-1); + + const Tensor t2_T = tensor2.transpose(-1, -2); + const Tensor t1_T = dim_tensor1 == 2 ? tensor1.t() : tensor1.reshape({n, m}).t(); + const Tensor res_T = at::matmul(t2_T, t1_T); + + if (dim_tensor1 == 2) { + return res_T.transpose(-1, -2); + } + else { + std::vector shape = tensor2.sizes().slice(0, dim_tensor2 - 2).vec(); + shape.push_back(p); + + Tensor res = res_T.reshape(shape); + return res; + } + } else if ((dim_tensor1 >= 1 && dim_tensor2 >= 1) && (dim_tensor1 >= 3 || dim_tensor2 >= 3)) { + // We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list); + // we track m1 vs m2 separately even though they must match for nicer error messages + int64_t n = dim_tensor1 > 1 ? tensor1.size(-2) : 1; + int64_t m1 = tensor1.size(-1); + IntArrayRef batch_tensor1(tensor1.sizes().data(), std::max(dim_tensor1 - 2, 0)); + int64_t m2 = dim_tensor2 > 1 ? tensor2.size(-2) : 1; + int64_t p = tensor2.size(-1); + IntArrayRef batch_tensor2(tensor2.sizes().data(), std::max(dim_tensor2 - 2, 0)); + + // expand the batch portion (i.e. cut off matrix dimensions and expand rest) + std::vector expand_batch_portion = infer_size(batch_tensor1, batch_tensor2); + + std::vector tensor1_expand_size(expand_batch_portion); + tensor1_expand_size.insert(tensor1_expand_size.end(), {n, m1}); + + std::vector tensor2_expand_size(expand_batch_portion); + tensor2_expand_size.insert(tensor2_expand_size.end(), {m2, p}); + + const int64_t expand_batch_product = + c10::multiply_integers(expand_batch_portion); + + std::vector tensor1_bmm_view({expand_batch_product}); + tensor1_bmm_view.insert(tensor1_bmm_view.end(), {n, m1}); + + std::vector tensor2_bmm_view({expand_batch_product}); + tensor2_bmm_view.insert(tensor2_bmm_view.end(), {m2, p}); + + // flatten expanded batches + Tensor tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape(tensor1_bmm_view); + Tensor tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape(tensor2_bmm_view); + + // reshape batches back into result + std::vector output_shape(expand_batch_portion); + if (dim_tensor1 > 1) { + output_shape.push_back(n); + } + if (dim_tensor2 > 1) { + output_shape.push_back(p); + } + + Tensor output = tensor1_expanded.bmm(tensor2_expanded).view(output_shape); + return output; + } + + AT_ERROR("both arguments to matmul need to be at least 1D, but they are ", + dim_tensor1, "D and ", dim_tensor2, "D"); +} TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) { @@ -1574,6 +1675,7 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) { m.impl("view_as", native::view_as); // composite wrt autograd // m.impl("addmm", addmm_batching_rule); + m.impl("matmul", matmul_decomposed); // // clamp operations // m.impl("clamp", clamp_batching_rule); diff --git a/functorch/functorch/csrc/new_blah_hacks.cpp b/functorch/functorch/csrc/new_blah_hacks.cpp new file mode 100644 index 000000000000..696a1eed6f81 --- /dev/null +++ b/functorch/functorch/csrc/new_blah_hacks.cpp @@ -0,0 +1,58 @@ +#include +#include + +#include +#include +#include + +namespace at { namespace functorch { + +#define NEW_BLAH_HACK(new_blah) \ + static Tensor new_blah##_hack( \ + const Tensor& self, \ + IntArrayRef size, \ + c10::optional dtype, \ + c10::optional layout, \ + c10::optional device, \ + c10::optional pin_memory \ + ) { \ + static auto op = c10::Dispatcher::singleton() \ + .findSchemaOrThrow("functorch::"#new_blah"_hack", "") \ + .typed(); \ + return op.call(self, size, dtype, layout, device, pin_memory); \ + } \ + static Tensor new_blah##_hack_impl( \ + const Tensor& self, \ + IntArrayRef size, \ + c10::optional dtype, \ + c10::optional layout, \ + c10::optional device, \ + c10::optional pin_memory \ + ) { \ + auto layer = maybeCurrentDynamicLayer(); \ + if (!layer.has_value()) { \ + return self.new_blah(size, dtype, layout, device, pin_memory); \ + } \ + AutoNonVariableTypeMode dispatch_after_grad_guard; \ + c10::impl::ExcludeDispatchKeyGuard dispatch_after_vmap_guard(kBatchedKey); \ + return new_blah##_hack(self, size, dtype, layout, device, pin_memory); \ + } + +NEW_BLAH_HACK(new_zeros); +NEW_BLAH_HACK(new_empty); + +#undef NEW_BLAH_HACK + +TORCH_LIBRARY(functorch, m) { + m.def("new_zeros_hack", new_zeros_hack_impl); + m.def("new_empty_hack", new_empty_hack_impl); +} + +TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) { + m.impl("new_zeros", new_zeros_hack); + m.impl("new_empty", new_empty_hack); +} + + +}} + diff --git a/functorch/setup.py b/functorch/setup.py index b2f84d6da6c0..443ef4b027ce 100644 --- a/functorch/setup.py +++ b/functorch/setup.py @@ -30,9 +30,9 @@ def get_extensions(): define_macros = [] extra_link_args = [] - # extra_compile_args = {"cxx": ["-O3", "-g", "-std=c++14"]} - # if int(os.environ.get("DEBUG", 0)): - if True: + extra_compile_args = {"cxx": ["-O3", "-g", "-std=c++14"]} + if int(os.environ.get("DEBUG", 0)): + # if True: extra_compile_args = { "cxx": ["-O0", "-fno-inline", "-g", "-std=c++14"]} extra_link_args = ["-O0", "-g"] diff --git a/functorch/test/test_eager_transforms.py b/functorch/test/test_eager_transforms.py index b9686be08b23..2883214fefb9 100644 --- a/functorch/test/test_eager_transforms.py +++ b/functorch/test/test_eager_transforms.py @@ -254,14 +254,29 @@ class TestVmapOfGrad(TestCase): N = 3 C = 5 - def foo(x, y): + def foo(y, x): result = x.new_zeros((C,)) result.copy_(y) return result.sum() x = torch.randn(N, device=device) y = torch.randn(N, C, device=device) - result = vmap(grad(foo))(x, y) + result = vmap(grad(foo))(y, x) + self.assertEqual(result, torch.ones_like(y)) + + def test_new_empty_materializes_tensor(self, device): + N = 3 + C = 5 + + def foo(y, x): + result = x.new_empty((C,)) + result.copy_(y) + return result.sum() + + x = torch.randn(N, device=device) + y = torch.randn(N, C, device=device) + result = vmap(grad(foo))(y, x) + self.assertEqual(result, torch.ones_like(y)) def test_per_sample_grads_simple(self, device): def compute_loss(weight, x, t):