[functorch] Some more batch rules for pointwise ops

This commit is contained in:
Richard Zou
2021-07-20 17:30:59 -07:00
committed by Jon Janzen
parent 983a43cfc9
commit d506951937
2 changed files with 19 additions and 0 deletions

View File

@ -212,6 +212,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("threshold_backward", SINGLE_ARG(
binary_pointwise_batch_rule<decltype(&at::threshold_backward), &at::threshold_backward, const Scalar&>));
VMAP_SUPPORT("fmin", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(fmin)), &at::fmin>));
VMAP_SUPPORT("fmax", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(fmax)), &at::fmax>));
OP_DECOMPOSE2(max, other);
OP_DECOMPOSE2(min, other);

View File

@ -83,16 +83,33 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
UNARY_POINTWISE_ALL(tan);
UNARY_POINTWISE_ALL(trunc);
// special-related
UNARY_POINTWISE_ALL(i0);
UNARY_POINTWISE_ALL(erfc);
UNARY_POINTWISE_ALL(erfinv);
UNARY_POINTWISE_ALL(exp2);
// torch.special.* functions
UNARY_POINTWISE(special_entr);
UNARY_POINTWISE(special_erf);
UNARY_POINTWISE(special_erfc);
UNARY_POINTWISE(special_erfcx);
UNARY_POINTWISE(special_erfinv);
UNARY_POINTWISE(special_expit);
UNARY_POINTWISE(special_expm1);
UNARY_POINTWISE(special_digamma);
UNARY_POINTWISE(special_psi);
UNARY_POINTWISE(special_exp2);
UNARY_POINTWISE(special_gammaln);
UNARY_POINTWISE(special_i0);
UNARY_POINTWISE(special_i0e);
UNARY_POINTWISE(special_i1);
UNARY_POINTWISE(special_i1e);
UNARY_POINTWISE(special_log1p);
UNARY_POINTWISE(special_ndtr);
UNARY_POINTWISE(special_ndtri);
UNARY_POINTWISE(special_round);
UNARY_POINTWISE(special_sinc);
// Activation functions (from https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity)
UNARY_POINTWISE_SCALAR_SCALAR_SCALAR(elu);