mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Testing] Adding reference tests to OpInfo
class (#59369)
Summary: This PR will ideally add `ref` argument to `OpInfo` base class. The idea is to add reference checks for all the ops _eligible_. For more discussion, please check https://github.com/pytorch/pytorch/issues/58294 * [x] Migrate (but not removing yet) and modify helper functions from `UnaryUfuncOpInfo` class to `OpInfo` base class. * [x] Test the reference checks for multiple ops. (also decide a list of different and eligible ops for this) * [x] Handle possible edge cases (for example: `uint64` isn't implemented in PyTorch but is there in NumPy, and this needs to be handled -- more on this later) -- _Update_: We decided that these reference tests should only test for values and not types. * [x] Create a sample PR for a single (of all different categories?) on adding reference functions to the eligible ops. -- _Update_: This is being done in this PR only. * [x] ~Remove reference tests from `test_unary_ufuncs.py` and test to make sure that nothing breaks.~ (*Update*: We won't be touching Unary Ufunc reference tests in this PR) * [x] Add comments, remove unnecessary prints/comments (added for debugging). Note: To keep the PR description short, examples of edge cases encountered have been mentioned in the comments below. cc: mruberry pmeier kshitij12345 Pull Request resolved: https://github.com/pytorch/pytorch/pull/59369 Reviewed By: ngimel Differential Revision: D29347252 Pulled By: mruberry fbshipit-source-id: 69719deddb1d23c53db45287a7e66c1bfe7e65bb
This commit is contained in:
committed by
Facebook GitHub Bot
parent
236d3afd82
commit
08020220f3
@ -9,10 +9,9 @@ import warnings
|
||||
from torch._six import nan
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, make_tensor, torch_to_numpy_dtype_dict)
|
||||
from torch.testing._internal.common_methods_invocations import shape_funcs
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyOnCPUAndCUDA,
|
||||
dtypesIfCPU, dtypesIfCUDA, largeTensorTest, ops)
|
||||
dtypesIfCPU, dtypesIfCUDA, largeTensorTest)
|
||||
|
||||
# TODO: replace with make_tensor
|
||||
def _generate_input(shape, dtype, device, with_extremal):
|
||||
@ -673,21 +672,7 @@ class TestShapeOps(TestCase):
|
||||
nz = x.nonzero()
|
||||
self.assertFalse(nz.requires_grad)
|
||||
|
||||
class TestShapeFuncs(TestCase):
|
||||
"""Test suite for Shape manipulating operators using the ShapeFuncInfo."""
|
||||
|
||||
@dtypes(*(torch.uint8, torch.int64, torch.double, torch.complex128))
|
||||
@ops([op for op in shape_funcs if op.name in ['tile', 'repeat']])
|
||||
def test_repeat_tile_vs_numpy(self, device, dtype, op):
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
for sample in samples:
|
||||
assert isinstance(sample.input, torch.Tensor)
|
||||
expected = op.ref(sample.input.cpu().numpy(), *sample.args, **sample.kwargs)
|
||||
result = op(sample.input, *sample.args, **sample.kwargs).cpu().numpy()
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
instantiate_device_type_tests(TestShapeOps, globals())
|
||||
instantiate_device_type_tests(TestShapeFuncs, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user