[special] Alias igamma, igammac to special.gammaninc, special.gammaincc (#61902)

Summary:
Reference: https://github.com/pytorch/pytorch/issues/50345

Also added relevant OpInfo

TODO:
* [x] Check rendered docs gammainc : https://docs-preview.pytorch.org/61902/special.html#torch.special.gammainc
* [x] Check rendered docs gammaincc: https://docs-preview.pytorch.org/61902/special.html#torch.special.gammaincc

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61902

Reviewed By: ngimel

Differential Revision: D30761428

Pulled By: mruberry

fbshipit-source-id: 06a16432873357958d53364f12a4e91c29779d26
This commit is contained in:
kshitij12345
2021-09-07 15:24:08 -07:00
committed by Facebook GitHub Bot
parent b01d2d1d3e
commit 2c351c76e0
14 changed files with 247 additions and 102 deletions

View File

@ -382,10 +382,6 @@ _(aten, hsplit) \
_(aten, hstack) \
_(aten, hypot) \
_(aten, i0_) \
_(aten, igamma) \
_(aten, igamma_) \
_(aten, igammac) \
_(aten, igammac_) \
_(aten, ifft) \
_(aten, index) \
_(aten, index_add) \

View File

@ -371,6 +371,12 @@ namespace c10 {
_(aten, log_softmax) \
_(aten, special_log_softmax) \
_(aten, special_zeta) \
_(aten, igamma) \
_(aten, igamma_) \
_(aten, special_gammainc) \
_(aten, igammac) \
_(aten, igammac_) \
_(aten, special_gammaincc) \
_(aten, mvlgamma) \
_(aten, special_multigammaln) \
_(aten, has_torch_function) \

View File

@ -387,6 +387,22 @@ Tensor& special_zeta_out(const Tensor& self, const Scalar& other, Tensor& result
return at::special_zeta_out(result, self, wrapped_scalar_tensor(other));
}
Tensor& special_gammainc_out(const Tensor& self, const Tensor& other, Tensor& result) {
return at::igamma_out(result, self, other);
}
Tensor special_gammainc(const Tensor& self, const Tensor& other) {
return at::igamma(self, other);
}
Tensor& special_gammaincc_out(const Tensor& self, const Tensor& other, Tensor& result) {
return at::igammac_out(result, self, other);
}
Tensor special_gammaincc(const Tensor& self, const Tensor& other) {
return at::igammac(self, other);
}
TORCH_IMPL_FUNC(atan2_out) (const Tensor& self, const Tensor& other, const Tensor& result) {
atan2_stub(device_type(), *this);
}

View File

@ -10054,6 +10054,22 @@
python_module: special
variants: function
- func: special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
python_module: special
variants: function
- func: special_gammainc(Tensor self, Tensor other) -> Tensor
python_module: special
variants: function
- func: special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
python_module: special
variants: function
- func: special_gammaincc(Tensor self, Tensor other) -> Tensor
python_module: special
variants: function
- func: special_multigammaln(Tensor self, int p) -> Tensor
python_module: special
variants: function

View File

@ -23,6 +23,8 @@ Functions
.. autofunction:: expm1
.. autofunction:: exp2
.. autofunction:: gammaln
.. autofunction:: gammainc
.. autofunction:: gammaincc
.. autofunction:: polygamma
.. autofunction:: digamma
.. autofunction:: psi

View File

@ -2912,22 +2912,6 @@ class TestAutograd(TestCase):
requires_grad=True)
gradcheck(torch.sinc, a)
def test_igamma(self):
# 1e-3 offset to avoid zeros
# NOTE: derivative for s is not implemented
s = (torch.rand(100, dtype=torch.double) + 1e-3)
x = (torch.rand(100, dtype=torch.double) + 1e-3).requires_grad_()
gradcheck(torch.igamma, (s, x))
gradgradcheck(torch.igamma, (s, x))
def test_igammac(self):
# 1e-3 offset to avoid zeros in s
# NOTE: derivative for s is not implemented
s = (torch.rand(100, dtype=torch.double) + 1e-3)
x = (torch.rand(100, dtype=torch.double)).requires_grad_()
gradcheck(torch.igamma, (s, x))
gradgradcheck(torch.igamma, (s, x))
def test_profiler(self):
x = torch.randn(10, 10)

View File

@ -3062,6 +3062,8 @@ class TestOperatorSignatures(JitTestCase):
'expand_as',
'fill_',
'hstack',
'igamma',
'igammac',
'linalg.multi_dot',
'lu',
'norm',

View File

@ -1465,6 +1465,8 @@ class TestNormalizeOperators(JitTestCase):
"expand_as",
"fill_",
"gradient",
"igamma",
"igammac",
"index_put",
"polygamma",
"special.polygamma",

View File

@ -4336,92 +4336,15 @@ add_docstr(torch.igamma,
r"""
igamma(input, other, *, out=None) -> Tensor
Computes the regularized lower incomplete gamma function:
.. math::
\text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_0^{\text{other}_i} t^{\text{input}_i-1} e^{-t} dt
where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive
and at least one is strictly positive.
If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`.
:math:`\Gamma(\cdot)` in the equation above is the gamma function,
.. math::
\Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.
See :func:`torch.igammac` and :func:`torch.lgamma` for related functions.
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`
and float inputs.
.. note::
The backward pass with respect to :attr:`input` is not yet supported.
Please open an issue on PyTorch's Github to request it.
""" + r"""
Args:
input (Tensor): the first non-negative input tensor
other (Tensor): the second non-negative input tensor
Keyword args:
{out}
Example::
>>> a1 = torch.tensor([4.0])
>>> a2 = torch.tensor([3.0, 4.0, 5.0])
>>> a = torch.igammac(a1, a2)
tensor([0.3528, 0.5665, 0.7350])
tensor([0.3528, 0.5665, 0.7350])
>>> b = torch.igamma(a1, a2) + torch.igammac(a1, a2)
tensor([1., 1., 1.])
""".format(**common_args))
Alias for :func:`torch.special.gammainc`.
""")
add_docstr(torch.igammac,
r"""
igammac(input, other, *, out=None) -> Tensor
Computes the regularized upper incomplete gamma function:
.. math::
\text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_{\text{other}_i}^{\infty} t^{\text{input}_i-1} e^{-t} dt
where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive
and at least one is strictly positive.
If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`.
:math:`\Gamma(\cdot)` in the equation above is the gamma function,
.. math::
\Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.
See :func:`torch.igamma` and :func:`torch.lgamma` for related functions.
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`
and float inputs.
.. note::
The backward pass with respect to :attr:`input` is not yet supported.
Please open an issue on PyTorch's Github to request it.
""" + r"""
Args:
input (Tensor): the first non-negative input tensor
other (Tensor): the second non-negative input tensor
Keyword args:
{out}
Example::
>>> a1 = torch.tensor([4.0])
>>> a2 = torch.tensor([3.0, 4.0, 5.0])
>>> a = torch.igammac(a1, a2)
tensor([0.6472, 0.4335, 0.2650])
>>> b = torch.igamma(a1, a2) + torch.igammac(a1, a2)
tensor([1., 1., 1.])
""".format(**common_args))
Alias for :func:`torch.special.gammaincc`.
""")
add_docstr(torch.index_select,
r"""

View File

@ -21,6 +21,40 @@ inline Tensor& gammaln_out(Tensor& result, const Tensor& self) {
return torch::special_gammaln_out(result, self);
}
/// Computes the regularized lower incomplete gamma function
/// See https://pytorch.org/docs/master/special.html#torch.special.gammainc.
///
/// Example:
/// ```
/// auto t = torch::randn(128, dtype=kDouble);
/// auto s = torch::randn(128, dtype=kDouble);
/// torch::special::gammainc(s, t);
/// ```
inline Tensor gammainc(const Tensor& self, const Tensor& other) {
return torch::special_gammainc(self, other);
}
inline Tensor& gammainc_out(Tensor& result, const Tensor& self, const Tensor& other) {
return torch::special_gammainc_out(result, self, other);
}
/// Computes the regularized upper incomplete gamma function
/// See https://pytorch.org/docs/master/special.html#torch.special.gammainc.
///
/// Example:
/// ```
/// auto t = torch::randn(128, dtype=kDouble);
/// auto s = torch::randn(128, dtype=kDouble);
/// torch::special::gammaincc(s, t);
/// ```
inline Tensor gammaincc(const Tensor& self, const Tensor& other) {
return torch::special_gammaincc(self, other);
}
inline Tensor& gammaincc_out(Tensor& result, const Tensor& self, const Tensor& other) {
return torch::special_gammaincc_out(result, self, other);
}
/// Computes the multivariate log-gamma function with dimension `p`, elementwise
/// See https://pytorch.org/docs/master/special.html#torch.special.multigammaln.
///

View File

@ -131,6 +131,8 @@ const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
{aten::orgqr, aten::linalg_householder_product},
{aten::special_multigammaln, aten::mvlgamma},
{aten::special_polygamma, aten::polygamma},
{aten::special_gammainc, aten::igamma},
{aten::special_gammaincc, aten::igammac},
{aten::special_gammaln, aten::lgamma}};
return alias_map;
}

View File

@ -921,6 +921,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.special.polygamma: lambda input, n, out=None: -1,
torch.special.digamma: lambda input: -1,
torch.special.psi: lambda input: -1,
torch.special.gammainc: lambda input, other, out=None: -1,
torch.special.gammaincc: lambda input, other, out=None: -1,
torch.special.gammaln: lambda input: -1,
torch.special.i0: lambda input: -1,
torch.special.i0e: lambda input: -1,

View File

@ -5,7 +5,7 @@ from torch._torch_docs import common_args, multi_dim_common
__all__ = ['entr', 'psi', 'digamma', 'gammaln', 'polygamma', 'erf', 'erfc', 'erfinv',
'erfcx', 'logit', 'logsumexp', 'expit', 'exp2', 'expm1', 'xlog1py', 'xlogy',
'i0', 'i0e', 'i1', 'i1e', 'ndtr', 'ndtri', 'log1p', 'sinc', 'round', 'log_softmax',
'zeta', 'multigammaln']
'zeta', 'multigammaln', 'gammainc', 'gammaincc']
Tensor = torch.Tensor
@ -679,3 +679,94 @@ Example::
tensor([[0.3928, 0.4007, 0.7586],
[1.0311, 0.3901, 0.5049]])
""".format(**common_args))
gammainc = _add_docstr(_special.special_gammainc,
r"""
gammainc(input, other, *, out=None) -> Tensor
Computes the regularized lower incomplete gamma function:
.. math::
\text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_0^{\text{other}_i} t^{\text{input}_i-1} e^{-t} dt
where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive
and at least one is strictly positive.
If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`.
:math:`\Gamma(\cdot)` in the equation above is the gamma function,
.. math::
\Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.
See :func:`torch.special.gammaincc` and :func:`torch.special.gammaln` for related functions.
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`
and float inputs.
.. note::
The backward pass with respect to :attr:`input` is not yet supported.
Please open an issue on PyTorch's Github to request it.
""" + r"""
Args:
input (Tensor): the first non-negative input tensor
other (Tensor): the second non-negative input tensor
Keyword args:
{out}
Example::
>>> a1 = torch.tensor([4.0])
>>> a2 = torch.tensor([3.0, 4.0, 5.0])
>>> a = torch.special.gammaincc(a1, a2)
tensor([0.3528, 0.5665, 0.7350])
tensor([0.3528, 0.5665, 0.7350])
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
tensor([1., 1., 1.])
""".format(**common_args))
gammaincc = _add_docstr(_special.special_gammaincc,
r"""
gammaincc(input, other, *, out=None) -> Tensor
Computes the regularized upper incomplete gamma function:
.. math::
\text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_{\text{other}_i}^{\infty} t^{\text{input}_i-1} e^{-t} dt
where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive
and at least one is strictly positive.
If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`.
:math:`\Gamma(\cdot)` in the equation above is the gamma function,
.. math::
\Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.
See :func:`torch.special.gammainc` and :func:`torch.special.gammaln` for related functions.
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`
and float inputs.
.. note::
The backward pass with respect to :attr:`input` is not yet supported.
Please open an issue on PyTorch's Github to request it.
""" + r"""
Args:
input (Tensor): the first non-negative input tensor
other (Tensor): the second non-negative input tensor
Keyword args:
{out}
Example::
>>> a1 = torch.tensor([4.0])
>>> a2 = torch.tensor([3.0, 4.0, 5.0])
>>> a = torch.special.gammaincc(a1, a2)
tensor([0.6472, 0.4335, 0.2650])
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
tensor([1., 1., 1.])
""".format(**common_args))

View File

@ -2823,6 +2823,23 @@ def sample_inputs_outer(op_info, device, dtype, requires_grad, **kwargs):
inputs.append(SampleInput(arg_a, args=(arg_b,)))
return inputs
def sample_inputs_igamma_igammac(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, low=1e-3)
cases = (((S, S), (S, S), False),
((S, S), (S, ), False),
((S, ), (S, S), True),
((), (), False))
def generator():
for shape, other_shape, broadcasts_input in cases:
yield SampleInput(make_arg(shape, requires_grad=requires_grad),
args=(make_arg(other_shape, requires_grad=False),),
broadcasts_input=broadcasts_input)
return list(generator())
def sample_inputs_dist(op_info, device, dtype, requires_grad):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
sizes = ((S, S, S), (S,), (S, 1, S), (), (S, S))
@ -7577,6 +7594,58 @@ op_db: List[OpInfo] = [
# Topk is not raising a warning when the out is resized
SkipInfo('TestCommon', 'test_out'),
)),
# We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the
# standard entry, second is to run gradcheck tests on the second argument.
OpInfo('igamma',
dtypes=floating_types_and(torch.bfloat16, torch.float16),
aliases=('torch.special.gammainc',),
dtypesIfCUDA=floating_types(),
supports_autograd=False,
sample_inputs_func=sample_inputs_igamma_igammac),
OpInfo('igamma',
variant_test_name='grad_other',
# Since autograd formula is implemented only for other and
# gradcheck test verifies the formula for input in SampleInput,
# we permute the arguments.
op=lambda self, other, **kwargs: torch.igamma(other, self, **kwargs),
dtypes=floating_types_and(torch.bfloat16, torch.float16),
backward_dtypesIfCPU=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types(),
backward_dtypesIfCUDA=floating_types(),
supports_inplace_autograd=False,
skips=(
# test does not work with passing lambda for op
SkipInfo('TestJit', 'test_variant_consistency_jit'),
# test fails are we permute the arguments function variant
# but not for inplace or method.
SkipInfo('TestCommon', 'test_variant_consistency_eager'),
),
sample_inputs_func=sample_inputs_igamma_igammac),
OpInfo('igammac',
dtypes=floating_types_and(torch.bfloat16, torch.float16),
aliases=('torch.special.gammaincc',),
dtypesIfCUDA=floating_types(),
supports_autograd=False,
sample_inputs_func=sample_inputs_igamma_igammac),
OpInfo('igammac',
variant_test_name='grad_other',
# Since autograd formula is implemented only for other and
# gradcheck test verifies the formula for input in SampleInput,
# we permute the arguments
op=lambda self, other, **kwargs: torch.igammac(other, self, **kwargs),
dtypes=floating_types_and(torch.bfloat16, torch.float16),
backward_dtypesIfCPU=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types(),
backward_dtypesIfCUDA=floating_types(),
supports_inplace_autograd=False,
skips=(
# test does not work with passing lambda for op
SkipInfo('TestJit', 'test_variant_consistency_jit'),
# test fails are we permute the arguments function variant
# but not for inplace or method.
SkipInfo('TestCommon', 'test_variant_consistency_eager'),
),
sample_inputs_func=sample_inputs_igamma_igammac),
OpInfo('nn.functional.hardshrink',
aten_name="hardshrink",
dtypes=floating_types(),