From d17b144e6564f10f55af639fbf2dd82eaacdfa3e Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 20 Sep 2022 12:40:28 +0000 Subject: [PATCH] Adding multigammaln ref and fix arange (#85153) Partially based on https://github.com/pytorch/pytorch/pull/83662. I'll help land this one, as Rob does not work in the PyTorch project anymore I removed the data-dependent check for the args, as data dependencies are bad for many reasons (and it was failing when the input has NaNs). It also registers arange as a decomposition, and fixes the naming of its args. Pull Request resolved: https://github.com/pytorch/pytorch/pull/85153 Approved by: https://github.com/mruberry, https://github.com/ngimel --- aten/src/ATen/native/UnaryOps.cpp | 3 +- test/test_meta.py | 6 - test/test_proxy_tensor.py | 3 - test/test_unary_ufuncs.py | 10 -- torch/_decomp/decompositions.py | 2 +- torch/_refs/__init__.py | 81 ++++-------- torch/_refs/special/__init__.py | 14 +++ torch/special/__init__.py | 4 +- .../_internal/common_methods_invocations.py | 116 ++++++++++-------- 9 files changed, 106 insertions(+), 133 deletions(-) diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 160955a01350..07c158fc7348 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -723,8 +723,7 @@ constexpr double QUARTER = 0.25; } static inline void mvlgamma_check(const Tensor& self, int64_t p) { - TORCH_CHECK((self > HALF * (p - 1)).all().item(), - "All elements must be greater than (p-1)/2"); + TORCH_CHECK(self.scalar_type() != kBool, "The input tensor may not be a boolean tensor."); TORCH_CHECK(p >= 1, "p has to be greater than or equal to 1"); } diff --git a/test/test_meta.py b/test/test_meta.py index cd912f67a22b..f4c5137e8968 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -423,7 +423,6 @@ meta_function_expected_failures = { torch.median : {f64, i32, i64, u8, i16, bf16, i8, f32}, torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32}, torch.multinomial : {f64, bf16, f32}, - torch.mvlgamma : {f64, i32, i64, u8, i16, bf16, i8, f32}, torch.nn.functional.ctc_loss : {f64, f32}, torch.nn.functional.gaussian_nll_loss : {f64, bf16, f32}, torch.nn.functional.max_pool3d : {f64, f32}, @@ -543,7 +542,6 @@ meta_function_device_expected_failures['cuda'] = { torch.matrix_exp: {f16}, # aten::linalg_matrix_exp torch.median: {f16}, # aten::median, aten::median.dim_values torch.multinomial: {f16}, # aten::multinomial, aten::multinomial.out - torch.mvlgamma: {f16}, # aten::_local_scalar_dense, aten::mvlgamma.out torch.nn.functional.gaussian_nll_loss: {f16}, # aten::_local_scalar_dense torch.nn.functional.max_pool3d: {bf16, f16}, # aten::max_pool3d_with_indices torch.nn.functional.max_pool3d_with_indices: {bf16, f16}, # aten::max_pool3d_with_indices @@ -687,8 +685,6 @@ meta_dispatch_expected_failures = { aten.multilabel_margin_loss_forward.default : {f32, f64}, aten.multinomial.default : {bf16, f32, f64}, aten.multinomial.out : {bf16, f32, f64}, - aten.mvlgamma.default : {i8, f64, i64, bf16, f32, i32, i16, u8}, - aten.mvlgamma.out : {i8, f64, i64, bf16, f32, i32, i16, u8}, aten.nll_loss2d_forward.default : {bf16, f32, f64}, aten.polar.default : {f32, f64}, aten.rrelu_with_noise.default : {bf16, f32, f64}, @@ -745,8 +741,6 @@ meta_dispatch_device_expected_failures['cuda'] = { aten.multilabel_margin_loss_forward.default: {bf16, f16}, # aten::multilabel_margin_loss_forward aten.multinomial.default: {f16}, # aten::multinomial aten.multinomial.out: {f16}, # aten::multinomial.out - aten.mvlgamma.default: {f16}, # aten::_local_scalar_dense - aten.mvlgamma.out: {f16}, # aten::mvlgamma.out aten.native_group_norm.default: {bf16, f16}, aten.nll_loss2d_forward.default: {f16}, # aten::nll_loss2d_forward aten.ormqr.default: {f32, f64}, # aten::ormqr diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index bc837738aa9c..3e421514f1de 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -997,9 +997,6 @@ fake_tensor_failures = { # FakeTensor fallback doesn't work xfail('segment_reduce', 'lengths'), xfail('multinomial'), - xfail('mvlgamma', 'mvlgamma_p_1'), - xfail('mvlgamma', 'mvlgamma_p_3'), - xfail('mvlgamma', 'mvlgamma_p_5'), xfail('cholesky'), xfail('cholesky_inverse'), # ASAN failures due to divide by 0 diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 43330b953a81..5a9bdb53ab6b 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -624,16 +624,6 @@ class TestUnaryUfuncs(TestCase): ): torch.frexp(input, out=(mantissa, exponent)) - def test_mvlgamma_argcheck(self, device): - def run_test(d): - input = torch.linspace((d - 2) / 2, 10, 10, device=device) - torch.mvlgamma(input, d) - - with self.assertRaisesRegex( - RuntimeError, r"All elements must be greater than \(p-1\)/2" - ): - run_test(3) - def test_polygamma_neg(self, device): with self.assertRaisesRegex( RuntimeError, r"polygamma\(n, x\) does not support negative n\." diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 5adf402967d9..1d09c0b9adf1 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1579,7 +1579,7 @@ def index_add_( utils.is_weakly_lesser_type(type(alpha), python_type), lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", ) - tensor = torch._prims.mul(tensor, alpha) + tensor = tensor * alpha idx = (slice(None),) * dim + (index,) torch.ops.aten.index_put_(x, idx, tensor, accumulate=True) return x diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index ff4cf53903f1..a9b7b1de3756 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -610,6 +610,10 @@ def lgamma(a): return prims.lgamma(a) +# alias +mvlgamma = torch.special.multigammaln # type: ignore[has-type] + + @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) def log(a): return prims.log(a) @@ -3707,37 +3711,6 @@ def empty_like( ) -@overload -def arange( - end: NumberType, - *, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - layout: torch.layout = torch.strided, - pin_memory: bool = False, - requires_grad: bool = False, -) -> TensorLikeType: - pass - - -@overload -def arange( - start: NumberType, - end: NumberType, - step: NumberType = 1, - *, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - layout: torch.layout = torch.strided, - pin_memory: bool = False, - requires_grad: bool = False, -) -> TensorLikeType: - pass - - -# See https://github.com/pytorch/pytorch/issues/82364 -# @register_decomposition(torch.ops.aten.arange) -# @out_wrapper() @register_decomposition( [ torch.ops.aten.arange.default, @@ -3745,9 +3718,10 @@ def arange( torch.ops.aten.arange.start_step, ] ) +@out_wrapper() def arange( - a: Optional[NumberType] = None, - b: Optional[NumberType] = None, + start: NumberType = 0, + end: Optional[NumberType] = None, step: NumberType = 1, *, dtype: Optional[torch.dtype] = None, @@ -3756,31 +3730,22 @@ def arange( pin_memory: bool = False, requires_grad: bool = False, ) -> TensorLikeType: - assert (a is not None and b is not None) or (a is not None and b is None) - if a is not None and b is not None: - return prims.arange( - a, - b, - step, - dtype=dtype, - device=device, - # layout=layout, - # pin_memory=pin_memory, - requires_grad=requires_grad, - ) - elif a is not None and b is None: - return prims.arange( - 0, - a, - step, - dtype=dtype, - device=device, - # layout=layout, - # pin_memory=pin_memory, - requires_grad=requires_grad, - ) - else: - raise AssertionError() + assert not pin_memory + assert layout == torch.strided + # Case: torch.arange(5) + if end is None: + end = start + start = 0 + return prims.arange( + start, + end, + step, + dtype=dtype, + device=device, + # layout=layout, + # pin_memory=pin_memory, + requires_grad=requires_grad, + ) @register_decomposition(torch.ops.aten.linspace) diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py index 7eec32a17a5f..8aee704f2272 100644 --- a/torch/_refs/special/__init__.py +++ b/torch/_refs/special/__init__.py @@ -1,3 +1,4 @@ +import math from typing import Optional import torch @@ -20,6 +21,7 @@ __all__ = [ "i1", "i1e", "logit", + "multigammaln", "zeta", ] @@ -60,6 +62,18 @@ def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType: return torch.log(torch.true_divide(self, torch.sub(1, self))) +@register_decomposition(torch.ops.aten.mvlgamma) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType: + c = 0.25 * p * (p - 1) * math.log(math.pi) + b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device) + return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c + + zeta = _make_elementwise_binary_reference( prims.zeta, # type: ignore[has-type] type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, diff --git a/torch/special/__init__.py b/torch/special/__init__.py index 224e262c1ef6..a25f0f7c0368 100644 --- a/torch/special/__init__.py +++ b/torch/special/__init__.py @@ -761,9 +761,9 @@ Computes the `multivariate log-gamma function .. math:: \log(\Gamma_{p}(a)) = C + \displaystyle \sum_{i=1}^{p} \log\left(\Gamma\left(a - \frac{i - 1}{2}\right)\right) -where :math:`C = \log(\pi) \times \frac{p (p - 1)}{4}` and :math:`\Gamma(\cdot)` is the Gamma function. +where :math:`C = \log(\pi) \cdot \frac{p (p - 1)}{4}` and :math:`\Gamma(-)` is the Gamma function. -All elements must be greater than :math:`\frac{p - 1}{2}`, otherwise an error would be thrown. +All elements must be greater than :math:`\frac{p - 1}{2}`, otherwise the behavior is undefiend. """ + """ Args: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b285a8d248a8..4337dd3c5297 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -740,7 +740,10 @@ def sample_inputs_arange(op, device, dtype, requires_grad, **kwargs): for start, end, step in samples: if start is None: assert step is None + # Pass end as positional arg yield SampleInput(end, kwargs={"dtype": dtype, "device": device}) + # (Similar to) calling torch.arange(end=3) + yield SampleInput(0, kwargs={"end": end, "dtype": dtype, "device": device}) elif step is None: yield SampleInput(start, args=(end,), kwargs={"dtype": dtype, "device": device}) else: @@ -5670,25 +5673,20 @@ def skips_mvlgamma(skip_redundant=False): # To test reference numerics against multiple values of argument `p`, # we make multiple OpInfo entries with each entry corresponding to different value of p. # We run the op tests from test_ops.py only for `p=1` to avoid redundancy in testing. -# Class `MvlGammaInfo` already contains the basic information related to the operator, -# it only takes arguments like `domain`, `skips` and `sample_kwargs`, which -# differ between the entries. -class MvlGammaInfo(UnaryUfuncInfo): - def __init__(self, variant_test_name, domain, skips, sample_kwargs): - super(MvlGammaInfo, self).__init__( - 'mvlgamma', - ref=reference_mvlgamma if TEST_SCIPY else None, - aliases=('special.multigammaln',), - variant_test_name=variant_test_name, - domain=domain, - decorators=(precisionOverride({torch.float16: 5e-2}),), - dtypes=all_types_and(torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.half), - sample_inputs_func=sample_inputs_mvlgamma, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - skips=skips, - sample_kwargs=sample_kwargs) +def make_mvlgamma_opinfo(variant_test_name, domain, skips, sample_kwargs): + return UnaryUfuncInfo('mvlgamma', + ref=reference_mvlgamma if TEST_SCIPY else None, + aliases=('special.multigammaln',), + variant_test_name=variant_test_name, + domain=domain, + decorators=(precisionOverride({torch.float16: 5e-2}),), + dtypes=all_types_and(torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.float16), + sample_inputs_func=sample_inputs_mvlgamma, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + skips=skips, + sample_kwargs=sample_kwargs) def sample_inputs_cumulative_ops(op_info, device, dtype, requires_grad, supports_dtype_kwargs=True, **kwargs): @@ -12133,35 +12131,36 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), ), sample_inputs_func=sample_inputs_mode,), - MvlGammaInfo(variant_test_name='mvlgamma_p_1', - domain=(1, None), - skips=skips_mvlgamma() + \ - (DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', - dtypes=(torch.float16, torch.int8)), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', - dtypes=(torch.int8,)),), - sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})), - MvlGammaInfo(variant_test_name='mvlgamma_p_3', - domain=(2, None), - skips=skips_mvlgamma(skip_redundant=True) + ( - DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', - dtypes=(torch.float16, torch.int8)), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', - dtypes=(torch.int8,)), - ), - sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})), - MvlGammaInfo(variant_test_name='mvlgamma_p_5', - domain=(3, None), - skips=skips_mvlgamma(skip_redundant=True) + ( - DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', - dtypes=(torch.float16, torch.int8)), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', - dtypes=(torch.int8,)), - ), - sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})), + make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_1', + domain=(1, None), + skips=skips_mvlgamma() + ( + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.float16, torch.int8)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=(torch.int8,)), + ), + sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})), + make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_3', + domain=(2, None), + skips=skips_mvlgamma() + ( + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.float16, torch.int8)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=(torch.int8,)), + ), + sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})), + make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_5', + domain=(3, None), + skips=skips_mvlgamma() + ( + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.float16, torch.int8)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', + dtypes=(torch.int8,)), + ), + sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})), BinaryUfuncInfo('ne', ref=np.not_equal, aliases=('not_equal',), @@ -16242,9 +16241,6 @@ python_ref_db = [ DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), - # See https://github.com/pytorch/pytorch/issues/82364 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), # Prims arange does not follow aten DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', @@ -16486,6 +16482,24 @@ python_ref_db = [ "_refs.lgamma", torch_opinfo_name="lgamma", ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.multigammaln", + torch_opinfo_name="mvlgamma", + torch_opinfo_variant_name="mvlgamma_p_1", + supports_nvfuser=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.multigammaln", + torch_opinfo_name="mvlgamma", + torch_opinfo_variant_name="mvlgamma_p_3", + supports_nvfuser=False, + ), + ElementwiseUnaryPythonRefInfo( + "_refs.special.multigammaln", + torch_opinfo_name="mvlgamma", + torch_opinfo_variant_name="mvlgamma_p_5", + supports_nvfuser=False, + ), ElementwiseUnaryPythonRefInfo( "_refs.log", torch_opinfo_name="log",