mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add OpInfo for polygamma and remove torch_op_tests Infra (#51966)
Summary: Reference: https://github.com/pytorch/pytorch/issues/42515 * OpInfo entry for Polygamma * Removes infra of torch_op_tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/51966 Reviewed By: bdhirsh Differential Revision: D27851858 Pulled By: mruberry fbshipit-source-id: 7f1d0273065e1df56a152f95a14513959af29a1b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a661e58731
commit
df8bb5a42b
@ -5034,9 +5034,6 @@
|
||||
structured_delegate: digamma.out
|
||||
variants: method
|
||||
|
||||
- func: polygamma_(Tensor(a!) self, int n) -> Tensor(a!)
|
||||
variants: method
|
||||
|
||||
- func: renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!)
|
||||
variants: method
|
||||
dispatch:
|
||||
@ -5804,6 +5801,11 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: polygamma
|
||||
|
||||
- func: polygamma_(Tensor(a!) self, int n) -> Tensor(a!)
|
||||
variants: method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: polygamma_
|
||||
|
||||
- func: erfinv(Tensor self) -> Tensor
|
||||
structured_delegate: erfinv.out
|
||||
variants: method, function
|
||||
|
@ -2320,7 +2320,8 @@ class TestOperatorSignatures(JitTestCase):
|
||||
@onlyCPU
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
|
||||
known_no_schema = {'stack', 'hstack', 'vstack', 'dstack', 'repeat', '__getitem__', 'linalg.multi_dot'}
|
||||
known_no_schema = {'stack', 'hstack', 'vstack', 'dstack', 'repeat', '__getitem__', 'linalg.multi_dot',
|
||||
'polygamma'}
|
||||
|
||||
try:
|
||||
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
|
@ -11,8 +11,7 @@ import unittest
|
||||
from torch._six import inf, nan
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict,
|
||||
suppress_warnings, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy,
|
||||
gradcheck, IS_WINDOWS)
|
||||
suppress_warnings, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy, IS_WINDOWS)
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
unary_ufuncs, _NOTHING)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
@ -20,7 +19,7 @@ from torch.testing._internal.common_device_type import (
|
||||
onlyCUDA, dtypesIfCUDA, precisionOverride, skipCUDAIfRocm, dtypesIfCPU,
|
||||
OpDTypes)
|
||||
from torch.testing import (
|
||||
floating_types_and, all_types_and_complex_and, floating_types, floating_and_complex_types_and)
|
||||
floating_types_and, all_types_and_complex_and, floating_and_complex_types_and)
|
||||
|
||||
if TEST_SCIPY:
|
||||
import scipy
|
||||
@ -664,7 +663,6 @@ class TestUnaryUfuncs(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, r"All elements must be greater than \(p-1\)/2"):
|
||||
run_test(3)
|
||||
|
||||
# TODO opinfo polygamma
|
||||
def test_polygamma_neg(self, device):
|
||||
with self.assertRaisesRegex(RuntimeError, r'polygamma\(n, x\) does not support negative n\.'):
|
||||
torch.polygamma(-1, torch.tensor([1.0, 2.0], device=device))
|
||||
@ -1272,24 +1270,6 @@ class TestUnaryUfuncs(TestCase):
|
||||
t_copy.abs_()
|
||||
self.assertEqual(t, t_copy)
|
||||
|
||||
# Note: ROCm fails when using float tensors
|
||||
# TODO: update this test to just compare against NumPy
|
||||
@onlyCUDA
|
||||
@dtypes(torch.double)
|
||||
def test_polygamma(self, device, dtype):
|
||||
cpu_tensor = torch.randn(10, 10, 10, dtype=dtype)
|
||||
device_tensor = cpu_tensor.to(device)
|
||||
zeros = torch.zeros(10, 10, 10, dtype=dtype)
|
||||
for n in [0, 1, 2, 3, 4, 5]:
|
||||
cpu_out = cpu_tensor.polygamma(n)
|
||||
device_out = device_tensor.polygamma(n)
|
||||
norm_errors = (device_out - cpu_out.to(device)) / device_out
|
||||
self.assertEqual(norm_errors, zeros)
|
||||
|
||||
cpu_tensor.requires_grad = True
|
||||
for n in [0, 1, 2, 3, 4, 5]:
|
||||
gradcheck(lambda x: x.polygamma(n), cpu_tensor)
|
||||
|
||||
# TODO: update to compare against NumPy by rationalizing with OpInfo
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float, torch.double)
|
||||
@ -1565,251 +1545,7 @@ class TestUnaryUfuncs(TestCase):
|
||||
self.assertEqual(torch.isinf(sample.atanh()), inf_mask)
|
||||
|
||||
|
||||
def _generate_reference_input(dtype, device):
|
||||
input = []
|
||||
input.append(list(range(-5, 5)))
|
||||
input.append([0 for x in range(-5, 5)])
|
||||
input.append([x + 1e-6 for x in range(-5, 5)])
|
||||
# Some vectorized implementations don't support large values
|
||||
input.append([x + 1e10 for x in range(-5, 5)])
|
||||
input.append([x - 1e10 for x in range(-5, 5)])
|
||||
input.append([*torch.randn(7).tolist(), math.inf, -math.inf, math.nan])
|
||||
input.append((torch.randn(10) * 1e6).tolist())
|
||||
input.append([math.pi * (x / 2) for x in range(-5, 5)])
|
||||
return torch.tensor(input, dtype=dtype, device=device)
|
||||
|
||||
def _generate_gamma_input(dtype, device, test_poles=True):
|
||||
input = []
|
||||
input.append((torch.randn(10).abs() + 1e-4).tolist())
|
||||
input.append((torch.randn(10).abs() + 1e6).tolist())
|
||||
zeros = torch.linspace(-9.5, -0.5, 10)
|
||||
input.append(zeros.tolist())
|
||||
input.append((zeros - 0.49).tolist())
|
||||
input.append((zeros + 0.49).tolist())
|
||||
input.append((zeros + (torch.rand(10) * 0.99) - 0.5).tolist())
|
||||
|
||||
if test_poles:
|
||||
input.append([-0.999999994, -1.999999994, -2.0000000111,
|
||||
-100.99999994, -1931.99999994, 0.000000111,
|
||||
-0.000000111, 0, -2, -329])
|
||||
return torch.tensor(input, dtype=dtype, device=device)
|
||||
|
||||
# this class contains information needed to generate tests for torch math functions
|
||||
# the generated tests compare torch implementation with the reference numpy/scipy implementation,
|
||||
# and also check proper behavior for contiguous/noncontiguous/inplace outputs.
|
||||
class _TorchMathTestMeta(object):
|
||||
def __init__(self,
|
||||
opstr,
|
||||
args=(),
|
||||
reffn=None,
|
||||
refargs=lambda x: (x.numpy(),),
|
||||
input_fn=_generate_reference_input,
|
||||
inputargs=(),
|
||||
substr='',
|
||||
make_inplace=True,
|
||||
decorators=None,
|
||||
ref_backend='numpy',
|
||||
rtol=None,
|
||||
atol=None,
|
||||
dtypes=floating_types(),
|
||||
replace_inf_with_nan=False):
|
||||
self.opstr = opstr
|
||||
self.args = args
|
||||
self.reffn = reffn # reffn is either callable or ref_backend attribute, set to opstr if not specified
|
||||
self.refargs = refargs
|
||||
self.input_fn = input_fn
|
||||
self.inputargs = inputargs
|
||||
self.substr = substr
|
||||
self.make_inplace = make_inplace
|
||||
assert ref_backend == 'numpy' or ref_backend == 'scipy'
|
||||
self.ref_backend = ref_backend
|
||||
if ref_backend == 'scipy':
|
||||
self.ref_decorator = [unittest.skipIf(not TEST_SCIPY, "Scipy not found")]
|
||||
else:
|
||||
self.ref_decorator = []
|
||||
self.decorators = decorators
|
||||
self.rtol = rtol
|
||||
self.atol = atol
|
||||
self.dtypes = dtypes
|
||||
self.replace_inf_with_nan = replace_inf_with_nan
|
||||
|
||||
# TODO: replace with make_tensor
|
||||
# Converts half/bfloat16 dtype to float when device is cpu
|
||||
def _convert_t(dtype, device):
|
||||
if device == 'cpu' and dtype in {torch.half, torch.bfloat16}:
|
||||
return torch.float
|
||||
return dtype
|
||||
|
||||
# TODO: replace with make_tensor
|
||||
# Returns a tensor of the requested shape, dtype, and device
|
||||
# Requesting a half CPU tensor returns a float CPU tensor with
|
||||
# values representable by a half.
|
||||
# Initialization uses randint for non-float types and randn for float types.
|
||||
def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor:
|
||||
# Returns a tensor filled with ones
|
||||
if fill_ones:
|
||||
return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device)
|
||||
|
||||
# Returns a tensor with random integer values
|
||||
if not (dtype.is_floating_point or dtype.is_complex):
|
||||
t = torch.randint(0, 10, shape, device=device)
|
||||
if dtype != torch.uint8:
|
||||
t = t - 5 # generate negative values also
|
||||
return t.to(_convert_t(dtype, device))
|
||||
|
||||
# Populates the CPU tensor with floats representable as half/bfloat16
|
||||
if dtype == torch.half and device == 'cpu':
|
||||
return torch.randn(*shape, dtype=torch.float, device=device).half().float()
|
||||
if dtype == torch.bfloat16 and device == 'cpu':
|
||||
return torch.randn(*shape, dtype=torch.float, device=device).bfloat16().float()
|
||||
|
||||
# Default: returns a tensor with random float values
|
||||
return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype)
|
||||
|
||||
# TODO: replace with make_tensor
|
||||
def _medium_2d(dtype, device):
|
||||
return _make_tensor((50, 50), dtype, device)
|
||||
|
||||
# TODO: replace with opinfo
|
||||
_types_no_half = [
|
||||
torch.float, torch.double,
|
||||
torch.int8, torch.short, torch.int, torch.long,
|
||||
torch.uint8
|
||||
]
|
||||
|
||||
# TODO: all these should be replaced with OpInfos
|
||||
torch_op_tests = [
|
||||
_TorchMathTestMeta('polygamma', args=[0], substr='_0', reffn='polygamma',
|
||||
refargs=lambda x: (0, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
|
||||
ref_backend='scipy'),
|
||||
_TorchMathTestMeta('polygamma', args=[1], substr='_1', reffn='polygamma',
|
||||
refargs=lambda x: (1, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
|
||||
ref_backend='scipy', rtol=0.0008, atol=1e-5),
|
||||
_TorchMathTestMeta('polygamma', args=[2], substr='_2', reffn='polygamma',
|
||||
refargs=lambda x: (2, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
|
||||
ref_backend='scipy', rtol=0.0008, atol=1e-5)]
|
||||
|
||||
def generate_torch_test_functions(cls, testmeta, inplace):
|
||||
opstr = testmeta.opstr if not inplace else testmeta.opstr + "_"
|
||||
|
||||
def torchfn(x):
|
||||
return getattr(x, opstr)(*testmeta.args)
|
||||
|
||||
def fn_check_reference(self, device, dtype):
|
||||
def reffn(x):
|
||||
backend = np if testmeta.ref_backend == 'numpy' else scipy.special
|
||||
opstr = None
|
||||
if testmeta.reffn is None:
|
||||
opstr = testmeta.opstr
|
||||
elif isinstance(testmeta.reffn, str):
|
||||
opstr = testmeta.reffn
|
||||
if callable(testmeta.reffn):
|
||||
fn = testmeta.reffn
|
||||
else:
|
||||
assert opstr is not None, "invalid reffn"
|
||||
fn = getattr(backend, opstr)
|
||||
return fn(*testmeta.refargs(x))
|
||||
|
||||
inp = testmeta.input_fn(dtype, device, *testmeta.inputargs)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
expected = torch.from_numpy(reffn(inp))
|
||||
actual = torchfn(inp)
|
||||
if testmeta.replace_inf_with_nan:
|
||||
actual[(actual == -inf) | (actual == inf)] = nan
|
||||
expected[(expected == -inf) | (expected == inf)] = nan
|
||||
|
||||
torch.testing.assert_allclose(actual, expected, rtol=testmeta.rtol, atol=testmeta.atol)
|
||||
|
||||
def fn_non_contig(self, device, dtype) -> None:
|
||||
shapes = [(5, 7), (1024,)]
|
||||
for shape in shapes:
|
||||
contig = _make_tensor(shape, dtype=dtype, device=device)
|
||||
non_contig = torch.empty(shape + (2,), dtype=dtype)[..., 0]
|
||||
non_contig.copy_(contig)
|
||||
self.assertFalse(non_contig.is_contiguous())
|
||||
self.assertEqual(torchfn(contig), torchfn(non_contig), msg='non-contiguous')
|
||||
|
||||
def fn_non_contig_index(self, device, dtype):
|
||||
contig = _make_tensor((2, 2, 1, 2), dtype=dtype, device=device)
|
||||
non_contig = contig[:, 1, ...]
|
||||
contig = non_contig.clone()
|
||||
self.assertFalse(non_contig.is_contiguous())
|
||||
self.assertEqual(torchfn(contig), torchfn(non_contig), msg='non-contiguous index')
|
||||
|
||||
def fn_non_contig_expand(self, device, dtype):
|
||||
shapes = [(1, 3), (1, 7), (5, 7)]
|
||||
for shape in shapes:
|
||||
contig = _make_tensor(shape, dtype=dtype, device=device)
|
||||
non_contig = contig.clone().expand(3, -1, -1)
|
||||
self.assertFalse(non_contig.is_contiguous())
|
||||
contig = torchfn(contig)
|
||||
non_contig = torchfn(non_contig)
|
||||
for i in range(3):
|
||||
self.assertEqual(contig, non_contig[i], msg='non-contiguous expand[' + str(i) + ']')
|
||||
|
||||
def fn_contig_size1(self, device, dtype):
|
||||
contig = _make_tensor((5, 100), dtype=dtype, device=device)
|
||||
contig = contig[:1, :50]
|
||||
contig2 = torch.empty(contig.size(), dtype=dtype)
|
||||
contig2.copy_(contig)
|
||||
self.assertTrue(contig.is_contiguous())
|
||||
self.assertTrue(contig2.is_contiguous())
|
||||
self.assertEqual(torchfn(contig), torchfn(contig2), msg='contiguous size1')
|
||||
|
||||
def fn_contig_size1_large_dim(self, device, dtype):
|
||||
contig = _make_tensor((5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), dtype=dtype, device=device)
|
||||
contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
|
||||
contig2 = torch.empty(contig.size(), dtype=dtype)
|
||||
contig2.copy_(contig)
|
||||
self.assertTrue(contig.is_contiguous())
|
||||
self.assertTrue(contig2.is_contiguous())
|
||||
self.assertEqual(torchfn(contig), torchfn(contig2), msg='contiguous size1')
|
||||
|
||||
def fn_large(self, device, dtype):
|
||||
input = _make_tensor((1024, 512), dtype=dtype, device=device)
|
||||
# clone input to properly test inplace functions
|
||||
actual = torchfn(input.clone())
|
||||
expected = torch.stack([torchfn(slice) for slice in input])
|
||||
self.assertEqual(actual, expected, msg='large')
|
||||
|
||||
test_functions = {"test_reference_": fn_check_reference,
|
||||
"test_non_contig_": fn_non_contig,
|
||||
"test_non_contig_index_": fn_non_contig_index,
|
||||
"test_non_contig_expand_": fn_non_contig_expand,
|
||||
"test_contig_size1_": fn_contig_size1,
|
||||
"test_check_contig_size1_large_dim_": fn_contig_size1_large_dim,
|
||||
"test_large_": fn_large}
|
||||
for name in test_functions:
|
||||
if inplace and 'expand' in name:
|
||||
continue
|
||||
test_name = name + testmeta.opstr + testmeta.substr
|
||||
if inplace:
|
||||
test_name += "_inplace"
|
||||
assert not hasattr(cls, test_name), "{0} already in TestUnaryUfuncMathOps".format(test_name)
|
||||
|
||||
decorators = [] if testmeta.decorators is None else testmeta.decorators
|
||||
if 'reference' in name:
|
||||
decorators = decorators + testmeta.ref_decorator
|
||||
decorators = decorators + [dtypes(*testmeta.dtypes)]
|
||||
fn_test = test_functions[name]
|
||||
for dec in decorators:
|
||||
fn_test = dec(fn_test)
|
||||
setattr(cls, test_name, fn_test)
|
||||
|
||||
class TestUnaryUfuncMathOps(TestCase):
|
||||
exact_dtype = True
|
||||
|
||||
def generate_torch_op_tests(cls):
|
||||
for t in torch_op_tests:
|
||||
generate_torch_test_functions(cls, t, False)
|
||||
if t.make_inplace:
|
||||
generate_torch_test_functions(cls, t, True)
|
||||
|
||||
|
||||
generate_torch_op_tests(TestUnaryUfuncMathOps)
|
||||
instantiate_device_type_tests(TestUnaryUfuncs, globals())
|
||||
instantiate_device_type_tests(TestUnaryUfuncMathOps, globals(), only_for='cpu')
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -658,6 +658,9 @@
|
||||
- name: polygamma(int n, Tensor self) -> Tensor
|
||||
self: grad * polygamma(n + 1, self)
|
||||
|
||||
- name: polygamma_(Tensor(a!) self, int n) -> Tensor(a!)
|
||||
self: grad * polygamma(n + 1, self)
|
||||
|
||||
- name: log(Tensor self) -> Tensor
|
||||
self: grad.div(self.conj())
|
||||
|
||||
|
@ -2475,6 +2475,18 @@ def sample_inputs_polar(op_info, device, dtype, requires_grad, **kwargs):
|
||||
return samples
|
||||
|
||||
|
||||
def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
tensor_shapes = ((S, S), ())
|
||||
ns = (1, 2, 3, 4, 5)
|
||||
|
||||
def generator():
|
||||
for shape, n in product(tensor_shapes, ns):
|
||||
yield SampleInput(make_arg(shape), args=(n,))
|
||||
|
||||
return list(generator())
|
||||
|
||||
|
||||
def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
|
||||
low, _ = op_info.domain
|
||||
|
||||
@ -2883,6 +2895,16 @@ def reference_lgamma(x):
|
||||
|
||||
return out
|
||||
|
||||
def reference_polygamma(x, n):
|
||||
# WEIRD `scipy.special.polygamma` behavior
|
||||
# >>> scipy.special.polygamma(0, np.array(501, dtype=np.float32)).dtype
|
||||
# dtype('float64')
|
||||
# >>> scipy.special.polygamma(0, np.array([501], dtype=np.float32)).dtype
|
||||
# dtype('float32')
|
||||
#
|
||||
# Thus we cast output to the default torch dtype.
|
||||
np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
|
||||
return scipy.special.polygamma(n, x).astype(np_dtype)
|
||||
|
||||
def gradcheck_wrapper_hermitian_input(op, input, *args, **kwargs):
|
||||
"""Gradcheck wrapper for functions that take Hermitian matrices as input.
|
||||
@ -4484,6 +4506,114 @@ op_db: List[OpInfo] = [
|
||||
OpInfo('polar',
|
||||
dtypes=floating_types(),
|
||||
sample_inputs_func=sample_inputs_polar),
|
||||
# To test reference numerics against multiple values of argument `n`,
|
||||
# we make multiple OpInfo entries with each entry corresponding to different value of n (currently 0 to 4).
|
||||
# We run the op tests from test_ops.py only for `n=0` to avoid redundancy in testing.
|
||||
UnaryUfuncInfo('polygamma',
|
||||
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
|
||||
variant_test_name='polygamma_n_0',
|
||||
ref=reference_polygamma if TEST_SCIPY else _NOTHING,
|
||||
dtypes=floating_types(),
|
||||
dtypesIfCPU=floating_types(),
|
||||
dtypesIfCUDA=floating_types_and(torch.half),
|
||||
sample_inputs_func=sample_inputs_polygamma,
|
||||
skips=(
|
||||
# Probably related to the way the function is
|
||||
# scripted for JIT tests (or maybe not).
|
||||
# RuntimeError:
|
||||
# Arguments for call are not valid.
|
||||
# The following variants are available:
|
||||
# aten::polygamma(int n, Tensor self) -> (Tensor):
|
||||
# Expected a value of type 'Tensor' for argument 'self' but instead found type 'int'.
|
||||
# aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> (Tensor(a!)):
|
||||
# Expected a value of type 'Tensor' for argument 'self' but instead found type 'int'.
|
||||
# The original call is:
|
||||
# File "<string>", line 3
|
||||
# def the_method(i0):
|
||||
# return torch.polygamma(i0, 1)
|
||||
# ~~~~~~~~~~~~~~~ <--- HERE
|
||||
SkipInfo('TestCommon', 'test_variant_consistency_jit'),),
|
||||
sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0})),
|
||||
UnaryUfuncInfo('polygamma',
|
||||
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
|
||||
variant_test_name='polygamma_n_1',
|
||||
ref=reference_polygamma if TEST_SCIPY else _NOTHING,
|
||||
dtypes=floating_types(),
|
||||
dtypesIfCPU=floating_types(),
|
||||
dtypesIfCUDA=floating_types_and(torch.half),
|
||||
sample_inputs_func=sample_inputs_polygamma,
|
||||
skips=(
|
||||
# Redundant tests
|
||||
SkipInfo('TestGradients'),
|
||||
SkipInfo('TestOpInfo'),
|
||||
SkipInfo('TestCommon'),
|
||||
# Mismatch: https://github.com/pytorch/pytorch/issues/55357
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal'),
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard'),
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal'),
|
||||
),
|
||||
sample_kwargs=lambda device, dtype, input: ({'n': 1}, {'n': 1})),
|
||||
UnaryUfuncInfo('polygamma',
|
||||
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
|
||||
variant_test_name='polygamma_n_2',
|
||||
ref=reference_polygamma if TEST_SCIPY else _NOTHING,
|
||||
dtypes=floating_types(),
|
||||
dtypesIfCPU=floating_types(),
|
||||
dtypesIfCUDA=floating_types_and(torch.half),
|
||||
sample_inputs_func=sample_inputs_polygamma,
|
||||
skips=(
|
||||
# Redundant tests
|
||||
SkipInfo('TestGradients'),
|
||||
SkipInfo('TestOpInfo'),
|
||||
SkipInfo('TestCommon'),
|
||||
# Mismatch: https://github.com/pytorch/pytorch/issues/55357
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal'),
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
|
||||
active_if=TEST_WITH_ROCM),
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal',
|
||||
active_if=TEST_WITH_ROCM),),
|
||||
sample_kwargs=lambda device, dtype, input: ({'n': 2}, {'n': 2})),
|
||||
UnaryUfuncInfo('polygamma',
|
||||
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
|
||||
variant_test_name='polygamma_n_3',
|
||||
ref=reference_polygamma if TEST_SCIPY else _NOTHING,
|
||||
dtypes=floating_types(),
|
||||
dtypesIfCPU=floating_types(),
|
||||
dtypesIfCUDA=floating_types_and(torch.half),
|
||||
sample_inputs_func=sample_inputs_polygamma,
|
||||
skips=(
|
||||
# Redundant tests
|
||||
SkipInfo('TestGradients'),
|
||||
SkipInfo('TestOpInfo'),
|
||||
SkipInfo('TestCommon'),
|
||||
# Mismatch: https://github.com/pytorch/pytorch/issues/55357
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal'),
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
|
||||
active_if=TEST_WITH_ROCM),
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal',
|
||||
active_if=TEST_WITH_ROCM),),
|
||||
sample_kwargs=lambda device, dtype, input: ({'n': 3}, {'n': 3})),
|
||||
UnaryUfuncInfo('polygamma',
|
||||
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
|
||||
variant_test_name='polygamma_n_4',
|
||||
ref=reference_polygamma if TEST_SCIPY else _NOTHING,
|
||||
decorators=(precisionOverride({torch.float16: 5e-4, torch.float32: 5e-4}),),
|
||||
dtypes=floating_types(),
|
||||
dtypesIfCPU=floating_types(),
|
||||
dtypesIfCUDA=floating_types_and(torch.half),
|
||||
sample_inputs_func=sample_inputs_polygamma,
|
||||
skips=(
|
||||
# Redundant tests
|
||||
SkipInfo('TestGradients'),
|
||||
SkipInfo('TestOpInfo'),
|
||||
SkipInfo('TestCommon'),
|
||||
# Mismatch: https://github.com/pytorch/pytorch/issues/55357
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal'),
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
|
||||
active_if=TEST_WITH_ROCM),
|
||||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal',
|
||||
active_if=TEST_WITH_ROCM),),
|
||||
sample_kwargs=lambda device, dtype, input: ({'n': 4}, {'n': 4})),
|
||||
OpInfo('pinverse',
|
||||
op=torch.pinverse,
|
||||
dtypes=floating_and_complex_types(),
|
||||
|
Reference in New Issue
Block a user