More flexible test parametrization with @reparametrize (#138369)

**Background:** The `@parametrize` decorator enjoys widespread usage as a convenient tool for ensuring extensive test coverage. One particular feature that makes this easy is the ability to stack such decorators, testing over the cross-product of inputs. Example:
```python
class MyTestClass(TestCase):
    @parametrize("x", range(3))
    @parametrize("y", [False, True])
    def test_foo(self, x, y):
        # Invoked with:
        # x=0, y=False
        # x=1, y=False
        # x=2, y=False
        # x=0, y=True
        # x=1, y=True
        # x=2, y=True
        ...
```

Note that the `@ops` and `@modules` decorators employ the same underlying machinery for parametrizing over `OpInfo` / `ModuleInfo` entries. These decorators also parametrize over op-specific `device` / `dtype` info *according to what is supported for each op*.
```python
class MyTestClass(TestCase):
    @ops(op_db)
    def test_foo(self, op, device, dtype):
        # Invoked each OpInfo in the db along with each device / dtype that corresponds
        # with this op according to the OpInfo entry.
        ...
```

Note that this in contrast to the naive cross product between ops and devices / dtypes, which would generate too many tests. Certain use cases benefit from a similar type of flexible parametrization that is more intelligent than simple cross-product composition. It is expensive to generate / run too many tests, even if the unneeded ones are skipped appropriately.

This PR attempts to generalize such flexible parametrization and satisfy these use cases through the introduction of a `@reparametrize` decorator, which operates on an existing parametrizer and allows for customized on-the-fly parametrization through the use of an `adapter_fn`. Examples:
```python
# adapter_fn that adds a new arg
 def include_is_even_arg(test_name, param_kwargs):
    x = param_kwargs["x"]
    is_even = x % 2 == 0
    new_param_kwargs = dict(param_kwargs)
    new_param_kwargs["is_even"] = is_even
    is_even_suffix = "_even" if is_even else "_odd"
    new_test_name = f"{test_name}{is_even_suffix}"
    yield (new_test_name, new_param_kwargs)

# adapter_fn that excludes certain values
def exclude_odds(test_name, param_kwargs):
    x = param_kwargs["x"]
    is_even = x % 2 == 0
    yield None if not is_even else (test_name, param_kwargs)

class MyTestClass(TestCase):
    @reparametrize(parametrize("x", range(5)), include_is_even_arg)
    def test_foo(self, x, is_even):
        # Invoked with both the x value and the new is_even arg
        ...

    @reparametrize(parametrize("x", range(5)), exclude_odds)
    def test_bar(self, x):
        # Only invoked with even x values
        ...
```

For a more real-world use case, imagine you want to write a set of OpInfo tests that parametrize over additional op-specific things beyond `device` / `dtype` (in NJT's case, this includes contiguity type, whether to operate over the batch / ragged / other dims, etc.). The `@reparametrize` decorator allows you to customize the `@ops` parametrization to add in these additional args as they make sense on a per-op basis.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138369
Approved by: https://github.com/janeyx99
This commit is contained in:
Joel Schlosser
2024-10-29 14:57:03 -04:00
committed by PyTorch MergeBot
parent ebaa774f96
commit 23d590e518
2 changed files with 100 additions and 3 deletions

View File

@ -17,9 +17,11 @@ from typing import Any, Callable, Iterator, List, Tuple
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import \
(IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest,
parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM, decorateIf, skipIfRocm)
from torch.testing._internal.common_utils import (
IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest,
parametrize, reparametrize, subtest, instantiate_parametrized_tests, dtype_name,
TEST_WITH_ROCM, decorateIf, skipIfRocm
)
from torch.testing._internal.common_device_type import \
(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes,
@ -1651,6 +1653,46 @@ class TestTestParametrization(TestCase):
test_names = _get_test_names_for_test_class(TestParametrized)
self.assertEqual(expected_test_names, test_names)
def test_reparametrize(self):
def include_is_even_arg(test_name, param_kwargs):
x = param_kwargs["x"]
is_even = x % 2 == 0
new_param_kwargs = dict(param_kwargs)
new_param_kwargs["is_even"] = is_even
is_even_suffix = "_even" if is_even else "_odd"
new_test_name = f"{test_name}{is_even_suffix}"
yield (new_test_name, new_param_kwargs)
def exclude_odds(test_name, param_kwargs):
x = param_kwargs["x"]
is_even = x % 2 == 0
yield None if not is_even else (test_name, param_kwargs)
class TestParametrized(TestCase):
@reparametrize(parametrize("x", range(5)), include_is_even_arg)
def test_foo(self, x, is_even):
pass
@reparametrize(parametrize("x", range(5)), exclude_odds)
def test_bar(self, x):
pass
instantiate_parametrized_tests(TestParametrized)
expected_test_names = [
'TestParametrized.test_bar_x_0',
'TestParametrized.test_bar_x_2',
'TestParametrized.test_bar_x_4',
'TestParametrized.test_foo_x_0_even',
'TestParametrized.test_foo_x_1_odd',
'TestParametrized.test_foo_x_2_even',
'TestParametrized.test_foo_x_3_odd',
'TestParametrized.test_foo_x_4_even',
]
test_names = _get_test_names_for_test_class(TestParametrized)
self.assertEqual(expected_test_names, test_names)
def test_subtest_names(self):
class TestParametrized(TestCase):

View File

@ -702,6 +702,61 @@ class parametrize(_TestParametrizer):
'Note that this may result from reuse of a generator.')
class reparametrize(_TestParametrizer):
"""
Decorator for adjusting the way an existing parametrizer operates. This class runs
the given adapter_fn on each parametrization produced by the given parametrizer,
allowing for on-the-fly parametrization more flexible than the default,
product-based composition that occurs when stacking parametrization decorators.
If the adapter_fn returns None for a given test parametrization, that parametrization
will be excluded. Otherwise, it's expected that the adapter_fn returns an iterable of
modified parametrizations, with tweaked test names and parameter kwargs.
Examples::
def include_is_even_arg(test_name, param_kwargs):
x = param_kwargs["x"]
is_even = x % 2 == 0
new_param_kwargs = dict(param_kwargs)
new_param_kwargs["is_even"] = is_even
is_even_suffix = "_even" if is_even else "_odd"
new_test_name = f"{test_name}{is_even_suffix}"
yield (new_test_name, new_param_kwargs)
...
@reparametrize(parametrize("x", range(5)), include_is_even_arg)
def test_foo(self, x, is_even):
...
def exclude_odds(test_name, param_kwargs):
x = param_kwargs["x"]
is_even = x % 2 == 0
yield None if not is_even else (test_name, param_kwargs)
...
@reparametrize(parametrize("x", range(5)), exclude_odds)
def test_bar(self, x):
...
"""
def __init__(self, parametrizer, adapter_fn):
self.parametrizer = parametrizer
self.adapter_fn = adapter_fn
def _parametrize_test(self, test, generic_cls, device_cls):
for (gen_test, test_name, param_kwargs, decorator_fn) in \
self.parametrizer._parametrize_test(test, generic_cls, device_cls):
adapted = self.adapter_fn(test_name, param_kwargs)
if adapted is not None:
for adapted_item in adapted:
if adapted_item is not None:
new_test_name, new_param_kwargs = adapted_item
yield (gen_test, new_test_name, new_param_kwargs, decorator_fn)
class decorateIf(_TestParametrizer):
"""
Decorator for applying parameter-specific conditional decoration.