[functorch] Added batching rule for logsumexp

This commit is contained in:
Horace He
2021-06-27 01:03:39 -07:00
committed by Jon Janzen
parent 9c138786b7
commit b20e4decc4

View File

@ -309,6 +309,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("cumprod", SINGLE_ARG(reduction_dim_batch_rule<decltype(&ATEN_FN(cumprod)), &at::cumprod, optional<ScalarType>>));
VMAP_SUPPORT("cumsum", SINGLE_ARG(reduction_dim_batch_rule<decltype(&ATEN_FN(cumsum)), &at::cumsum, optional<ScalarType>>));
VMAP_SUPPORT("log_softmax.int", SINGLE_ARG(reduction_dim_batch_rule<decltype(&ATEN_FN2(log_softmax, int)), &at::log_softmax, optional<ScalarType>>));
VMAP_SUPPORT("logsumexp", SINGLE_ARG(reduction_dimarray_batch_rule<decltype(&ATEN_FN(logsumexp)), &at::logsumexp, bool>));
VMAP_SUPPORT("nansum", nansum_batch_rule);
VMAP_SUPPORT("nansum.dim_IntList", nansum_dim_batch_rule);
VMAP_SUPPORT("max", SINGLE_ARG(reduction_no_dim_batch_rule<decltype(&ATEN_FN(max)), &at::max, decltype(&max_dim_batch_rule), &max_dim_batch_rule>));