mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] new_blah hack
This commit is contained in:
@ -1499,6 +1499,107 @@ std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
|
||||
return { std::move(result), std::move(result_batch_dim) };
|
||||
}
|
||||
|
||||
Tensor matmul_decomposed(
|
||||
const Tensor& tensor1,
|
||||
const Tensor& tensor2) {
|
||||
auto dim_tensor1 = tensor1.dim();
|
||||
auto dim_tensor2 = tensor2.dim();
|
||||
|
||||
if (dim_tensor1 == 1 && dim_tensor2 == 1) {
|
||||
return tensor1.dot(tensor2);
|
||||
} else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
|
||||
return tensor1.mv(tensor2);
|
||||
} else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
|
||||
return tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
|
||||
} else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
|
||||
return tensor1.mm(tensor2);
|
||||
} else if (dim_tensor1 >= 3 && (dim_tensor2 == 1 || dim_tensor2 == 2)) {
|
||||
// optimization: use mm instead of bmm by folding tensor1's batch into
|
||||
// its leading matrix dimension.
|
||||
|
||||
Tensor t2 = dim_tensor2 == 1 ? tensor2.unsqueeze(-1) : tensor2;
|
||||
auto size1 = tensor1.sizes();
|
||||
auto size2 = t2.sizes();
|
||||
std::vector<int64_t> output_size;
|
||||
output_size.insert(output_size.end(), size1.begin(), size1.end() - 1);
|
||||
if (dim_tensor2 > 1) {
|
||||
output_size.push_back(size2[dim_tensor2 - 1]);
|
||||
}
|
||||
|
||||
// fold the batch into the first dimension
|
||||
Tensor t1 = tensor1.reshape({-1, size1[size1.size() - 1]});
|
||||
Tensor output = t1.mm(t2).view(output_size);
|
||||
return output;
|
||||
} else if ((dim_tensor1 == 1 || dim_tensor1 == 2) && dim_tensor2 >= 3) {
|
||||
// optimization: transpose the inner dimensions of the arguments, call
|
||||
// matmul on the swapped arguments, then transpose the inner dimensions
|
||||
// of the result.
|
||||
const int64_t n = dim_tensor1 == 2 ? tensor1.size(-2) : 1;
|
||||
const int64_t m = tensor1.size(-1);
|
||||
const int64_t p = tensor2.size(-1);
|
||||
|
||||
const Tensor t2_T = tensor2.transpose(-1, -2);
|
||||
const Tensor t1_T = dim_tensor1 == 2 ? tensor1.t() : tensor1.reshape({n, m}).t();
|
||||
const Tensor res_T = at::matmul(t2_T, t1_T);
|
||||
|
||||
if (dim_tensor1 == 2) {
|
||||
return res_T.transpose(-1, -2);
|
||||
}
|
||||
else {
|
||||
std::vector<int64_t> shape = tensor2.sizes().slice(0, dim_tensor2 - 2).vec();
|
||||
shape.push_back(p);
|
||||
|
||||
Tensor res = res_T.reshape(shape);
|
||||
return res;
|
||||
}
|
||||
} else if ((dim_tensor1 >= 1 && dim_tensor2 >= 1) && (dim_tensor1 >= 3 || dim_tensor2 >= 3)) {
|
||||
// We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
|
||||
// we track m1 vs m2 separately even though they must match for nicer error messages
|
||||
int64_t n = dim_tensor1 > 1 ? tensor1.size(-2) : 1;
|
||||
int64_t m1 = tensor1.size(-1);
|
||||
IntArrayRef batch_tensor1(tensor1.sizes().data(), std::max<int64_t>(dim_tensor1 - 2, 0));
|
||||
int64_t m2 = dim_tensor2 > 1 ? tensor2.size(-2) : 1;
|
||||
int64_t p = tensor2.size(-1);
|
||||
IntArrayRef batch_tensor2(tensor2.sizes().data(), std::max<int64_t>(dim_tensor2 - 2, 0));
|
||||
|
||||
// expand the batch portion (i.e. cut off matrix dimensions and expand rest)
|
||||
std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
|
||||
|
||||
std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
|
||||
tensor1_expand_size.insert(tensor1_expand_size.end(), {n, m1});
|
||||
|
||||
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
|
||||
tensor2_expand_size.insert(tensor2_expand_size.end(), {m2, p});
|
||||
|
||||
const int64_t expand_batch_product =
|
||||
c10::multiply_integers(expand_batch_portion);
|
||||
|
||||
std::vector<int64_t> tensor1_bmm_view({expand_batch_product});
|
||||
tensor1_bmm_view.insert(tensor1_bmm_view.end(), {n, m1});
|
||||
|
||||
std::vector<int64_t> tensor2_bmm_view({expand_batch_product});
|
||||
tensor2_bmm_view.insert(tensor2_bmm_view.end(), {m2, p});
|
||||
|
||||
// flatten expanded batches
|
||||
Tensor tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape(tensor1_bmm_view);
|
||||
Tensor tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape(tensor2_bmm_view);
|
||||
|
||||
// reshape batches back into result
|
||||
std::vector<int64_t> output_shape(expand_batch_portion);
|
||||
if (dim_tensor1 > 1) {
|
||||
output_shape.push_back(n);
|
||||
}
|
||||
if (dim_tensor2 > 1) {
|
||||
output_shape.push_back(p);
|
||||
}
|
||||
|
||||
Tensor output = tensor1_expanded.bmm(tensor2_expanded).view(output_shape);
|
||||
return output;
|
||||
}
|
||||
|
||||
AT_ERROR("both arguments to matmul need to be at least 1D, but they are ",
|
||||
dim_tensor1, "D and ", dim_tensor2, "D");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
|
||||
|
||||
@ -1574,6 +1675,7 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
|
||||
m.impl("view_as", native::view_as); // composite wrt autograd
|
||||
|
||||
// m.impl("addmm", addmm_batching_rule);
|
||||
m.impl("matmul", matmul_decomposed);
|
||||
//
|
||||
// clamp operations
|
||||
// m.impl("clamp", clamp_batching_rule);
|
||||
|
58
functorch/functorch/csrc/new_blah_hacks.cpp
Normal file
58
functorch/functorch/csrc/new_blah_hacks.cpp
Normal file
@ -0,0 +1,58 @@
|
||||
#include <functorch/csrc/DynamicLayer.h>
|
||||
#include <functorch/csrc/Constants.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
#define NEW_BLAH_HACK(new_blah) \
|
||||
static Tensor new_blah##_hack( \
|
||||
const Tensor& self, \
|
||||
IntArrayRef size, \
|
||||
c10::optional<ScalarType> dtype, \
|
||||
c10::optional<Layout> layout, \
|
||||
c10::optional<Device> device, \
|
||||
c10::optional<bool> pin_memory \
|
||||
) { \
|
||||
static auto op = c10::Dispatcher::singleton() \
|
||||
.findSchemaOrThrow("functorch::"#new_blah"_hack", "") \
|
||||
.typed<decltype(new_blah##_hack)>(); \
|
||||
return op.call(self, size, dtype, layout, device, pin_memory); \
|
||||
} \
|
||||
static Tensor new_blah##_hack_impl( \
|
||||
const Tensor& self, \
|
||||
IntArrayRef size, \
|
||||
c10::optional<ScalarType> dtype, \
|
||||
c10::optional<Layout> layout, \
|
||||
c10::optional<Device> device, \
|
||||
c10::optional<bool> pin_memory \
|
||||
) { \
|
||||
auto layer = maybeCurrentDynamicLayer(); \
|
||||
if (!layer.has_value()) { \
|
||||
return self.new_blah(size, dtype, layout, device, pin_memory); \
|
||||
} \
|
||||
AutoNonVariableTypeMode dispatch_after_grad_guard; \
|
||||
c10::impl::ExcludeDispatchKeyGuard dispatch_after_vmap_guard(kBatchedKey); \
|
||||
return new_blah##_hack(self, size, dtype, layout, device, pin_memory); \
|
||||
}
|
||||
|
||||
NEW_BLAH_HACK(new_zeros);
|
||||
NEW_BLAH_HACK(new_empty);
|
||||
|
||||
#undef NEW_BLAH_HACK
|
||||
|
||||
TORCH_LIBRARY(functorch, m) {
|
||||
m.def("new_zeros_hack", new_zeros_hack_impl);
|
||||
m.def("new_empty_hack", new_empty_hack_impl);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, DynamicLayerFront, m) {
|
||||
m.impl("new_zeros", new_zeros_hack);
|
||||
m.impl("new_empty", new_empty_hack);
|
||||
}
|
||||
|
||||
|
||||
}}
|
||||
|
@ -30,9 +30,9 @@ def get_extensions():
|
||||
define_macros = []
|
||||
|
||||
extra_link_args = []
|
||||
# extra_compile_args = {"cxx": ["-O3", "-g", "-std=c++14"]}
|
||||
# if int(os.environ.get("DEBUG", 0)):
|
||||
if True:
|
||||
extra_compile_args = {"cxx": ["-O3", "-g", "-std=c++14"]}
|
||||
if int(os.environ.get("DEBUG", 0)):
|
||||
# if True:
|
||||
extra_compile_args = {
|
||||
"cxx": ["-O0", "-fno-inline", "-g", "-std=c++14"]}
|
||||
extra_link_args = ["-O0", "-g"]
|
||||
|
@ -254,14 +254,29 @@ class TestVmapOfGrad(TestCase):
|
||||
N = 3
|
||||
C = 5
|
||||
|
||||
def foo(x, y):
|
||||
def foo(y, x):
|
||||
result = x.new_zeros((C,))
|
||||
result.copy_(y)
|
||||
return result.sum()
|
||||
|
||||
x = torch.randn(N, device=device)
|
||||
y = torch.randn(N, C, device=device)
|
||||
result = vmap(grad(foo))(x, y)
|
||||
result = vmap(grad(foo))(y, x)
|
||||
self.assertEqual(result, torch.ones_like(y))
|
||||
|
||||
def test_new_empty_materializes_tensor(self, device):
|
||||
N = 3
|
||||
C = 5
|
||||
|
||||
def foo(y, x):
|
||||
result = x.new_empty((C,))
|
||||
result.copy_(y)
|
||||
return result.sum()
|
||||
|
||||
x = torch.randn(N, device=device)
|
||||
y = torch.randn(N, C, device=device)
|
||||
result = vmap(grad(foo))(y, x)
|
||||
self.assertEqual(result, torch.ones_like(y))
|
||||
|
||||
def test_per_sample_grads_simple(self, device):
|
||||
def compute_loss(weight, x, t):
|
||||
|
Reference in New Issue
Block a user