mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[special] migrate log1p, sinc, round to special namespace (#55878)
Summary: Reference : https://github.com/pytorch/pytorch/issues/50345 Pull Request resolved: https://github.com/pytorch/pytorch/pull/55878 Reviewed By: zou3519, janeyx99 Differential Revision: D29160593 Pulled By: mruberry fbshipit-source-id: f3ca9c541382bab33fb85d7817ce8ddc117c6826
This commit is contained in:
committed by
Facebook GitHub Bot
parent
769c299dcf
commit
01e0296eb7
@ -431,7 +431,6 @@ _(aten, linear) \
|
||||
_(aten, linspace) \
|
||||
_(aten, log) \
|
||||
_(aten, log10) \
|
||||
_(aten, log1p) \
|
||||
_(aten, log2) \
|
||||
_(aten, log_normal) \
|
||||
_(aten, log_sigmoid) \
|
||||
@ -621,7 +620,6 @@ _(aten, rnn_relu_cell) \
|
||||
_(aten, rnn_tanh) \
|
||||
_(aten, rnn_tanh_cell) \
|
||||
_(aten, rot90) \
|
||||
_(aten, round) \
|
||||
_(aten, rrelu) \
|
||||
_(aten, rrelu_with_noise) \
|
||||
_(aten, rrelu_with_noise_backward) \
|
||||
@ -638,7 +636,6 @@ _(aten, signbit) \
|
||||
_(aten, silu) \
|
||||
_(aten, sgn) \
|
||||
_(aten, sin) \
|
||||
_(aten, sinc) \
|
||||
_(aten, sinh) \
|
||||
_(aten, size) \
|
||||
_(aten, sizes) \
|
||||
|
@ -339,6 +339,12 @@ namespace c10 {
|
||||
_(aten, special_expm1) \
|
||||
_(aten, exp2) \
|
||||
_(aten, special_exp2) \
|
||||
_(aten, log1p) \
|
||||
_(aten, special_log1p) \
|
||||
_(aten, round) \
|
||||
_(aten, special_round) \
|
||||
_(aten, sinc) \
|
||||
_(aten, special_sinc) \
|
||||
_(aten, i0) \
|
||||
_(aten, special_i0) \
|
||||
_(aten, special_i0e) \
|
||||
|
@ -454,6 +454,18 @@ Tensor special_digamma(const Tensor& self) { return self.digamma(); }
|
||||
Tensor& special_i0_out(const Tensor& self, Tensor& result) { return at::i0_out(result, self); }
|
||||
Tensor special_i0(const Tensor& self) { return self.i0(); }
|
||||
|
||||
// special_log1p, alias for log1p
|
||||
Tensor& special_log1p_out(const Tensor& self, Tensor& result) { return at::log1p_out(result, self); }
|
||||
Tensor special_log1p(const Tensor& self) { return self.log1p(); }
|
||||
|
||||
// special_round, alias for round
|
||||
Tensor& special_round_out(const Tensor& self, Tensor& result) { return at::round_out(result, self); }
|
||||
Tensor special_round(const Tensor& self) { return self.round(); }
|
||||
|
||||
// special_sinc, alias for sinc
|
||||
Tensor& special_sinc_out(const Tensor& self, Tensor& result) { return at::sinc_out(result, self); }
|
||||
Tensor special_sinc(const Tensor& self) { return self.sinc(); }
|
||||
|
||||
namespace {
|
||||
|
||||
inline Tensor calc_ndtr(const Tensor& self) {
|
||||
|
@ -9652,6 +9652,30 @@
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_sinc(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_round(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_log1p(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
## Functions related to the fast Fourier transform and the torch.fft namespace
|
||||
# Note [FFT namespace binding]
|
||||
# Functions in the fft python module should have their names start with
|
||||
|
@ -33,6 +33,9 @@ Functions
|
||||
.. autofunction:: i1
|
||||
.. autofunction:: i1e
|
||||
.. autofunction:: logit
|
||||
.. autofunction:: log1p
|
||||
.. autofunction:: ndtr
|
||||
.. autofunction:: ndtri
|
||||
.. autofunction:: round
|
||||
.. autofunction:: sinc
|
||||
.. autofunction:: xlog1py
|
||||
|
@ -8072,29 +8072,8 @@ add_docstr(torch.sinc,
|
||||
r"""
|
||||
sinc(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the normalized sinc of :attr:`input.`
|
||||
|
||||
.. math::
|
||||
\text{out}_{i} =
|
||||
\begin{cases}
|
||||
1, & \text{if}\ \text{input}_{i}=0 \\
|
||||
\sin(\pi \text{input}_{i}) / (\pi \text{input}_{i}), & \text{otherwise}
|
||||
\end{cases}
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.randn(4)
|
||||
>>> a
|
||||
tensor([ 0.2252, -0.2948, 1.0267, -1.1566])
|
||||
>>> torch.sinc(a)
|
||||
tensor([ 0.9186, 0.8631, -0.0259, -0.1300])
|
||||
""".format(**common_args))
|
||||
Alias for :func:`torch.special.sinc`.
|
||||
""")
|
||||
|
||||
add_docstr(torch.sinh,
|
||||
r"""
|
||||
|
@ -303,4 +303,52 @@ inline Tensor& i1e_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_i1e_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the sinc of input, elementwise
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.sinc.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// auto t = torch::randn(128, dtype=kDouble);
|
||||
/// torch::special::sinc(t);
|
||||
/// ```
|
||||
inline Tensor sinc(const Tensor& self) {
|
||||
return torch::special_sinc(self);
|
||||
}
|
||||
|
||||
inline Tensor& sinc_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_sinc_out(result, self);
|
||||
}
|
||||
|
||||
/// Rounds the elements of the input
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.round.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// auto t = torch::randn(128, dtype=kDouble);
|
||||
/// torch::special::round(t);
|
||||
/// ```
|
||||
inline Tensor round(const Tensor& self) {
|
||||
return torch::special_round(self);
|
||||
}
|
||||
|
||||
inline Tensor& round_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_round_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes log(1 + x) of the input, elementwise
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.log1p.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// auto t = torch::randn(128, dtype=kDouble);
|
||||
/// torch::special::log1p(t);
|
||||
/// ```
|
||||
inline Tensor log1p(const Tensor& self) {
|
||||
return torch::special_log1p(self);
|
||||
}
|
||||
|
||||
inline Tensor& log1p_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_log1p_out(result, self);
|
||||
}
|
||||
|
||||
}} // torch::special
|
||||
|
@ -117,6 +117,9 @@ const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
|
||||
{aten::special_exp2, aten::exp2},
|
||||
{aten::special_expm1, aten::expm1},
|
||||
{aten::special_logit, aten::logit},
|
||||
{aten::special_round, aten::round},
|
||||
{aten::special_log1p, aten::log1p},
|
||||
{aten::special_sinc, aten::sinc},
|
||||
{aten::special_digamma, aten::digamma},
|
||||
{aten::special_psi, aten::digamma},
|
||||
{aten::special_i0, aten::i0},
|
||||
|
@ -882,6 +882,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.special.i1: lambda input: -1,
|
||||
torch.special.i1e: lambda input: -1,
|
||||
torch.special.logit: lambda input: -1,
|
||||
torch.special.log1p: lambda input: -1,
|
||||
torch.special.round: lambda input: -1,
|
||||
torch.special.sinc: lambda input: -1,
|
||||
torch.special.ndtri: lambda input: -1,
|
||||
torch.special.ndtr: lambda input: -1,
|
||||
torch.special.xlog1py: lambda input, other, out=None: -1,
|
||||
|
@ -438,3 +438,45 @@ Example::
|
||||
>>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1]))
|
||||
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
|
||||
""".format(**common_args))
|
||||
|
||||
log1p = _add_docstr(_special.special_log1p,
|
||||
r"""
|
||||
log1p(input, *, out=None) -> Tensor
|
||||
|
||||
Alias for :func:`torch.log1p`.
|
||||
""")
|
||||
|
||||
sinc = _add_docstr(_special.special_sinc,
|
||||
r"""
|
||||
sinc(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the normalized sinc of :attr:`input.`
|
||||
|
||||
.. math::
|
||||
\text{out}_{i} =
|
||||
\begin{cases}
|
||||
1, & \text{if}\ \text{input}_{i}=0 \\
|
||||
\sin(\pi \text{input}_{i}) / (\pi \text{input}_{i}), & \text{otherwise}
|
||||
\end{cases}
|
||||
""" + r"""
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
>>> t = torch.randn(4)
|
||||
>>> t
|
||||
tensor([ 0.2252, -0.2948, 1.0267, -1.1566])
|
||||
>>> torch.special.sinc(t)
|
||||
tensor([ 0.9186, 0.8631, -0.0259, -0.1300])
|
||||
""".format(**common_args))
|
||||
|
||||
round = _add_docstr(_special.special_round,
|
||||
r"""
|
||||
round(input, *, out=None) -> Tensor
|
||||
|
||||
Alias for :func:`torch.round`.
|
||||
""")
|
||||
|
@ -5819,6 +5819,7 @@ op_db: List[OpInfo] = [
|
||||
)),
|
||||
UnaryUfuncInfo('log1p',
|
||||
ref=np.log1p,
|
||||
aliases=('special.log1p',),
|
||||
domain=(-1, float('inf')),
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
|
||||
@ -6290,6 +6291,7 @@ op_db: List[OpInfo] = [
|
||||
sample_inputs_func=sample_inputs_rot90),
|
||||
UnaryUfuncInfo('round',
|
||||
ref=np.round,
|
||||
aliases=('special.round',),
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
||||
assert_autodiffed=True,),
|
||||
@ -6304,6 +6306,7 @@ op_db: List[OpInfo] = [
|
||||
decorators=(precisionOverride({torch.bfloat16: 1e-2}),)),
|
||||
UnaryUfuncInfo('sinc',
|
||||
ref=np_sinc_with_fp16_as_fp32,
|
||||
aliases=('special.sinc',),
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
handles_large_floats=False,
|
||||
|
Reference in New Issue
Block a user