mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] batch rules for torch.special unary ops
This commit is contained in:
@ -83,6 +83,17 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
UNARY_POINTWISE_ALL(tan);
|
||||
UNARY_POINTWISE_ALL(trunc);
|
||||
|
||||
// torch.special.* functions
|
||||
UNARY_POINTWISE(special_entr);
|
||||
UNARY_POINTWISE(special_erf);
|
||||
UNARY_POINTWISE(special_erfc);
|
||||
UNARY_POINTWISE(special_erfinv);
|
||||
UNARY_POINTWISE(special_expit);
|
||||
UNARY_POINTWISE(special_expm1);
|
||||
UNARY_POINTWISE(special_exp2);
|
||||
UNARY_POINTWISE(special_gammaln);
|
||||
UNARY_POINTWISE(special_i0e);
|
||||
|
||||
// Activation functions (from https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity)
|
||||
UNARY_POINTWISE_SCALAR_SCALAR_SCALAR(elu);
|
||||
UNARY_POINTWISE_SCALAR(hardshrink);
|
||||
|
Reference in New Issue
Block a user