[functorch] new_blah hack

This commit is contained in:
Richard Zou
2021-04-23 12:29:29 -07:00
committed by Jon Janzen
parent 8201dfc2d5
commit 985b35c23d
4 changed files with 180 additions and 5 deletions

View File

@ -1499,6 +1499,107 @@ std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
return { std::move(result), std::move(result_batch_dim) };
}
Tensor matmul_decomposed(
const Tensor& tensor1,
const Tensor& tensor2) {
auto dim_tensor1 = tensor1.dim();
auto dim_tensor2 = tensor2.dim();
if (dim_tensor1 == 1 && dim_tensor2 == 1) {
return tensor1.dot(tensor2);
} else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
return tensor1.mv(tensor2);
} else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
return tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
} else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
return tensor1.mm(tensor2);
} else if (dim_tensor1 >= 3 && (dim_tensor2 == 1 || dim_tensor2 == 2)) {
// optimization: use mm instead of bmm by folding tensor1's batch into
// its leading matrix dimension.
Tensor t2 = dim_tensor2 == 1 ? tensor2.unsqueeze(-1) : tensor2;
auto size1 = tensor1.sizes();
auto size2 = t2.sizes();
std::vector<int64_t> output_size;
output_size.insert(output_size.end(), size1.begin(), size1.end() - 1);
if (dim_tensor2 > 1) {
output_size.push_back(size2[dim_tensor2 - 1]);
}
// fold the batch into the first dimension
Tensor t1 = tensor1.reshape({-1, size1[size1.size() - 1]});
Tensor output = t1.mm(t2).view(output_size);
return output;
} else if ((dim_tensor1 == 1 || dim_tensor1 == 2) && dim_tensor2 >= 3) {
// optimization: transpose the inner dimensions of the arguments, call
// matmul on the swapped arguments, then transpose the inner dimensions
// of the result.
const int64_t n = dim_tensor1 == 2 ? tensor1.size(-2) : 1;
const int64_t m = tensor1.size(-1);
const int64_t p = tensor2.size(-1);
const Tensor t2_T = tensor2.transpose(-1, -2);
const Tensor t1_T = dim_tensor1 == 2 ? tensor1.t() : tensor1.reshape({n, m}).t();
const Tensor res_T = at::matmul(t2_T, t1_T);
if (dim_tensor1 == 2) {
return res_T.transpose(-1, -2);
}
else {
std::vector<int64_t> shape = tensor2.sizes().slice(0, dim_tensor2 - 2).vec();
shape.push_back(p);
Tensor res = res_T.reshape(shape);
return res;
}
} else if ((dim_tensor1 >= 1 && dim_tensor2 >= 1) && (dim_tensor1 >= 3 || dim_tensor2 >= 3)) {
// We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
// we track m1 vs m2 separately even though they must match for nicer error messages
int64_t n = dim_tensor1 > 1 ? tensor1.size(-2) : 1;
int64_t m1 = tensor1.size(-1);
IntArrayRef batch_tensor1(tensor1.sizes().data(), std::max<int64_t>(dim_tensor1 - 2, 0));
int64_t m2 = dim_tensor2 > 1 ? tensor2.size(-2) : 1;
int64_t p = tensor2.size(-1);
IntArrayRef batch_tensor2(tensor2.sizes().data(), std::max<int64_t>(dim_tensor2 - 2, 0));
// expand the batch portion (i.e. cut off matrix dimensions and expand rest)
std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
tensor1_expand_size.insert(tensor1_expand_size.end(), {n, m1});
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
tensor2_expand_size.insert(tensor2_expand_size.end(), {m2, p});
const int64_t expand_batch_product =
c10::multiply_integers(expand_batch_portion);
std::vector<int64_t> tensor1_bmm_view({expand_batch_product});
tensor1_bmm_view.insert(tensor1_bmm_view.end(), {n, m1});
std::vector<int64_t> tensor2_bmm_view({expand_batch_product});
tensor2_bmm_view.insert(tensor2_bmm_view.end(), {m2, p});
// flatten expanded batches
Tensor tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape(tensor1_bmm_view);
Tensor tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape(tensor2_bmm_view);
// reshape batches back into result
std::vector<int64_t> output_shape(expand_batch_portion);
if (dim_tensor1 > 1) {
output_shape.push_back(n);
}
if (dim_tensor2 > 1) {
output_shape.push_back(p);
}
Tensor output = tensor1_expanded.bmm(tensor2_expanded).view(output_shape);
return output;
}
AT_ERROR("both arguments to matmul need to be at least 1D, but they are ",
dim_tensor1, "D and ", dim_tensor2, "D");
}
TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
@ -1574,6 +1675,7 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
m.impl("view_as", native::view_as); // composite wrt autograd
// m.impl("addmm", addmm_batching_rule);
m.impl("matmul", matmul_decomposed);
//
// clamp operations
// m.impl("clamp", clamp_batching_rule);

View File

@ -0,0 +1,58 @@
#include <functorch/csrc/DynamicLayer.h>
#include <functorch/csrc/Constants.h>
#include <torch/library.h>
#include <ATen/ATen.h>
#include <ATen/core/dispatch/Dispatcher.h>
namespace at { namespace functorch {
#define NEW_BLAH_HACK(new_blah) \
static Tensor new_blah##_hack( \
const Tensor& self, \
IntArrayRef size, \
c10::optional<ScalarType> dtype, \
c10::optional<Layout> layout, \
c10::optional<Device> device, \
c10::optional<bool> pin_memory \
) { \
static auto op = c10::Dispatcher::singleton() \
.findSchemaOrThrow("functorch::"#new_blah"_hack", "") \
.typed<decltype(new_blah##_hack)>(); \
return op.call(self, size, dtype, layout, device, pin_memory); \
} \
static Tensor new_blah##_hack_impl( \
const Tensor& self, \
IntArrayRef size, \
c10::optional<ScalarType> dtype, \
c10::optional<Layout> layout, \
c10::optional<Device> device, \
c10::optional<bool> pin_memory \
) { \
auto layer = maybeCurrentDynamicLayer(); \
if (!layer.has_value()) { \
return self.new_blah(size, dtype, layout, device, pin_memory); \
} \
AutoNonVariableTypeMode dispatch_after_grad_guard; \
c10::impl::ExcludeDispatchKeyGuard dispatch_after_vmap_guard(kBatchedKey); \
return new_blah##_hack(self, size, dtype, layout, device, pin_memory); \
}
NEW_BLAH_HACK(new_zeros);
NEW_BLAH_HACK(new_empty);
#undef NEW_BLAH_HACK
TORCH_LIBRARY(functorch, m) {
m.def("new_zeros_hack", new_zeros_hack_impl);
m.def("new_empty_hack", new_empty_hack_impl);
}
TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) {
m.impl("new_zeros", new_zeros_hack);
m.impl("new_empty", new_empty_hack);
}
}}

View File

@ -30,9 +30,9 @@ def get_extensions():
define_macros = []
extra_link_args = []
# extra_compile_args = {"cxx": ["-O3", "-g", "-std=c++14"]}
# if int(os.environ.get("DEBUG", 0)):
if True:
extra_compile_args = {"cxx": ["-O3", "-g", "-std=c++14"]}
if int(os.environ.get("DEBUG", 0)):
# if True:
extra_compile_args = {
"cxx": ["-O0", "-fno-inline", "-g", "-std=c++14"]}
extra_link_args = ["-O0", "-g"]

View File

@ -254,14 +254,29 @@ class TestVmapOfGrad(TestCase):
N = 3
C = 5
def foo(x, y):
def foo(y, x):
result = x.new_zeros((C,))
result.copy_(y)
return result.sum()
x = torch.randn(N, device=device)
y = torch.randn(N, C, device=device)
result = vmap(grad(foo))(x, y)
result = vmap(grad(foo))(y, x)
self.assertEqual(result, torch.ones_like(y))
def test_new_empty_materializes_tensor(self, device):
N = 3
C = 5
def foo(y, x):
result = x.new_empty((C,))
result.copy_(y)
return result.sum()
x = torch.randn(N, device=device)
y = torch.randn(N, C, device=device)
result = vmap(grad(foo))(y, x)
self.assertEqual(result, torch.ones_like(y))
def test_per_sample_grads_simple(self, device):
def compute_loss(weight, x, t):