mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Added triu/tril
This commit is contained in:
@ -38,6 +38,12 @@ std::tuple<Tensor,optional<int64_t>> basic_unary_batch_rule(
|
||||
return std::make_tuple(Func(tensor, std::forward<ExtraArgs>(extra_args)...), batch_dim);
|
||||
}
|
||||
|
||||
template <typename F, F Func, typename... ExtraArgs>
|
||||
std::tuple<Tensor,optional<int64_t>> variadic_bdims_batch_rule(const Tensor& self, optional<int64_t> self_bdim, ExtraArgs... extra_args) {
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
return std::make_tuple(Func(self_, std::forward<ExtraArgs>(extra_args)...), self_bdim.has_value() ? optional<int64_t>{0} : nullopt);
|
||||
}
|
||||
|
||||
#define INVOKE(object,ptrToMember) ((object).*(ptrToMember))
|
||||
#define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
|
||||
#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
|
||||
|
@ -149,6 +149,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
|
||||
VMAP_SUPPORT("repeat", repeat_batch_rule);
|
||||
VMAP_SUPPORT("diag", diag_batch_rule);
|
||||
VMAP_SUPPORT("triu", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(triu)), &at::triu, int64_t>));
|
||||
VMAP_SUPPORT("tril", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(tril)), &at::tril, int64_t>));
|
||||
VMAP_SUPPORT("_unsafe_view", _unsafe_view_batch_rule);
|
||||
}
|
||||
|
||||
|
@ -82,8 +82,8 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
|
||||
m.impl("random_.to", unsupportedRandomOp_<Tensor&, int64_t, optional<Generator>>);
|
||||
m.impl("random_", unsupportedRandomOp_<Tensor&, optional<Generator>>);
|
||||
|
||||
m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, optional<MemoryFormat>>);
|
||||
m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, optional<MemoryFormat>>);
|
||||
// m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, optional<MemoryFormat>>);
|
||||
// m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, optional<MemoryFormat>>);
|
||||
|
||||
m.impl("randint_like", unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, optional<MemoryFormat>>);
|
||||
m.impl("randint_like.low_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, optional<MemoryFormat>>);
|
||||
|
@ -2901,6 +2901,16 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
x[x > 0] = float('inf')
|
||||
test(self, op, (x,), in_dims=(0))
|
||||
|
||||
def test_foo_like(self, device):
|
||||
test = functools.partial(_vmap_test, check_propagates_grad=False)
|
||||
|
||||
B, N, C, H, W = 2, 3, 24, 5, 7
|
||||
for op in [torch.ones_like, torch.zeros_like, torch.randn_like, torch.rand_like]:
|
||||
x = torch.randn(B, N, C, H, W)
|
||||
# todo(chilli): test these better
|
||||
# Not testing correctness, just that they run
|
||||
vmap(op, in_dims=(0,))(x,)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_einsum(self, device):
|
||||
test = functools.partial(_vmap_test, check_propagates_grad=False)
|
||||
|
Reference in New Issue
Block a user