mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[special] Alias for sigmoid and logit & follow-up (#54759)
Summary: Reference: https://github.com/pytorch/pytorch/issues/50345 Chages: * Alias for sigmoid and logit * Adds out variant for C++ API * Updates docs to link back to `special` documentation Pull Request resolved: https://github.com/pytorch/pytorch/pull/54759 Reviewed By: mrshenli Differential Revision: D27615208 Pulled By: mruberry fbshipit-source-id: 8bba908d1bea246e4aa9dbadb6951339af353556
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f4967d68f5
commit
902bf0bbbe
@ -435,7 +435,6 @@ _(aten, _log_softmax) \
|
||||
_(aten, _log_softmax_backward_data) \
|
||||
_(aten, logcumsumexp) \
|
||||
_(aten, logdet) \
|
||||
_(aten, logit) \
|
||||
_(aten, logspace) \
|
||||
_(aten, logsumexp) \
|
||||
_(aten, xlogy) \
|
||||
@ -620,7 +619,6 @@ _(aten, segment_reduce) \
|
||||
_(aten, select) \
|
||||
_(aten, selu) \
|
||||
_(aten, set) \
|
||||
_(aten, sigmoid) \
|
||||
_(aten, sign) \
|
||||
_(aten, signbit) \
|
||||
_(aten, silu) \
|
||||
|
@ -321,6 +321,10 @@ namespace c10 {
|
||||
_(aten, special_erfc) \
|
||||
_(aten, erfinv) \
|
||||
_(aten, special_erfinv) \
|
||||
_(aten, logit) \
|
||||
_(aten, special_logit) \
|
||||
_(aten, sigmoid) \
|
||||
_(aten, special_expit) \
|
||||
_(aten, expm1) \
|
||||
_(aten, special_expm1) \
|
||||
_(aten, exp2) \
|
||||
|
@ -474,6 +474,21 @@ Tensor& logit_(Tensor& self, c10::optional<double> eps) {
|
||||
return at::logit_out(self, self, eps);
|
||||
}
|
||||
|
||||
Tensor& special_logit_out(const Tensor& self, c10::optional<double> eps, Tensor& result) {
|
||||
return at::logit_out(result, self, eps);
|
||||
}
|
||||
Tensor special_logit(const Tensor& self, c10::optional<double> eps) {
|
||||
return self.logit(eps);
|
||||
}
|
||||
|
||||
// special_expit, alias for sigmoid
|
||||
Tensor& special_expit_out(const Tensor& self, Tensor& result) {
|
||||
return at::sigmoid_out(result, self);
|
||||
}
|
||||
Tensor special_expit(const Tensor& self) {
|
||||
return self.sigmoid();
|
||||
}
|
||||
|
||||
Tensor& nan_to_num_out(const Tensor& self,
|
||||
c10::optional<double> nan,
|
||||
c10::optional<double> pos_inf,
|
||||
|
@ -8382,6 +8382,21 @@
|
||||
- func: special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: special
|
||||
|
||||
- func: special_logit(Tensor self, float? eps=None) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: special
|
||||
|
||||
- func: special_expit(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
## Functions related to the fast Fourier transform and the torch.fft namespace
|
||||
# Note [FFT namespace binding]
|
||||
# Functions in the fft python module should have their names start with
|
||||
|
@ -22,6 +22,8 @@ Functions
|
||||
.. autofunction:: erf
|
||||
.. autofunction:: erfc
|
||||
.. autofunction:: erfinv
|
||||
.. autofunction:: expit
|
||||
.. autofunction:: expm1
|
||||
.. autofunction:: exp2
|
||||
.. autofunction:: gammaln
|
||||
.. autofunction:: logit
|
||||
|
@ -2983,75 +2983,22 @@ add_docstr(torch.erf,
|
||||
r"""
|
||||
erf(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the error function of each element. The error function is defined as follows:
|
||||
|
||||
.. math::
|
||||
\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt
|
||||
""" + r"""
|
||||
|
||||
.. note:: Alias for :func:`torch.special.erf`.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.erf(torch.tensor([0, -1., 10.]))
|
||||
tensor([ 0.0000, -0.8427, 1.0000])
|
||||
""".format(**common_args))
|
||||
Alias for :func:`torch.special.erf`.
|
||||
""")
|
||||
|
||||
add_docstr(torch.erfc,
|
||||
r"""
|
||||
erfc(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the complementary error function of each element of :attr:`input`.
|
||||
The complementary error function is defined as follows:
|
||||
|
||||
.. math::
|
||||
\mathrm{erfc}(x) = 1 - \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt
|
||||
""" + r"""
|
||||
|
||||
.. note:: Alias for :func:`torch.special.erfc`.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.erfc(torch.tensor([0, -1., 10.]))
|
||||
tensor([ 1.0000, 1.8427, 0.0000])
|
||||
""".format(**common_args))
|
||||
Alias for :func:`torch.special.erfc`.
|
||||
""")
|
||||
|
||||
add_docstr(torch.erfinv,
|
||||
r"""
|
||||
erfinv(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the inverse error function of each element of :attr:`input`.
|
||||
The inverse error function is defined in the range :math:`(-1, 1)` as:
|
||||
|
||||
.. math::
|
||||
\mathrm{erfinv}(\mathrm{erf}(x)) = x
|
||||
""" + r"""
|
||||
|
||||
.. note:: Alias for :func:`torch.special.erfinv`.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.erfinv(torch.tensor([0, 0.5, -1.]))
|
||||
tensor([ 0.0000, 0.4769, -inf])
|
||||
""".format(**common_args))
|
||||
Alias for :func:`torch.special.erfinv`.
|
||||
""")
|
||||
|
||||
add_docstr(torch.exp,
|
||||
r"""
|
||||
@ -3079,52 +3026,15 @@ add_docstr(torch.exp2,
|
||||
r"""
|
||||
exp2(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the base two exponential function of :attr:`input`.
|
||||
|
||||
.. math::
|
||||
y_{i} = 2^{x_{i}}
|
||||
|
||||
.. note:: Alias for :func:`torch.special.exp2`.
|
||||
""" + r"""
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.exp2(torch.tensor([0, math.log2(2.), 3, 4]))
|
||||
tensor([ 1., 2., 8., 16.])
|
||||
""".format(**common_args))
|
||||
Alias for :func:`torch.special.exp2`.
|
||||
""")
|
||||
|
||||
add_docstr(torch.expm1,
|
||||
r"""
|
||||
expm1(input, *, out=None) -> Tensor
|
||||
|
||||
Returns a new tensor with the exponential of the elements minus 1
|
||||
of :attr:`input`.
|
||||
|
||||
.. math::
|
||||
y_{i} = e^{x_{i}} - 1
|
||||
|
||||
.. note:: This function provides greater precision than exp(x) - 1 for small values of x.
|
||||
|
||||
.. note:: Alias for :func:`torch.special.expm1`.
|
||||
""" + r"""
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.expm1(torch.tensor([0, math.log(2.)]))
|
||||
tensor([ 0., 1.])
|
||||
""".format(**common_args))
|
||||
Alias for :func:`torch.special.expm1`.
|
||||
""")
|
||||
|
||||
add_docstr(torch.eye,
|
||||
r"""
|
||||
@ -7678,58 +7588,15 @@ Sets the number of threads used for interop parallelism
|
||||
add_docstr(torch.sigmoid, r"""
|
||||
sigmoid(input, *, out=None) -> Tensor
|
||||
|
||||
Returns a new tensor with the sigmoid of the elements of :attr:`input`.
|
||||
|
||||
.. math::
|
||||
\text{out}_{i} = \frac{1}{1 + e^{-\text{input}_{i}}}
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.randn(4)
|
||||
>>> a
|
||||
tensor([ 0.9213, 1.0887, -0.8858, -1.7683])
|
||||
>>> torch.sigmoid(a)
|
||||
tensor([ 0.7153, 0.7481, 0.2920, 0.1458])
|
||||
""".format(**common_args))
|
||||
Alias for :func:`torch.special.expit`.
|
||||
""")
|
||||
|
||||
add_docstr(torch.logit,
|
||||
r"""
|
||||
logit(input, eps=None, *, out=None) -> Tensor
|
||||
|
||||
Returns a new tensor with the logit of the elements of :attr:`input`.
|
||||
:attr:`input` is clamped to [eps, 1 - eps] when eps is not None.
|
||||
When eps is None and :attr:`input` < 0 or :attr:`input` > 1, the function will yields NaN.
|
||||
|
||||
.. math::
|
||||
y_{i} = \ln(\frac{z_{i}}{1 - z_{i}}) \\
|
||||
z_{i} = \begin{cases}
|
||||
x_{i} & \text{if eps is None} \\
|
||||
\text{eps} & \text{if } x_{i} < \text{eps} \\
|
||||
x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\
|
||||
1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps}
|
||||
\end{cases}
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
eps (float, optional): the epsilon for input clamp bound. Default: ``None``
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.rand(5)
|
||||
>>> a
|
||||
tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
|
||||
>>> torch.logit(a, eps=1e-6)
|
||||
tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])
|
||||
""".format(**common_args))
|
||||
Alias for :func:`torch.special.logit`.
|
||||
""")
|
||||
|
||||
add_docstr(torch.sign,
|
||||
r"""
|
||||
|
@ -17,6 +17,10 @@ inline Tensor gammaln(const Tensor& self) {
|
||||
return torch::special_gammaln(self);
|
||||
}
|
||||
|
||||
inline Tensor& gammaln_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_gammaln_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes entropy of input, elementwise
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.entr.
|
||||
///
|
||||
@ -29,6 +33,10 @@ inline Tensor entr(const Tensor& self) {
|
||||
return torch::special_entr(self);
|
||||
}
|
||||
|
||||
inline Tensor& entr_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_entr_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the error function
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.erf.
|
||||
///
|
||||
@ -41,6 +49,10 @@ inline Tensor erf(const Tensor& self) {
|
||||
return torch::special_erf(self);
|
||||
}
|
||||
|
||||
inline Tensor& erf_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_erf_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the complementary error function
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.erfc.
|
||||
///
|
||||
@ -53,6 +65,10 @@ inline Tensor erfc(const Tensor& self) {
|
||||
return torch::special_erfc(self);
|
||||
}
|
||||
|
||||
inline Tensor& erfc_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_erfc_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the inverse error function
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.erfinv.
|
||||
///
|
||||
@ -65,6 +81,42 @@ inline Tensor erfinv(const Tensor& self) {
|
||||
return torch::special_erfinv(self);
|
||||
}
|
||||
|
||||
inline Tensor& erfinv_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_erfinv_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the logit of input, elementwise.
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.logit.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// auto t = torch::randn(128, dtype=kDouble);
|
||||
/// torch::special::logit(t);
|
||||
/// ```
|
||||
inline Tensor logit(const Tensor& self) {
|
||||
return torch::special_logit(self);
|
||||
}
|
||||
|
||||
inline Tensor& logit_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_logit_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the expit (also known as the logistic sigmoid function) of input, elementwise
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.expit.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// auto t = torch::randn(128, dtype=kDouble);
|
||||
/// torch::special::expit(t);
|
||||
/// ```
|
||||
inline Tensor expit(const Tensor& self) {
|
||||
return torch::special_expit(self);
|
||||
}
|
||||
|
||||
inline Tensor& expit_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_expit_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the base two exponential function of :attr:`input`, elementwise
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.exp2.
|
||||
///
|
||||
@ -77,6 +129,10 @@ inline Tensor exp2(const Tensor& self) {
|
||||
return torch::special_exp2(self);
|
||||
}
|
||||
|
||||
inline Tensor& exp2_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_exp2_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the exponential of the elements minus 1, elementwise
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.expm1.
|
||||
///
|
||||
@ -89,4 +145,8 @@ inline Tensor expm1(const Tensor& self) {
|
||||
return torch::special_expm1(self);
|
||||
}
|
||||
|
||||
inline Tensor& expm1_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_expm1_out(result, self);
|
||||
}
|
||||
|
||||
}} // torch::special
|
||||
|
@ -90,8 +90,10 @@ const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
|
||||
{aten::special_erf, aten::erf},
|
||||
{aten::special_erfc, aten::erfc},
|
||||
{aten::special_erfinv, aten::erfinv},
|
||||
{aten::special_expit, aten::sigmoid},
|
||||
{aten::special_exp2, aten::exp2},
|
||||
{aten::special_expm1, aten::expm1},
|
||||
{aten::special_logit, aten::logit},
|
||||
{aten::orgqr, aten::linalg_householder_product},
|
||||
{aten::special_gammaln, aten::lgamma}};
|
||||
return alias_map;
|
||||
|
@ -844,7 +844,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.special.erfinv: lambda input: -1,
|
||||
torch.special.exp2: lambda input: -1,
|
||||
torch.special.expm1: lambda input: -1,
|
||||
torch.special.expit: lambda input: -1,
|
||||
torch.special.gammaln: lambda input: -1,
|
||||
torch.special.logit: lambda input: -1,
|
||||
torch.t: lambda input: -1,
|
||||
torch.take: lambda input, index: -1,
|
||||
torch.take_along_dim: lambda input, indices, dim=None, out=None: -1,
|
||||
|
@ -73,7 +73,7 @@ Keyword args:
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.erf(torch.tensor([0, -1., 10.]))
|
||||
>>> torch.special.erf(torch.tensor([0, -1., 10.]))
|
||||
tensor([ 0.0000, -0.8427, 1.0000])
|
||||
""".format(**common_args))
|
||||
|
||||
@ -95,7 +95,7 @@ Keyword args:
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.erfc(torch.tensor([0, -1., 10.]))
|
||||
>>> torch.special.erfc(torch.tensor([0, -1., 10.]))
|
||||
tensor([ 1.0000, 1.8427, 0.0000])
|
||||
""".format(**common_args))
|
||||
|
||||
@ -118,10 +118,67 @@ Keyword args:
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.erfinv(torch.tensor([0, 0.5, -1.]))
|
||||
>>> torch.special.erfinv(torch.tensor([0, 0.5, -1.]))
|
||||
tensor([ 0.0000, 0.4769, -inf])
|
||||
""".format(**common_args))
|
||||
|
||||
logit = _add_docstr(_special.special_logit,
|
||||
r"""
|
||||
logit(input, eps=None, *, out=None) -> Tensor
|
||||
|
||||
Returns a new tensor with the logit of the elements of :attr:`input`.
|
||||
:attr:`input` is clamped to [eps, 1 - eps] when eps is not None.
|
||||
When eps is None and :attr:`input` < 0 or :attr:`input` > 1, the function will yields NaN.
|
||||
|
||||
.. math::
|
||||
y_{i} = \ln(\frac{z_{i}}{1 - z_{i}}) \\
|
||||
z_{i} = \begin{cases}
|
||||
x_{i} & \text{if eps is None} \\
|
||||
\text{eps} & \text{if } x_{i} < \text{eps} \\
|
||||
x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\
|
||||
1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps}
|
||||
\end{cases}
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
eps (float, optional): the epsilon for input clamp bound. Default: ``None``
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.rand(5)
|
||||
>>> a
|
||||
tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
|
||||
>>> torch.special.logit(a, eps=1e-6)
|
||||
tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])
|
||||
""".format(**common_args))
|
||||
|
||||
expit = _add_docstr(_special.special_expit,
|
||||
r"""
|
||||
expit(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the expit (also known as the logistic sigmoid function) of the elements of :attr:`input`.
|
||||
|
||||
.. math::
|
||||
\text{out}_{i} = \frac{1}{1 + e^{-\text{input}_{i}}}
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> t = torch.randn(4)
|
||||
>>> t
|
||||
tensor([ 0.9213, 1.0887, -0.8858, -1.7683])
|
||||
>>> torch.special.expit(t)
|
||||
tensor([ 0.7153, 0.7481, 0.2920, 0.1458])
|
||||
""".format(**common_args))
|
||||
|
||||
exp2 = _add_docstr(_special.special_exp2,
|
||||
r"""
|
||||
exp2(input, *, out=None) -> Tensor
|
||||
@ -151,8 +208,6 @@ expm1(input, *, out=None) -> Tensor
|
||||
Computes the exponential of the elements minus 1
|
||||
of :attr:`input`.
|
||||
|
||||
..
|
||||
|
||||
.. math::
|
||||
y_{i} = e^{x_{i}} - 1
|
||||
|
||||
@ -167,6 +222,6 @@ Keyword args:
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.expm1(torch.tensor([0, math.log(2.)]))
|
||||
>>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
|
||||
tensor([ 0., 1.])
|
||||
""".format(**common_args))
|
||||
|
@ -3905,6 +3905,7 @@ op_db: List[OpInfo] = [
|
||||
supports_out=False,
|
||||
sample_inputs_func=sample_inputs_trace),
|
||||
UnaryUfuncInfo('sigmoid',
|
||||
aliases=('special.expit', ),
|
||||
ref=reference_sigmoid if TEST_SCIPY else _NOTHING,
|
||||
decorators=(precisionOverride({torch.float16: 1e-2,
|
||||
torch.bfloat16: 1e-2}),),
|
||||
@ -4022,6 +4023,7 @@ op_db: List[OpInfo] = [
|
||||
UnaryUfuncInfo('logit',
|
||||
ref=scipy.special.logit if TEST_SCIPY else _NOTHING,
|
||||
domain=(0, 1),
|
||||
aliases=('special.logit', ),
|
||||
decorators=(precisionOverride({torch.bfloat16: 5e-1,
|
||||
torch.float16: 5e-1}),),
|
||||
dtypes=all_types_and(torch.half),
|
||||
|
Reference in New Issue
Block a user