[functorch] batch rules for torch.special unary ops

This commit is contained in:
Richard Zou
2021-07-20 17:16:39 -07:00
committed by Jon Janzen
parent 1b78cae7b6
commit 983a43cfc9

View File

@ -83,6 +83,17 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
UNARY_POINTWISE_ALL(tan);
UNARY_POINTWISE_ALL(trunc);
// torch.special.* functions
UNARY_POINTWISE(special_entr);
UNARY_POINTWISE(special_erf);
UNARY_POINTWISE(special_erfc);
UNARY_POINTWISE(special_erfinv);
UNARY_POINTWISE(special_expit);
UNARY_POINTWISE(special_expm1);
UNARY_POINTWISE(special_exp2);
UNARY_POINTWISE(special_gammaln);
UNARY_POINTWISE(special_i0e);
// Activation functions (from https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity)
UNARY_POINTWISE_SCALAR_SCALAR_SCALAR(elu);
UNARY_POINTWISE_SCALAR(hardshrink);