Add OpInfo for polygamma and remove torch_op_tests Infra (#51966)

Summary:
Reference: https://github.com/pytorch/pytorch/issues/42515

* OpInfo entry for Polygamma
* Removes infra of torch_op_tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51966

Reviewed By: bdhirsh

Differential Revision: D27851858

Pulled By: mruberry

fbshipit-source-id: 7f1d0273065e1df56a152f95a14513959af29a1b
This commit is contained in:
kshitij12345
2021-04-20 01:01:07 -07:00
committed by Facebook GitHub Bot
parent a661e58731
commit df8bb5a42b
5 changed files with 142 additions and 270 deletions

View File

@ -5034,9 +5034,6 @@
structured_delegate: digamma.out
variants: method
- func: polygamma_(Tensor(a!) self, int n) -> Tensor(a!)
variants: method
- func: renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!)
variants: method
dispatch:
@ -5804,6 +5801,11 @@
dispatch:
CompositeExplicitAutograd: polygamma
- func: polygamma_(Tensor(a!) self, int n) -> Tensor(a!)
variants: method
dispatch:
CompositeExplicitAutograd: polygamma_
- func: erfinv(Tensor self) -> Tensor
structured_delegate: erfinv.out
variants: method, function

View File

@ -2320,7 +2320,8 @@ class TestOperatorSignatures(JitTestCase):
@onlyCPU
@ops(op_db, allowed_dtypes=(torch.float,))
def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
known_no_schema = {'stack', 'hstack', 'vstack', 'dstack', 'repeat', '__getitem__', 'linalg.multi_dot'}
known_no_schema = {'stack', 'hstack', 'vstack', 'dstack', 'repeat', '__getitem__', 'linalg.multi_dot',
'polygamma'}
try:
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)

View File

@ -11,8 +11,7 @@ import unittest
from torch._six import inf, nan
from torch.testing._internal.common_utils import (
TestCase, run_tests, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict,
suppress_warnings, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy,
gradcheck, IS_WINDOWS)
suppress_warnings, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy, IS_WINDOWS)
from torch.testing._internal.common_methods_invocations import (
unary_ufuncs, _NOTHING)
from torch.testing._internal.common_device_type import (
@ -20,7 +19,7 @@ from torch.testing._internal.common_device_type import (
onlyCUDA, dtypesIfCUDA, precisionOverride, skipCUDAIfRocm, dtypesIfCPU,
OpDTypes)
from torch.testing import (
floating_types_and, all_types_and_complex_and, floating_types, floating_and_complex_types_and)
floating_types_and, all_types_and_complex_and, floating_and_complex_types_and)
if TEST_SCIPY:
import scipy
@ -664,7 +663,6 @@ class TestUnaryUfuncs(TestCase):
with self.assertRaisesRegex(RuntimeError, r"All elements must be greater than \(p-1\)/2"):
run_test(3)
# TODO opinfo polygamma
def test_polygamma_neg(self, device):
with self.assertRaisesRegex(RuntimeError, r'polygamma\(n, x\) does not support negative n\.'):
torch.polygamma(-1, torch.tensor([1.0, 2.0], device=device))
@ -1272,24 +1270,6 @@ class TestUnaryUfuncs(TestCase):
t_copy.abs_()
self.assertEqual(t, t_copy)
# Note: ROCm fails when using float tensors
# TODO: update this test to just compare against NumPy
@onlyCUDA
@dtypes(torch.double)
def test_polygamma(self, device, dtype):
cpu_tensor = torch.randn(10, 10, 10, dtype=dtype)
device_tensor = cpu_tensor.to(device)
zeros = torch.zeros(10, 10, 10, dtype=dtype)
for n in [0, 1, 2, 3, 4, 5]:
cpu_out = cpu_tensor.polygamma(n)
device_out = device_tensor.polygamma(n)
norm_errors = (device_out - cpu_out.to(device)) / device_out
self.assertEqual(norm_errors, zeros)
cpu_tensor.requires_grad = True
for n in [0, 1, 2, 3, 4, 5]:
gradcheck(lambda x: x.polygamma(n), cpu_tensor)
# TODO: update to compare against NumPy by rationalizing with OpInfo
@onlyCUDA
@dtypes(torch.float, torch.double)
@ -1565,251 +1545,7 @@ class TestUnaryUfuncs(TestCase):
self.assertEqual(torch.isinf(sample.atanh()), inf_mask)
def _generate_reference_input(dtype, device):
input = []
input.append(list(range(-5, 5)))
input.append([0 for x in range(-5, 5)])
input.append([x + 1e-6 for x in range(-5, 5)])
# Some vectorized implementations don't support large values
input.append([x + 1e10 for x in range(-5, 5)])
input.append([x - 1e10 for x in range(-5, 5)])
input.append([*torch.randn(7).tolist(), math.inf, -math.inf, math.nan])
input.append((torch.randn(10) * 1e6).tolist())
input.append([math.pi * (x / 2) for x in range(-5, 5)])
return torch.tensor(input, dtype=dtype, device=device)
def _generate_gamma_input(dtype, device, test_poles=True):
input = []
input.append((torch.randn(10).abs() + 1e-4).tolist())
input.append((torch.randn(10).abs() + 1e6).tolist())
zeros = torch.linspace(-9.5, -0.5, 10)
input.append(zeros.tolist())
input.append((zeros - 0.49).tolist())
input.append((zeros + 0.49).tolist())
input.append((zeros + (torch.rand(10) * 0.99) - 0.5).tolist())
if test_poles:
input.append([-0.999999994, -1.999999994, -2.0000000111,
-100.99999994, -1931.99999994, 0.000000111,
-0.000000111, 0, -2, -329])
return torch.tensor(input, dtype=dtype, device=device)
# this class contains information needed to generate tests for torch math functions
# the generated tests compare torch implementation with the reference numpy/scipy implementation,
# and also check proper behavior for contiguous/noncontiguous/inplace outputs.
class _TorchMathTestMeta(object):
def __init__(self,
opstr,
args=(),
reffn=None,
refargs=lambda x: (x.numpy(),),
input_fn=_generate_reference_input,
inputargs=(),
substr='',
make_inplace=True,
decorators=None,
ref_backend='numpy',
rtol=None,
atol=None,
dtypes=floating_types(),
replace_inf_with_nan=False):
self.opstr = opstr
self.args = args
self.reffn = reffn # reffn is either callable or ref_backend attribute, set to opstr if not specified
self.refargs = refargs
self.input_fn = input_fn
self.inputargs = inputargs
self.substr = substr
self.make_inplace = make_inplace
assert ref_backend == 'numpy' or ref_backend == 'scipy'
self.ref_backend = ref_backend
if ref_backend == 'scipy':
self.ref_decorator = [unittest.skipIf(not TEST_SCIPY, "Scipy not found")]
else:
self.ref_decorator = []
self.decorators = decorators
self.rtol = rtol
self.atol = atol
self.dtypes = dtypes
self.replace_inf_with_nan = replace_inf_with_nan
# TODO: replace with make_tensor
# Converts half/bfloat16 dtype to float when device is cpu
def _convert_t(dtype, device):
if device == 'cpu' and dtype in {torch.half, torch.bfloat16}:
return torch.float
return dtype
# TODO: replace with make_tensor
# Returns a tensor of the requested shape, dtype, and device
# Requesting a half CPU tensor returns a float CPU tensor with
# values representable by a half.
# Initialization uses randint for non-float types and randn for float types.
def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor:
# Returns a tensor filled with ones
if fill_ones:
return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device)
# Returns a tensor with random integer values
if not (dtype.is_floating_point or dtype.is_complex):
t = torch.randint(0, 10, shape, device=device)
if dtype != torch.uint8:
t = t - 5 # generate negative values also
return t.to(_convert_t(dtype, device))
# Populates the CPU tensor with floats representable as half/bfloat16
if dtype == torch.half and device == 'cpu':
return torch.randn(*shape, dtype=torch.float, device=device).half().float()
if dtype == torch.bfloat16 and device == 'cpu':
return torch.randn(*shape, dtype=torch.float, device=device).bfloat16().float()
# Default: returns a tensor with random float values
return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype)
# TODO: replace with make_tensor
def _medium_2d(dtype, device):
return _make_tensor((50, 50), dtype, device)
# TODO: replace with opinfo
_types_no_half = [
torch.float, torch.double,
torch.int8, torch.short, torch.int, torch.long,
torch.uint8
]
# TODO: all these should be replaced with OpInfos
torch_op_tests = [
_TorchMathTestMeta('polygamma', args=[0], substr='_0', reffn='polygamma',
refargs=lambda x: (0, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
ref_backend='scipy'),
_TorchMathTestMeta('polygamma', args=[1], substr='_1', reffn='polygamma',
refargs=lambda x: (1, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
ref_backend='scipy', rtol=0.0008, atol=1e-5),
_TorchMathTestMeta('polygamma', args=[2], substr='_2', reffn='polygamma',
refargs=lambda x: (2, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
ref_backend='scipy', rtol=0.0008, atol=1e-5)]
def generate_torch_test_functions(cls, testmeta, inplace):
opstr = testmeta.opstr if not inplace else testmeta.opstr + "_"
def torchfn(x):
return getattr(x, opstr)(*testmeta.args)
def fn_check_reference(self, device, dtype):
def reffn(x):
backend = np if testmeta.ref_backend == 'numpy' else scipy.special
opstr = None
if testmeta.reffn is None:
opstr = testmeta.opstr
elif isinstance(testmeta.reffn, str):
opstr = testmeta.reffn
if callable(testmeta.reffn):
fn = testmeta.reffn
else:
assert opstr is not None, "invalid reffn"
fn = getattr(backend, opstr)
return fn(*testmeta.refargs(x))
inp = testmeta.input_fn(dtype, device, *testmeta.inputargs)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
expected = torch.from_numpy(reffn(inp))
actual = torchfn(inp)
if testmeta.replace_inf_with_nan:
actual[(actual == -inf) | (actual == inf)] = nan
expected[(expected == -inf) | (expected == inf)] = nan
torch.testing.assert_allclose(actual, expected, rtol=testmeta.rtol, atol=testmeta.atol)
def fn_non_contig(self, device, dtype) -> None:
shapes = [(5, 7), (1024,)]
for shape in shapes:
contig = _make_tensor(shape, dtype=dtype, device=device)
non_contig = torch.empty(shape + (2,), dtype=dtype)[..., 0]
non_contig.copy_(contig)
self.assertFalse(non_contig.is_contiguous())
self.assertEqual(torchfn(contig), torchfn(non_contig), msg='non-contiguous')
def fn_non_contig_index(self, device, dtype):
contig = _make_tensor((2, 2, 1, 2), dtype=dtype, device=device)
non_contig = contig[:, 1, ...]
contig = non_contig.clone()
self.assertFalse(non_contig.is_contiguous())
self.assertEqual(torchfn(contig), torchfn(non_contig), msg='non-contiguous index')
def fn_non_contig_expand(self, device, dtype):
shapes = [(1, 3), (1, 7), (5, 7)]
for shape in shapes:
contig = _make_tensor(shape, dtype=dtype, device=device)
non_contig = contig.clone().expand(3, -1, -1)
self.assertFalse(non_contig.is_contiguous())
contig = torchfn(contig)
non_contig = torchfn(non_contig)
for i in range(3):
self.assertEqual(contig, non_contig[i], msg='non-contiguous expand[' + str(i) + ']')
def fn_contig_size1(self, device, dtype):
contig = _make_tensor((5, 100), dtype=dtype, device=device)
contig = contig[:1, :50]
contig2 = torch.empty(contig.size(), dtype=dtype)
contig2.copy_(contig)
self.assertTrue(contig.is_contiguous())
self.assertTrue(contig2.is_contiguous())
self.assertEqual(torchfn(contig), torchfn(contig2), msg='contiguous size1')
def fn_contig_size1_large_dim(self, device, dtype):
contig = _make_tensor((5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), dtype=dtype, device=device)
contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
contig2 = torch.empty(contig.size(), dtype=dtype)
contig2.copy_(contig)
self.assertTrue(contig.is_contiguous())
self.assertTrue(contig2.is_contiguous())
self.assertEqual(torchfn(contig), torchfn(contig2), msg='contiguous size1')
def fn_large(self, device, dtype):
input = _make_tensor((1024, 512), dtype=dtype, device=device)
# clone input to properly test inplace functions
actual = torchfn(input.clone())
expected = torch.stack([torchfn(slice) for slice in input])
self.assertEqual(actual, expected, msg='large')
test_functions = {"test_reference_": fn_check_reference,
"test_non_contig_": fn_non_contig,
"test_non_contig_index_": fn_non_contig_index,
"test_non_contig_expand_": fn_non_contig_expand,
"test_contig_size1_": fn_contig_size1,
"test_check_contig_size1_large_dim_": fn_contig_size1_large_dim,
"test_large_": fn_large}
for name in test_functions:
if inplace and 'expand' in name:
continue
test_name = name + testmeta.opstr + testmeta.substr
if inplace:
test_name += "_inplace"
assert not hasattr(cls, test_name), "{0} already in TestUnaryUfuncMathOps".format(test_name)
decorators = [] if testmeta.decorators is None else testmeta.decorators
if 'reference' in name:
decorators = decorators + testmeta.ref_decorator
decorators = decorators + [dtypes(*testmeta.dtypes)]
fn_test = test_functions[name]
for dec in decorators:
fn_test = dec(fn_test)
setattr(cls, test_name, fn_test)
class TestUnaryUfuncMathOps(TestCase):
exact_dtype = True
def generate_torch_op_tests(cls):
for t in torch_op_tests:
generate_torch_test_functions(cls, t, False)
if t.make_inplace:
generate_torch_test_functions(cls, t, True)
generate_torch_op_tests(TestUnaryUfuncMathOps)
instantiate_device_type_tests(TestUnaryUfuncs, globals())
instantiate_device_type_tests(TestUnaryUfuncMathOps, globals(), only_for='cpu')
if __name__ == '__main__':
run_tests()

View File

@ -658,6 +658,9 @@
- name: polygamma(int n, Tensor self) -> Tensor
self: grad * polygamma(n + 1, self)
- name: polygamma_(Tensor(a!) self, int n) -> Tensor(a!)
self: grad * polygamma(n + 1, self)
- name: log(Tensor self) -> Tensor
self: grad.div(self.conj())

View File

@ -2475,6 +2475,18 @@ def sample_inputs_polar(op_info, device, dtype, requires_grad, **kwargs):
return samples
def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
tensor_shapes = ((S, S), ())
ns = (1, 2, 3, 4, 5)
def generator():
for shape, n in product(tensor_shapes, ns):
yield SampleInput(make_arg(shape), args=(n,))
return list(generator())
def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
low, _ = op_info.domain
@ -2883,6 +2895,16 @@ def reference_lgamma(x):
return out
def reference_polygamma(x, n):
# WEIRD `scipy.special.polygamma` behavior
# >>> scipy.special.polygamma(0, np.array(501, dtype=np.float32)).dtype
# dtype('float64')
# >>> scipy.special.polygamma(0, np.array([501], dtype=np.float32)).dtype
# dtype('float32')
#
# Thus we cast output to the default torch dtype.
np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
return scipy.special.polygamma(n, x).astype(np_dtype)
def gradcheck_wrapper_hermitian_input(op, input, *args, **kwargs):
"""Gradcheck wrapper for functions that take Hermitian matrices as input.
@ -4484,6 +4506,114 @@ op_db: List[OpInfo] = [
OpInfo('polar',
dtypes=floating_types(),
sample_inputs_func=sample_inputs_polar),
# To test reference numerics against multiple values of argument `n`,
# we make multiple OpInfo entries with each entry corresponding to different value of n (currently 0 to 4).
# We run the op tests from test_ops.py only for `n=0` to avoid redundancy in testing.
UnaryUfuncInfo('polygamma',
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
variant_test_name='polygamma_n_0',
ref=reference_polygamma if TEST_SCIPY else _NOTHING,
dtypes=floating_types(),
dtypesIfCPU=floating_types(),
dtypesIfCUDA=floating_types_and(torch.half),
sample_inputs_func=sample_inputs_polygamma,
skips=(
# Probably related to the way the function is
# scripted for JIT tests (or maybe not).
# RuntimeError:
# Arguments for call are not valid.
# The following variants are available:
# aten::polygamma(int n, Tensor self) -> (Tensor):
# Expected a value of type 'Tensor' for argument 'self' but instead found type 'int'.
# aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> (Tensor(a!)):
# Expected a value of type 'Tensor' for argument 'self' but instead found type 'int'.
# The original call is:
# File "<string>", line 3
# def the_method(i0):
# return torch.polygamma(i0, 1)
# ~~~~~~~~~~~~~~~ <--- HERE
SkipInfo('TestCommon', 'test_variant_consistency_jit'),),
sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0})),
UnaryUfuncInfo('polygamma',
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
variant_test_name='polygamma_n_1',
ref=reference_polygamma if TEST_SCIPY else _NOTHING,
dtypes=floating_types(),
dtypesIfCPU=floating_types(),
dtypesIfCUDA=floating_types_and(torch.half),
sample_inputs_func=sample_inputs_polygamma,
skips=(
# Redundant tests
SkipInfo('TestGradients'),
SkipInfo('TestOpInfo'),
SkipInfo('TestCommon'),
# Mismatch: https://github.com/pytorch/pytorch/issues/55357
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal'),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard'),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal'),
),
sample_kwargs=lambda device, dtype, input: ({'n': 1}, {'n': 1})),
UnaryUfuncInfo('polygamma',
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
variant_test_name='polygamma_n_2',
ref=reference_polygamma if TEST_SCIPY else _NOTHING,
dtypes=floating_types(),
dtypesIfCPU=floating_types(),
dtypesIfCUDA=floating_types_and(torch.half),
sample_inputs_func=sample_inputs_polygamma,
skips=(
# Redundant tests
SkipInfo('TestGradients'),
SkipInfo('TestOpInfo'),
SkipInfo('TestCommon'),
# Mismatch: https://github.com/pytorch/pytorch/issues/55357
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal'),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
active_if=TEST_WITH_ROCM),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal',
active_if=TEST_WITH_ROCM),),
sample_kwargs=lambda device, dtype, input: ({'n': 2}, {'n': 2})),
UnaryUfuncInfo('polygamma',
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
variant_test_name='polygamma_n_3',
ref=reference_polygamma if TEST_SCIPY else _NOTHING,
dtypes=floating_types(),
dtypesIfCPU=floating_types(),
dtypesIfCUDA=floating_types_and(torch.half),
sample_inputs_func=sample_inputs_polygamma,
skips=(
# Redundant tests
SkipInfo('TestGradients'),
SkipInfo('TestOpInfo'),
SkipInfo('TestCommon'),
# Mismatch: https://github.com/pytorch/pytorch/issues/55357
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal'),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
active_if=TEST_WITH_ROCM),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal',
active_if=TEST_WITH_ROCM),),
sample_kwargs=lambda device, dtype, input: ({'n': 3}, {'n': 3})),
UnaryUfuncInfo('polygamma',
op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
variant_test_name='polygamma_n_4',
ref=reference_polygamma if TEST_SCIPY else _NOTHING,
decorators=(precisionOverride({torch.float16: 5e-4, torch.float32: 5e-4}),),
dtypes=floating_types(),
dtypesIfCPU=floating_types(),
dtypesIfCUDA=floating_types_and(torch.half),
sample_inputs_func=sample_inputs_polygamma,
skips=(
# Redundant tests
SkipInfo('TestGradients'),
SkipInfo('TestOpInfo'),
SkipInfo('TestCommon'),
# Mismatch: https://github.com/pytorch/pytorch/issues/55357
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_extremal'),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard',
active_if=TEST_WITH_ROCM),
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_normal',
active_if=TEST_WITH_ROCM),),
sample_kwargs=lambda device, dtype, input: ({'n': 4}, {'n': 4})),
OpInfo('pinverse',
op=torch.pinverse,
dtypes=floating_and_complex_types(),