Files
pytorch/torch/testing/_internal/opinfo/definitions/special.py
Deng, Daisy 4fd70d4e7b [1/N]Enable some tests in test_ops.TestCommon on Intel GPU (#159944)
For https://github.com/pytorch/pytorch/issues/114850, we will port aten unit tests to Intel GPU. This PR will work on some test case of test/test_ops.py. We could enable Intel GPU with following methods and try the best to keep the original code styles:

1. Extended XPUTestBase.get_all_devices to support multiple devices
2. Added skipXPU decorator
3. Extended onlyOn to support device list
4. Enabled 'xpu' for some test pathes
5. Added allow_xpu=True for supported test class.
6. Replaced onlyCUDA with onlyOn(['cuda', 'xpu']) for supported tests
7. Use skipIfXpu and skipXPU to disable unsupported test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159944
Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/albanD
2025-09-29 09:08:04 +00:00

786 lines
26 KiB
Python

# mypy: ignore-errors
import unittest
from functools import partial
from itertools import product
import numpy as np
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
precisionOverride,
tol,
toleranceOverride,
)
from torch.testing._internal.common_dtype import all_types_and, floating_types
from torch.testing._internal.common_utils import TEST_SCIPY, torch_to_numpy_dtype_dict
from torch.testing._internal.opinfo.core import (
BinaryUfuncInfo,
DecorateInfo,
L,
NumericsFilter,
OpInfo,
S,
SampleInput,
UnaryUfuncInfo,
)
from torch.testing._internal.opinfo.refs import (
ElementwiseBinaryPythonRefInfo,
ElementwiseUnaryPythonRefInfo,
)
from torch.testing._internal.opinfo.utils import (
np_unary_ufunc_integer_promotion_wrapper,
)
if TEST_SCIPY:
import scipy.special
# TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`,
# supports `exclude` argument.
# For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617
def sample_inputs_i0_i1(op_info, device, dtype, requires_grad, **kwargs):
exclude_zero = requires_grad and op_info.op == torch.special.i0e
make_arg = partial(
make_tensor,
dtype=dtype,
device=device,
requires_grad=requires_grad,
exclude_zero=exclude_zero,
)
yield SampleInput(make_arg((S,)))
yield SampleInput(make_arg(()))
if requires_grad and not exclude_zero:
# Special Case for gradient
# Sample with `0` in the input
t = make_arg((S,))
t[0] = 0
yield SampleInput(t)
def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(
make_tensor,
device=device,
# TODO: eliminate low after gh-106692 is fixed:
low=(1 if dtype in {torch.int32, torch.int64} else None),
dtype=dtype,
requires_grad=requires_grad,
)
tensor_shapes = ((S, S), ())
ns = (1, 2, 3, 4, 5)
for shape, n in product(tensor_shapes, ns):
yield SampleInput(make_arg(shape), args=(n,))
def reference_polygamma(x, n):
# WEIRD `scipy.special.polygamma` behavior
# >>> scipy.special.polygamma(0, np.array(501, dtype=np.float32)).dtype
# dtype('float64')
# >>> scipy.special.polygamma(0, np.array([501], dtype=np.float32)).dtype
# dtype('float32')
#
# Thus we cast output to the default torch dtype or preserve double
result_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
if x.dtype == np.double:
result_dtype = np.double
return scipy.special.polygamma(n, x).astype(result_dtype)
def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
low, _ = op_info.domain
if requires_grad:
low = 0 + op_info._domain_eps
make_arg = partial(
make_tensor, dtype=dtype, device=device, low=low, requires_grad=requires_grad
)
yield SampleInput(make_arg((L,)))
yield SampleInput(make_arg(()))
def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs):
for shape in ((L,), (1, 0, 3), ()):
yield SampleInput(
make_tensor(
shape,
device=device,
dtype=dtype,
low=-5,
requires_grad=requires_grad,
),
)
op_db: list[OpInfo] = [
UnaryUfuncInfo(
"special.i0e",
aten_name="special_i0e",
ref=scipy.special.i0e if TEST_SCIPY else None,
decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_i0_i1,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
UnaryUfuncInfo(
"special.i1",
aten_name="special_i1",
ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1)
if TEST_SCIPY
else None,
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
backward_dtypes=floating_types(),
sample_inputs_func=sample_inputs_i0_i1,
decorators=(
DecorateInfo(
toleranceOverride(
{
torch.float32: tol(atol=1e-4, rtol=0),
torch.bool: tol(atol=1e-4, rtol=0),
}
)
),
),
skips=(
DecorateInfo(
unittest.skip("Incorrect result!"),
"TestUnaryUfuncs",
"test_reference_numerics_large",
dtypes=(torch.int8,),
),
),
supports_fwgrad_bwgrad=True,
supports_forward_ad=True,
),
UnaryUfuncInfo(
"special.i1e",
aten_name="special_i1e",
ref=scipy.special.i1e if TEST_SCIPY else None,
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
backward_dtypes=floating_types(),
sample_inputs_func=sample_inputs_i0_i1,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
UnaryUfuncInfo(
"special.ndtr",
aten_name="special_ndtr",
decorators=(precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-4}),),
ref=scipy.special.ndtr if TEST_SCIPY else None,
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
# Dispatch stub: unsupported device typemeta
DecorateInfo(
unittest.expectedFailure,
"TestFwdGradients",
"test_fn_fwgrad_bwgrad",
device_type="meta",
),
),
),
# A separate OpInfo entry for special.polygamma is needed to reorder the arguments
# for the alias. See the discussion here: https://github.com/pytorch/pytorch/pull/59691#discussion_r650261939
UnaryUfuncInfo(
"special.polygamma",
op=lambda x, n, **kwargs: torch.special.polygamma(n, x, **kwargs),
variant_test_name="special_polygamma_n_0",
ref=reference_polygamma if TEST_SCIPY else None,
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_polygamma,
skips=(
# lambda impl
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
),
sample_kwargs=lambda device, dtype, input: ({"n": 0}, {"n": 0}),
# polygamma functions have multiple singularities at x having non-positive integer value
reference_numerics_filter=NumericsFilter(
condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), safe_val=1
),
),
BinaryUfuncInfo(
"special.xlog1py",
aten_name="special_xlog1py",
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
promotes_int_to_float=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_one_python_scalar=True,
# We don't test -1 as the gradient will be NaN and it'll break
rhs_make_tensor_kwargs=dict(low=-0.99),
),
BinaryUfuncInfo(
"special.zeta",
aten_name="special_zeta",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
supports_autograd=False,
supports_one_python_scalar=True,
skips=(
# Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
),
# TODO: FIXME
# OpInfo entry to verify the gradient formula of `other`/`q`
# BinaryUfuncInfo('special.zeta',
# op=lambda q, x, **kwargs: torch.special.zeta(x, q, **kwargs),
# aten_name='special_zeta',
# variant_test_name='grad',
# dtypes=all_types_and(torch.bool),
# promotes_int_to_float=True,
# supports_autograd=True,
# supports_rhs_python_scalar=False,
# decorators=[
# # Derivative wrt first tensor not implemented
# DecorateInfo(unittest.expectedFailure, "TestCommon",
# "test_floating_inputs_are_differentiable")
# ],
# skips=(
# # Lambda doesn't work in JIT test
# # AssertionError: JIT Test does not execute any logic
# DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"),
# )),
UnaryUfuncInfo(
"special.entr",
ref=scipy.special.entr if TEST_SCIPY else None,
aten_name="special_entr",
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
skips=(
DecorateInfo(
unittest.skip("Skipped!"),
"TestUnaryUfuncs",
"test_reference_numerics_large",
dtypes=[torch.bfloat16, torch.float16],
),
),
supports_inplace_autograd=False,
sample_inputs_func=sample_inputs_entr,
),
UnaryUfuncInfo(
"special.ndtri",
ref=scipy.special.ndtri if TEST_SCIPY else None,
domain=(0, 1),
aten_name="special_ndtri",
dtypes=all_types_and(torch.bool),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
UnaryUfuncInfo(
"special.log_ndtr",
aten_name="special_log_ndtr",
ref=scipy.special.log_ndtr if TEST_SCIPY else None,
dtypes=all_types_and(torch.bool),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
),
UnaryUfuncInfo(
"special.erfcx",
ref=scipy.special.erfcx if TEST_SCIPY else None,
aten_name="special_erfcx",
decorators=(
toleranceOverride(
{
torch.float32: tol(atol=0, rtol=4e-6),
}
),
),
dtypes=all_types_and(torch.bool),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_erfcx,
),
UnaryUfuncInfo(
"special.airy_ai",
decorators=(
precisionOverride(
{
torch.float32: 1e-03,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=lambda x: scipy.special.airy(x)[0] if TEST_SCIPY else None,
skips=(
DecorateInfo(
unittest.skip("Skipped!"),
"TestUnaryUfuncs",
"test_reference_numerics_large",
),
),
supports_autograd=False,
),
UnaryUfuncInfo(
"special.bessel_j0",
decorators=(
precisionOverride(
{
torch.float32: 1e-04,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.j0 if TEST_SCIPY else None,
supports_autograd=False,
),
UnaryUfuncInfo(
"special.bessel_j1",
decorators=(
precisionOverride(
{
torch.float32: 1e-04,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.j1 if TEST_SCIPY else None,
supports_autograd=False,
),
UnaryUfuncInfo(
"special.bessel_y0",
decorators=(
precisionOverride(
{
torch.float32: 1e-04,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.y0 if TEST_SCIPY else None,
supports_autograd=False,
),
UnaryUfuncInfo(
"special.bessel_y1",
decorators=(
precisionOverride(
{
torch.float32: 1e-04,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.y1 if TEST_SCIPY else None,
supports_autograd=False,
),
BinaryUfuncInfo(
"special.chebyshev_polynomial_t",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
),
BinaryUfuncInfo(
"special.chebyshev_polynomial_u",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
),
BinaryUfuncInfo(
"special.chebyshev_polynomial_v",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
),
BinaryUfuncInfo(
"special.chebyshev_polynomial_w",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
),
BinaryUfuncInfo(
"special.hermite_polynomial_h",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: inf
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
# Too slow
DecorateInfo(
unittest.skip, "TestCommon", "test_compare_cpu", device_type="xpu"
),
),
supports_one_python_scalar=True,
supports_autograd=False,
),
BinaryUfuncInfo(
"special.hermite_polynomial_he",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: inf
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
),
BinaryUfuncInfo(
"special.laguerre_polynomial_l",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
# Too slow
DecorateInfo(
unittest.skip, "TestCommon", "test_compare_cpu", device_type="xpu"
),
),
supports_one_python_scalar=True,
supports_autograd=False,
),
BinaryUfuncInfo(
"special.legendre_polynomial_p",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
),
UnaryUfuncInfo(
"special.modified_bessel_i0",
decorators=(
precisionOverride(
{
torch.float32: 1e-03,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.i0 if TEST_SCIPY else None,
supports_autograd=False,
),
UnaryUfuncInfo(
"special.modified_bessel_i1",
decorators=(
precisionOverride(
{
torch.float32: 1e-03,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.i1 if TEST_SCIPY else None,
supports_autograd=False,
),
UnaryUfuncInfo(
"special.modified_bessel_k0",
decorators=(
precisionOverride(
{
torch.float32: 1e-03,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.k0 if TEST_SCIPY else None,
supports_autograd=False,
),
UnaryUfuncInfo(
"special.modified_bessel_k1",
decorators=(
precisionOverride(
{
torch.float32: 1e-03,
torch.float64: 1e-05,
},
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.k1 if TEST_SCIPY else None,
supports_autograd=False,
),
UnaryUfuncInfo(
"special.scaled_modified_bessel_k0",
decorators=(
toleranceOverride(
{
torch.float32: tol(atol=1e-03, rtol=1e-03),
torch.float64: tol(atol=1e-05, rtol=1e-03),
}
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.k0e if TEST_SCIPY else None,
supports_autograd=False,
),
UnaryUfuncInfo(
"special.scaled_modified_bessel_k1",
decorators=(
toleranceOverride(
{
torch.float32: tol(atol=1e-03, rtol=1e-03),
torch.float64: tol(atol=1e-05, rtol=1e-03),
}
),
),
dtypes=all_types_and(torch.bool),
ref=scipy.special.k1e if TEST_SCIPY else None,
supports_autograd=False,
),
BinaryUfuncInfo(
"special.shifted_chebyshev_polynomial_t",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
),
BinaryUfuncInfo(
"special.shifted_chebyshev_polynomial_u",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
),
BinaryUfuncInfo(
"special.shifted_chebyshev_polynomial_v",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
),
BinaryUfuncInfo(
"special.shifted_chebyshev_polynomial_w",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
supports_one_python_scalar=True,
supports_autograd=False,
),
UnaryUfuncInfo(
"special.spherical_bessel_j0",
decorators=(
toleranceOverride(
{
torch.float32: tol(atol=1e-03, rtol=1e-03),
torch.float64: tol(atol=1e-05, rtol=1e-03),
}
),
),
dtypes=all_types_and(torch.bool),
ref=lambda x: scipy.special.spherical_jn(0, x) if TEST_SCIPY else None,
supports_autograd=False,
),
]
python_ref_db: list[OpInfo] = [
#
# Elementwise Unary Special OpInfos
#
ElementwiseUnaryPythonRefInfo(
"_refs.special.bessel_j0",
torch_opinfo_name="special.bessel_j0",
op_db=op_db,
decorators=(
precisionOverride(
{
torch.float32: 1e-04,
torch.float64: 1e-05,
},
),
),
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.bessel_j1",
torch_opinfo_name="special.bessel_j1",
op_db=op_db,
decorators=(
precisionOverride(
{
torch.float32: 1e-04,
torch.float64: 1e-05,
},
),
),
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.entr",
torch_opinfo_name="special.entr",
op_db=op_db,
decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),),
skips=(
DecorateInfo(
unittest.skip("Skipped!"),
"TestUnaryUfuncs",
"test_reference_numerics_large",
dtypes=[torch.bfloat16, torch.float16],
),
),
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.erfcx",
torch_opinfo_name="special.erfcx",
op_db=op_db,
decorators=(
toleranceOverride(
{
torch.float32: tol(atol=0, rtol=4e-6),
}
),
),
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.i0e",
torch_opinfo_name="special.i0e",
op_db=op_db,
decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),),
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.i1",
torch_opinfo_name="special.i1",
op_db=op_db,
decorators=(
DecorateInfo(
toleranceOverride(
{
torch.float32: tol(atol=1e-4, rtol=0),
torch.bool: tol(atol=1e-4, rtol=0),
}
)
),
),
skips=(
DecorateInfo(
unittest.skip("Incorrect result!"),
"TestUnaryUfuncs",
"test_reference_numerics_large",
dtypes=(torch.int8,),
),
),
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.i1e",
torch_opinfo_name="special.i1e",
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.log_ndtr",
torch_opinfo_name="special.log_ndtr",
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.ndtr",
torch_opinfo_name="special.ndtr",
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.ndtri",
torch_opinfo_name="special.ndtri",
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.spherical_bessel_j0",
torch_opinfo_name="special.spherical_bessel_j0",
op_db=op_db,
decorators=(
toleranceOverride(
{
torch.float32: tol(atol=1e-03, rtol=1e-03),
torch.float64: tol(atol=1e-05, rtol=1e-03),
}
),
),
),
#
# Elementwise Binary Special OpInfos
#
ElementwiseBinaryPythonRefInfo(
"_refs.special.zeta",
torch_opinfo_name="special.zeta",
supports_one_python_scalar=True,
op_db=op_db,
skips=(
# Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
),
),
]