[functorch] Added triu/tril

This commit is contained in:
Horace He
2021-06-26 16:15:35 -07:00
committed by Jon Janzen
parent 0aedd9e8c1
commit b815b5bc6b
4 changed files with 20 additions and 2 deletions

View File

@ -38,6 +38,12 @@ std::tuple<Tensor,optional<int64_t>> basic_unary_batch_rule(
return std::make_tuple(Func(tensor, std::forward<ExtraArgs>(extra_args)...), batch_dim);
}
template <typename F, F Func, typename... ExtraArgs>
std::tuple<Tensor,optional<int64_t>> variadic_bdims_batch_rule(const Tensor& self, optional<int64_t> self_bdim, ExtraArgs... extra_args) {
auto self_ = moveBatchDimToFront(self, self_bdim);
return std::make_tuple(Func(self_, std::forward<ExtraArgs>(extra_args)...), self_bdim.has_value() ? optional<int64_t>{0} : nullopt);
}
#define INVOKE(object,ptrToMember) ((object).*(ptrToMember))
#define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));

View File

@ -149,6 +149,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
VMAP_SUPPORT("repeat", repeat_batch_rule);
VMAP_SUPPORT("diag", diag_batch_rule);
VMAP_SUPPORT("triu", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(triu)), &at::triu, int64_t>));
VMAP_SUPPORT("tril", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(tril)), &at::tril, int64_t>));
VMAP_SUPPORT("_unsafe_view", _unsafe_view_batch_rule);
}

View File

@ -82,8 +82,8 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
m.impl("random_.to", unsupportedRandomOp_<Tensor&, int64_t, optional<Generator>>);
m.impl("random_", unsupportedRandomOp_<Tensor&, optional<Generator>>);
m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, optional<MemoryFormat>>);
m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, optional<MemoryFormat>>);
// m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, optional<MemoryFormat>>);
// m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, optional<MemoryFormat>>);
m.impl("randint_like", unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, optional<MemoryFormat>>);
m.impl("randint_like.low_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, optional<MemoryFormat>>);

View File

@ -2901,6 +2901,16 @@ class TestVmapOperatorsOpInfo(TestCase):
x[x > 0] = float('inf')
test(self, op, (x,), in_dims=(0))
def test_foo_like(self, device):
test = functools.partial(_vmap_test, check_propagates_grad=False)
B, N, C, H, W = 2, 3, 24, 5, 7
for op in [torch.ones_like, torch.zeros_like, torch.randn_like, torch.rand_like]:
x = torch.randn(B, N, C, H, W)
# todo(chilli): test these better
# Not testing correctness, just that they run
vmap(op, in_dims=(0,))(x,)
@unittest.expectedFailure
def test_einsum(self, device):
test = functools.partial(_vmap_test, check_propagates_grad=False)