[Testing] Adding reference tests to OpInfo class (#59369)

Summary:
This PR will ideally add `ref` argument to `OpInfo` base class. The idea is to add reference checks for all the ops _eligible_. For more discussion, please check https://github.com/pytorch/pytorch/issues/58294

* [x] Migrate (but not removing yet) and modify helper functions from `UnaryUfuncOpInfo` class to `OpInfo` base class.
* [x] Test the reference checks for multiple ops. (also decide a list of different and eligible ops for this)
* [x] Handle possible edge cases (for example: `uint64` isn't implemented in PyTorch but is there in NumPy, and this needs to be handled -- more on this later) -- _Update_: We decided that these reference tests should only test for values and not types.
* [x] Create a sample PR for a single (of all different categories?) on adding reference functions to the eligible ops. -- _Update_: This is being done in this PR only.
* [x] ~Remove reference tests from `test_unary_ufuncs.py` and test to make sure that nothing breaks.~ (*Update*: We won't be touching Unary Ufunc reference tests in this PR)
* [x] Add comments, remove unnecessary prints/comments (added for debugging).

Note: To keep the PR description short, examples of edge cases encountered have been mentioned in the comments below.

cc: mruberry pmeier kshitij12345

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59369

Reviewed By: ngimel

Differential Revision: D29347252

Pulled By: mruberry

fbshipit-source-id: 69719deddb1d23c53db45287a7e66c1bfe7e65bb
This commit is contained in:
Kushashwa Ravi Shrimali
2021-06-23 19:21:17 -07:00
committed by Facebook GitHub Bot
parent 236d3afd82
commit 08020220f3
4 changed files with 72 additions and 28 deletions

View File

@ -9,10 +9,9 @@ import warnings
from torch._six import nan
from torch.testing._internal.common_utils import (
TestCase, run_tests, make_tensor, torch_to_numpy_dtype_dict)
from torch.testing._internal.common_methods_invocations import shape_funcs
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyOnCPUAndCUDA,
dtypesIfCPU, dtypesIfCUDA, largeTensorTest, ops)
dtypesIfCPU, dtypesIfCUDA, largeTensorTest)
# TODO: replace with make_tensor
def _generate_input(shape, dtype, device, with_extremal):
@ -673,21 +672,7 @@ class TestShapeOps(TestCase):
nz = x.nonzero()
self.assertFalse(nz.requires_grad)
class TestShapeFuncs(TestCase):
"""Test suite for Shape manipulating operators using the ShapeFuncInfo."""
@dtypes(*(torch.uint8, torch.int64, torch.double, torch.complex128))
@ops([op for op in shape_funcs if op.name in ['tile', 'repeat']])
def test_repeat_tile_vs_numpy(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=False)
for sample in samples:
assert isinstance(sample.input, torch.Tensor)
expected = op.ref(sample.input.cpu().numpy(), *sample.args, **sample.kwargs)
result = op(sample.input, *sample.args, **sample.kwargs).cpu().numpy()
self.assertEqual(expected, result)
instantiate_device_type_tests(TestShapeOps, globals())
instantiate_device_type_tests(TestShapeFuncs, globals())
if __name__ == '__main__':
run_tests()