mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Some more batch rules for pointwise ops
This commit is contained in:
@ -212,6 +212,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT("threshold_backward", SINGLE_ARG(
|
||||
binary_pointwise_batch_rule<decltype(&at::threshold_backward), &at::threshold_backward, const Scalar&>));
|
||||
|
||||
VMAP_SUPPORT("fmin", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(fmin)), &at::fmin>));
|
||||
VMAP_SUPPORT("fmax", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(fmax)), &at::fmax>));
|
||||
|
||||
OP_DECOMPOSE2(max, other);
|
||||
OP_DECOMPOSE2(min, other);
|
||||
|
@ -83,16 +83,33 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
UNARY_POINTWISE_ALL(tan);
|
||||
UNARY_POINTWISE_ALL(trunc);
|
||||
|
||||
// special-related
|
||||
UNARY_POINTWISE_ALL(i0);
|
||||
UNARY_POINTWISE_ALL(erfc);
|
||||
UNARY_POINTWISE_ALL(erfinv);
|
||||
UNARY_POINTWISE_ALL(exp2);
|
||||
|
||||
// torch.special.* functions
|
||||
UNARY_POINTWISE(special_entr);
|
||||
UNARY_POINTWISE(special_erf);
|
||||
UNARY_POINTWISE(special_erfc);
|
||||
UNARY_POINTWISE(special_erfcx);
|
||||
UNARY_POINTWISE(special_erfinv);
|
||||
UNARY_POINTWISE(special_expit);
|
||||
UNARY_POINTWISE(special_expm1);
|
||||
UNARY_POINTWISE(special_digamma);
|
||||
UNARY_POINTWISE(special_psi);
|
||||
UNARY_POINTWISE(special_exp2);
|
||||
UNARY_POINTWISE(special_gammaln);
|
||||
UNARY_POINTWISE(special_i0);
|
||||
UNARY_POINTWISE(special_i0e);
|
||||
UNARY_POINTWISE(special_i1);
|
||||
UNARY_POINTWISE(special_i1e);
|
||||
UNARY_POINTWISE(special_log1p);
|
||||
UNARY_POINTWISE(special_ndtr);
|
||||
UNARY_POINTWISE(special_ndtri);
|
||||
UNARY_POINTWISE(special_round);
|
||||
UNARY_POINTWISE(special_sinc);
|
||||
|
||||
// Activation functions (from https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity)
|
||||
UNARY_POINTWISE_SCALAR_SCALAR_SCALAR(elu);
|
||||
|
Reference in New Issue
Block a user