mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] _unsafe_view batch rule
This commit is contained in:
@ -139,12 +139,25 @@ std::tuple<Tensor,optional<int64_t>> diag_batch_rule(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>> _unsafe_view_batch_rule(
|
||||
const Tensor& self,
|
||||
optional<int64_t> self_bdim,
|
||||
IntArrayRef size) {
|
||||
if (!self_bdim) {
|
||||
return std::make_tuple(at::_unsafe_view(self, size), nullopt);
|
||||
}
|
||||
VmapDimVector view_size(size);
|
||||
view_size.insert(view_size.begin() + *self_bdim, self.size(*self_bdim));
|
||||
|
||||
return std::make_tuple(at::_unsafe_view(self, view_size), self_bdim);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT("flatten.using_ints", flatten_batch_rule);
|
||||
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
|
||||
VMAP_SUPPORT("repeat", repeat_batch_rule);
|
||||
VMAP_SUPPORT("diag", diag_batch_rule);
|
||||
VMAP_SUPPORT("_unsafe_view", _unsafe_view_batch_rule);
|
||||
}
|
||||
|
||||
}}
|
||||
|
@ -254,9 +254,11 @@ static void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* s
|
||||
|
||||
void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
auto& dynamicLayerStack = dynamicLayerStackAccessor();
|
||||
// if (c10::show_dispatch_trace_enabled()) {
|
||||
// std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl;
|
||||
// }
|
||||
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
|
||||
if (c10::show_dispatch_trace_enabled()) {
|
||||
std::cout << "DLS size: " << dynamicLayerStack.size() << std::endl;
|
||||
}
|
||||
#endif
|
||||
if (dynamicLayerStack.size() == 0) {
|
||||
sanityCheckStack(op, stack);
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(all_dynlayer_keyset);
|
||||
|
@ -23,6 +23,7 @@ from common_utils import (
|
||||
)
|
||||
import types
|
||||
|
||||
import functorch
|
||||
from functorch import vmap, functional_init_with_buffers
|
||||
from functorch._C import reshape_dim_into, reshape_dim_outof
|
||||
|
||||
@ -1448,6 +1449,19 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
||||
test(vmap(get_op(0), in_dims=(0, 0)),
|
||||
(torch.rand(B1, 2), torch.rand(B0, B1, 3)), in_dims=(None, 0))
|
||||
|
||||
def test_unsafe_view(self):
|
||||
# Unsafe view isn't exposed, so we get at it via
|
||||
# vmap(grad(matmul))
|
||||
test = functools.partial(self._vmap_test, check_propagates_grad=False)
|
||||
B = 2
|
||||
x = torch.randn(B, 2, 3, 3)
|
||||
y = torch.randn(B, 3, 3)
|
||||
|
||||
def baz(x, y):
|
||||
return (x @ y).sum()
|
||||
|
||||
test(functorch.grad(baz), (x, y))
|
||||
|
||||
def test_conj(self):
|
||||
op = torch.conj
|
||||
|
||||
|
Reference in New Issue
Block a user