[functorch] _unsafe_view batch rule

This commit is contained in:
Richard Zou
2021-06-23 09:19:31 -07:00
committed by Jon Janzen
parent bea0df36c2
commit 1666d90161
3 changed files with 32 additions and 3 deletions

View File

@ -139,12 +139,25 @@ std::tuple<Tensor,optional<int64_t>> diag_batch_rule(
}
}
std::tuple<Tensor,optional<int64_t>> _unsafe_view_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
IntArrayRef size) {
if (!self_bdim) {
return std::make_tuple(at::_unsafe_view(self, size), nullopt);
}
VmapDimVector view_size(size);
view_size.insert(view_size.begin() + *self_bdim, self.size(*self_bdim));
return std::make_tuple(at::_unsafe_view(self, view_size), self_bdim);
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("flatten.using_ints", flatten_batch_rule);
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
VMAP_SUPPORT("repeat", repeat_batch_rule);
VMAP_SUPPORT("diag", diag_batch_rule);
VMAP_SUPPORT("_unsafe_view", _unsafe_view_batch_rule);
}
}}

View File

@ -254,9 +254,11 @@ static void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* s
void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
// if (c10::show_dispatch_trace_enabled()) {
// std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl;
// }
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
if (c10::show_dispatch_trace_enabled()) {
std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl;
}
#endif
if (dynamicLayerStack.size() == 0) {
sanityCheckStack(op, stack);
c10::impl::ExcludeDispatchKeyGuard guard(all_dynlayer_keyset);

View File

@ -23,6 +23,7 @@ from common_utils import (
)
import types
import functorch
from functorch import vmap, functional_init_with_buffers
from functorch._C import reshape_dim_into, reshape_dim_outof
@ -1448,6 +1449,19 @@ class TestVmapOperators(Namespace.TestVmapBase):
test(vmap(get_op(0), in_dims=(0, 0)),
(torch.rand(B1, 2), torch.rand(B0, B1, 3)), in_dims=(None, 0))
def test_unsafe_view(self):
# Unsafe view isn't exposed, so we get at it via
# vmap(grad(matmul))
test = functools.partial(self._vmap_test, check_propagates_grad=False)
B = 2
x = torch.randn(B, 2, 3, 3)
y = torch.randn(B, 3, 3)
def baz(x, y):
return (x @ y).sum()
test(functorch.grad(baz), (x, y))
def test_conj(self):
op = torch.conj