[special] Alias for sigmoid and logit & follow-up (#54759)

Summary:
Reference: https://github.com/pytorch/pytorch/issues/50345

Chages:
* Alias for sigmoid and logit
* Adds out variant for C++ API
* Updates docs to link back to `special` documentation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/54759

Reviewed By: mrshenli

Differential Revision: D27615208

Pulled By: mruberry

fbshipit-source-id: 8bba908d1bea246e4aa9dbadb6951339af353556
This commit is contained in:
kshitij12345
2021-04-08 00:55:39 -07:00
committed by Facebook GitHub Bot
parent f4967d68f5
commit 902bf0bbbe
11 changed files with 177 additions and 155 deletions

View File

@ -435,7 +435,6 @@ _(aten, _log_softmax) \
_(aten, _log_softmax_backward_data) \
_(aten, logcumsumexp) \
_(aten, logdet) \
_(aten, logit) \
_(aten, logspace) \
_(aten, logsumexp) \
_(aten, xlogy) \
@ -620,7 +619,6 @@ _(aten, segment_reduce) \
_(aten, select) \
_(aten, selu) \
_(aten, set) \
_(aten, sigmoid) \
_(aten, sign) \
_(aten, signbit) \
_(aten, silu) \

View File

@ -321,6 +321,10 @@ namespace c10 {
_(aten, special_erfc) \
_(aten, erfinv) \
_(aten, special_erfinv) \
_(aten, logit) \
_(aten, special_logit) \
_(aten, sigmoid) \
_(aten, special_expit) \
_(aten, expm1) \
_(aten, special_expm1) \
_(aten, exp2) \

View File

@ -474,6 +474,21 @@ Tensor& logit_(Tensor& self, c10::optional<double> eps) {
return at::logit_out(self, self, eps);
}
Tensor& special_logit_out(const Tensor& self, c10::optional<double> eps, Tensor& result) {
return at::logit_out(result, self, eps);
}
Tensor special_logit(const Tensor& self, c10::optional<double> eps) {
return self.logit(eps);
}
// special_expit, alias for sigmoid
Tensor& special_expit_out(const Tensor& self, Tensor& result) {
return at::sigmoid_out(result, self);
}
Tensor special_expit(const Tensor& self) {
return self.sigmoid();
}
Tensor& nan_to_num_out(const Tensor& self,
c10::optional<double> nan,
c10::optional<double> pos_inf,

View File

@ -8382,6 +8382,21 @@
- func: special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
python_module: special
- func: special_logit(Tensor self, float? eps=None) -> Tensor
python_module: special
variants: function
- func: special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)
python_module: special
- func: special_expit(Tensor self) -> Tensor
python_module: special
variants: function
- func: special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
python_module: special
variants: function
## Functions related to the fast Fourier transform and the torch.fft namespace
# Note [FFT namespace binding]
# Functions in the fft python module should have their names start with

View File

@ -22,6 +22,8 @@ Functions
.. autofunction:: erf
.. autofunction:: erfc
.. autofunction:: erfinv
.. autofunction:: expit
.. autofunction:: expm1
.. autofunction:: exp2
.. autofunction:: gammaln
.. autofunction:: logit

View File

@ -2983,75 +2983,22 @@ add_docstr(torch.erf,
r"""
erf(input, *, out=None) -> Tensor
Computes the error function of each element. The error function is defined as follows:
.. math::
\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt
""" + r"""
.. note:: Alias for :func:`torch.special.erf`.
Args:
{input}
Keyword args:
{out}
Example::
>>> torch.erf(torch.tensor([0, -1., 10.]))
tensor([ 0.0000, -0.8427, 1.0000])
""".format(**common_args))
Alias for :func:`torch.special.erf`.
""")
add_docstr(torch.erfc,
r"""
erfc(input, *, out=None) -> Tensor
Computes the complementary error function of each element of :attr:`input`.
The complementary error function is defined as follows:
.. math::
\mathrm{erfc}(x) = 1 - \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt
""" + r"""
.. note:: Alias for :func:`torch.special.erfc`.
Args:
{input}
Keyword args:
{out}
Example::
>>> torch.erfc(torch.tensor([0, -1., 10.]))
tensor([ 1.0000, 1.8427, 0.0000])
""".format(**common_args))
Alias for :func:`torch.special.erfc`.
""")
add_docstr(torch.erfinv,
r"""
erfinv(input, *, out=None) -> Tensor
Computes the inverse error function of each element of :attr:`input`.
The inverse error function is defined in the range :math:`(-1, 1)` as:
.. math::
\mathrm{erfinv}(\mathrm{erf}(x)) = x
""" + r"""
.. note:: Alias for :func:`torch.special.erfinv`.
Args:
{input}
Keyword args:
{out}
Example::
>>> torch.erfinv(torch.tensor([0, 0.5, -1.]))
tensor([ 0.0000, 0.4769, -inf])
""".format(**common_args))
Alias for :func:`torch.special.erfinv`.
""")
add_docstr(torch.exp,
r"""
@ -3079,52 +3026,15 @@ add_docstr(torch.exp2,
r"""
exp2(input, *, out=None) -> Tensor
Computes the base two exponential function of :attr:`input`.
.. math::
y_{i} = 2^{x_{i}}
.. note:: Alias for :func:`torch.special.exp2`.
""" + r"""
Args:
{input}
Keyword args:
{out}
Example::
>>> torch.exp2(torch.tensor([0, math.log2(2.), 3, 4]))
tensor([ 1., 2., 8., 16.])
""".format(**common_args))
Alias for :func:`torch.special.exp2`.
""")
add_docstr(torch.expm1,
r"""
expm1(input, *, out=None) -> Tensor
Returns a new tensor with the exponential of the elements minus 1
of :attr:`input`.
.. math::
y_{i} = e^{x_{i}} - 1
.. note:: This function provides greater precision than exp(x) - 1 for small values of x.
.. note:: Alias for :func:`torch.special.expm1`.
""" + r"""
Args:
{input}
Keyword args:
{out}
Example::
>>> torch.expm1(torch.tensor([0, math.log(2.)]))
tensor([ 0., 1.])
""".format(**common_args))
Alias for :func:`torch.special.expm1`.
""")
add_docstr(torch.eye,
r"""
@ -7678,58 +7588,15 @@ Sets the number of threads used for interop parallelism
add_docstr(torch.sigmoid, r"""
sigmoid(input, *, out=None) -> Tensor
Returns a new tensor with the sigmoid of the elements of :attr:`input`.
.. math::
\text{out}_{i} = \frac{1}{1 + e^{-\text{input}_{i}}}
""" + r"""
Args:
{input}
Keyword args:
{out}
Example::
>>> a = torch.randn(4)
>>> a
tensor([ 0.9213, 1.0887, -0.8858, -1.7683])
>>> torch.sigmoid(a)
tensor([ 0.7153, 0.7481, 0.2920, 0.1458])
""".format(**common_args))
Alias for :func:`torch.special.expit`.
""")
add_docstr(torch.logit,
r"""
logit(input, eps=None, *, out=None) -> Tensor
Returns a new tensor with the logit of the elements of :attr:`input`.
:attr:`input` is clamped to [eps, 1 - eps] when eps is not None.
When eps is None and :attr:`input` < 0 or :attr:`input` > 1, the function will yields NaN.
.. math::
y_{i} = \ln(\frac{z_{i}}{1 - z_{i}}) \\
z_{i} = \begin{cases}
x_{i} & \text{if eps is None} \\
\text{eps} & \text{if } x_{i} < \text{eps} \\
x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\
1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps}
\end{cases}
""" + r"""
Args:
{input}
eps (float, optional): the epsilon for input clamp bound. Default: ``None``
Keyword args:
{out}
Example::
>>> a = torch.rand(5)
>>> a
tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
>>> torch.logit(a, eps=1e-6)
tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])
""".format(**common_args))
Alias for :func:`torch.special.logit`.
""")
add_docstr(torch.sign,
r"""

View File

@ -17,6 +17,10 @@ inline Tensor gammaln(const Tensor& self) {
return torch::special_gammaln(self);
}
inline Tensor& gammaln_out(Tensor& result, const Tensor& self) {
return torch::special_gammaln_out(result, self);
}
/// Computes entropy of input, elementwise
/// See https://pytorch.org/docs/master/special.html#torch.special.entr.
///
@ -29,6 +33,10 @@ inline Tensor entr(const Tensor& self) {
return torch::special_entr(self);
}
inline Tensor& entr_out(Tensor& result, const Tensor& self) {
return torch::special_entr_out(result, self);
}
/// Computes the error function
/// See https://pytorch.org/docs/master/special.html#torch.special.erf.
///
@ -41,6 +49,10 @@ inline Tensor erf(const Tensor& self) {
return torch::special_erf(self);
}
inline Tensor& erf_out(Tensor& result, const Tensor& self) {
return torch::special_erf_out(result, self);
}
/// Computes the complementary error function
/// See https://pytorch.org/docs/master/special.html#torch.special.erfc.
///
@ -53,6 +65,10 @@ inline Tensor erfc(const Tensor& self) {
return torch::special_erfc(self);
}
inline Tensor& erfc_out(Tensor& result, const Tensor& self) {
return torch::special_erfc_out(result, self);
}
/// Computes the inverse error function
/// See https://pytorch.org/docs/master/special.html#torch.special.erfinv.
///
@ -65,6 +81,42 @@ inline Tensor erfinv(const Tensor& self) {
return torch::special_erfinv(self);
}
inline Tensor& erfinv_out(Tensor& result, const Tensor& self) {
return torch::special_erfinv_out(result, self);
}
/// Computes the logit of input, elementwise.
/// See https://pytorch.org/docs/master/special.html#torch.special.logit.
///
/// Example:
/// ```
/// auto t = torch::randn(128, dtype=kDouble);
/// torch::special::logit(t);
/// ```
inline Tensor logit(const Tensor& self) {
return torch::special_logit(self);
}
inline Tensor& logit_out(Tensor& result, const Tensor& self) {
return torch::special_logit_out(result, self);
}
/// Computes the expit (also known as the logistic sigmoid function) of input, elementwise
/// See https://pytorch.org/docs/master/special.html#torch.special.expit.
///
/// Example:
/// ```
/// auto t = torch::randn(128, dtype=kDouble);
/// torch::special::expit(t);
/// ```
inline Tensor expit(const Tensor& self) {
return torch::special_expit(self);
}
inline Tensor& expit_out(Tensor& result, const Tensor& self) {
return torch::special_expit_out(result, self);
}
/// Computes the base two exponential function of :attr:`input`, elementwise
/// See https://pytorch.org/docs/master/special.html#torch.special.exp2.
///
@ -77,6 +129,10 @@ inline Tensor exp2(const Tensor& self) {
return torch::special_exp2(self);
}
inline Tensor& exp2_out(Tensor& result, const Tensor& self) {
return torch::special_exp2_out(result, self);
}
/// Computes the exponential of the elements minus 1, elementwise
/// See https://pytorch.org/docs/master/special.html#torch.special.expm1.
///
@ -89,4 +145,8 @@ inline Tensor expm1(const Tensor& self) {
return torch::special_expm1(self);
}
inline Tensor& expm1_out(Tensor& result, const Tensor& self) {
return torch::special_expm1_out(result, self);
}
}} // torch::special

View File

@ -90,8 +90,10 @@ const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
{aten::special_erf, aten::erf},
{aten::special_erfc, aten::erfc},
{aten::special_erfinv, aten::erfinv},
{aten::special_expit, aten::sigmoid},
{aten::special_exp2, aten::exp2},
{aten::special_expm1, aten::expm1},
{aten::special_logit, aten::logit},
{aten::orgqr, aten::linalg_householder_product},
{aten::special_gammaln, aten::lgamma}};
return alias_map;

View File

@ -844,7 +844,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.special.erfinv: lambda input: -1,
torch.special.exp2: lambda input: -1,
torch.special.expm1: lambda input: -1,
torch.special.expit: lambda input: -1,
torch.special.gammaln: lambda input: -1,
torch.special.logit: lambda input: -1,
torch.t: lambda input: -1,
torch.take: lambda input, index: -1,
torch.take_along_dim: lambda input, indices, dim=None, out=None: -1,

View File

@ -73,7 +73,7 @@ Keyword args:
Example::
>>> torch.erf(torch.tensor([0, -1., 10.]))
>>> torch.special.erf(torch.tensor([0, -1., 10.]))
tensor([ 0.0000, -0.8427, 1.0000])
""".format(**common_args))
@ -95,7 +95,7 @@ Keyword args:
Example::
>>> torch.erfc(torch.tensor([0, -1., 10.]))
>>> torch.special.erfc(torch.tensor([0, -1., 10.]))
tensor([ 1.0000, 1.8427, 0.0000])
""".format(**common_args))
@ -118,10 +118,67 @@ Keyword args:
Example::
>>> torch.erfinv(torch.tensor([0, 0.5, -1.]))
>>> torch.special.erfinv(torch.tensor([0, 0.5, -1.]))
tensor([ 0.0000, 0.4769, -inf])
""".format(**common_args))
logit = _add_docstr(_special.special_logit,
r"""
logit(input, eps=None, *, out=None) -> Tensor
Returns a new tensor with the logit of the elements of :attr:`input`.
:attr:`input` is clamped to [eps, 1 - eps] when eps is not None.
When eps is None and :attr:`input` < 0 or :attr:`input` > 1, the function will yields NaN.
.. math::
y_{i} = \ln(\frac{z_{i}}{1 - z_{i}}) \\
z_{i} = \begin{cases}
x_{i} & \text{if eps is None} \\
\text{eps} & \text{if } x_{i} < \text{eps} \\
x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\
1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps}
\end{cases}
""" + r"""
Args:
{input}
eps (float, optional): the epsilon for input clamp bound. Default: ``None``
Keyword args:
{out}
Example::
>>> a = torch.rand(5)
>>> a
tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
>>> torch.special.logit(a, eps=1e-6)
tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])
""".format(**common_args))
expit = _add_docstr(_special.special_expit,
r"""
expit(input, *, out=None) -> Tensor
Computes the expit (also known as the logistic sigmoid function) of the elements of :attr:`input`.
.. math::
\text{out}_{i} = \frac{1}{1 + e^{-\text{input}_{i}}}
""" + r"""
Args:
{input}
Keyword args:
{out}
Example::
>>> t = torch.randn(4)
>>> t
tensor([ 0.9213, 1.0887, -0.8858, -1.7683])
>>> torch.special.expit(t)
tensor([ 0.7153, 0.7481, 0.2920, 0.1458])
""".format(**common_args))
exp2 = _add_docstr(_special.special_exp2,
r"""
exp2(input, *, out=None) -> Tensor
@ -151,8 +208,6 @@ expm1(input, *, out=None) -> Tensor
Computes the exponential of the elements minus 1
of :attr:`input`.
..
.. math::
y_{i} = e^{x_{i}} - 1
@ -167,6 +222,6 @@ Keyword args:
Example::
>>> torch.expm1(torch.tensor([0, math.log(2.)]))
>>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
tensor([ 0., 1.])
""".format(**common_args))

View File

@ -3905,6 +3905,7 @@ op_db: List[OpInfo] = [
supports_out=False,
sample_inputs_func=sample_inputs_trace),
UnaryUfuncInfo('sigmoid',
aliases=('special.expit', ),
ref=reference_sigmoid if TEST_SCIPY else _NOTHING,
decorators=(precisionOverride({torch.float16: 1e-2,
torch.bfloat16: 1e-2}),),
@ -4022,6 +4023,7 @@ op_db: List[OpInfo] = [
UnaryUfuncInfo('logit',
ref=scipy.special.logit if TEST_SCIPY else _NOTHING,
domain=(0, 1),
aliases=('special.logit', ),
decorators=(precisionOverride({torch.bfloat16: 5e-1,
torch.float16: 5e-1}),),
dtypes=all_types_and(torch.half),