mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Added batching rule for logsumexp
This commit is contained in:
@ -309,6 +309,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT("cumprod", SINGLE_ARG(reduction_dim_batch_rule<decltype(&ATEN_FN(cumprod)), &at::cumprod, optional<ScalarType>>));
|
||||
VMAP_SUPPORT("cumsum", SINGLE_ARG(reduction_dim_batch_rule<decltype(&ATEN_FN(cumsum)), &at::cumsum, optional<ScalarType>>));
|
||||
VMAP_SUPPORT("log_softmax.int", SINGLE_ARG(reduction_dim_batch_rule<decltype(&ATEN_FN2(log_softmax, int)), &at::log_softmax, optional<ScalarType>>));
|
||||
VMAP_SUPPORT("logsumexp", SINGLE_ARG(reduction_dimarray_batch_rule<decltype(&ATEN_FN(logsumexp)), &at::logsumexp, bool>));
|
||||
VMAP_SUPPORT("nansum", nansum_batch_rule);
|
||||
VMAP_SUPPORT("nansum.dim_IntList", nansum_dim_batch_rule);
|
||||
VMAP_SUPPORT("max", SINGLE_ARG(reduction_no_dim_batch_rule<decltype(&ATEN_FN(max)), &at::max, decltype(&max_dim_batch_rule), &max_dim_batch_rule>));
|
||||
|
Reference in New Issue
Block a user