diff --git a/test/test_testing.py b/test/test_testing.py index 56ce579374b9..ee10107d4d30 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -17,9 +17,11 @@ from typing import Any, Callable, Iterator, List, Tuple import torch from torch.testing import make_tensor -from torch.testing._internal.common_utils import \ - (IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest, - parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM, decorateIf, skipIfRocm) +from torch.testing._internal.common_utils import ( + IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest, + parametrize, reparametrize, subtest, instantiate_parametrized_tests, dtype_name, + TEST_WITH_ROCM, decorateIf, skipIfRocm +) from torch.testing._internal.common_device_type import \ (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes, get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes, @@ -1651,6 +1653,46 @@ class TestTestParametrization(TestCase): test_names = _get_test_names_for_test_class(TestParametrized) self.assertEqual(expected_test_names, test_names) + def test_reparametrize(self): + + def include_is_even_arg(test_name, param_kwargs): + x = param_kwargs["x"] + is_even = x % 2 == 0 + new_param_kwargs = dict(param_kwargs) + new_param_kwargs["is_even"] = is_even + is_even_suffix = "_even" if is_even else "_odd" + new_test_name = f"{test_name}{is_even_suffix}" + yield (new_test_name, new_param_kwargs) + + def exclude_odds(test_name, param_kwargs): + x = param_kwargs["x"] + is_even = x % 2 == 0 + yield None if not is_even else (test_name, param_kwargs) + + class TestParametrized(TestCase): + @reparametrize(parametrize("x", range(5)), include_is_even_arg) + def test_foo(self, x, is_even): + pass + + @reparametrize(parametrize("x", range(5)), exclude_odds) + def test_bar(self, x): + pass + + instantiate_parametrized_tests(TestParametrized) + + expected_test_names = [ + 'TestParametrized.test_bar_x_0', + 'TestParametrized.test_bar_x_2', + 'TestParametrized.test_bar_x_4', + 'TestParametrized.test_foo_x_0_even', + 'TestParametrized.test_foo_x_1_odd', + 'TestParametrized.test_foo_x_2_even', + 'TestParametrized.test_foo_x_3_odd', + 'TestParametrized.test_foo_x_4_even', + ] + test_names = _get_test_names_for_test_class(TestParametrized) + self.assertEqual(expected_test_names, test_names) + def test_subtest_names(self): class TestParametrized(TestCase): diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index b34aa68f1ece..deb46a2f337c 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -702,6 +702,61 @@ class parametrize(_TestParametrizer): 'Note that this may result from reuse of a generator.') +class reparametrize(_TestParametrizer): + """ + Decorator for adjusting the way an existing parametrizer operates. This class runs + the given adapter_fn on each parametrization produced by the given parametrizer, + allowing for on-the-fly parametrization more flexible than the default, + product-based composition that occurs when stacking parametrization decorators. + + If the adapter_fn returns None for a given test parametrization, that parametrization + will be excluded. Otherwise, it's expected that the adapter_fn returns an iterable of + modified parametrizations, with tweaked test names and parameter kwargs. + + Examples:: + + def include_is_even_arg(test_name, param_kwargs): + x = param_kwargs["x"] + is_even = x % 2 == 0 + new_param_kwargs = dict(param_kwargs) + new_param_kwargs["is_even"] = is_even + is_even_suffix = "_even" if is_even else "_odd" + new_test_name = f"{test_name}{is_even_suffix}" + yield (new_test_name, new_param_kwargs) + + ... + + @reparametrize(parametrize("x", range(5)), include_is_even_arg) + def test_foo(self, x, is_even): + ... + + def exclude_odds(test_name, param_kwargs): + x = param_kwargs["x"] + is_even = x % 2 == 0 + yield None if not is_even else (test_name, param_kwargs) + + ... + + @reparametrize(parametrize("x", range(5)), exclude_odds) + def test_bar(self, x): + ... + + """ + def __init__(self, parametrizer, adapter_fn): + self.parametrizer = parametrizer + self.adapter_fn = adapter_fn + + def _parametrize_test(self, test, generic_cls, device_cls): + for (gen_test, test_name, param_kwargs, decorator_fn) in \ + self.parametrizer._parametrize_test(test, generic_cls, device_cls): + adapted = self.adapter_fn(test_name, param_kwargs) + if adapted is not None: + for adapted_item in adapted: + if adapted_item is not None: + new_test_name, new_param_kwargs = adapted_item + yield (gen_test, new_test_name, new_param_kwargs, decorator_fn) + + class decorateIf(_TestParametrizer): """ Decorator for applying parameter-specific conditional decoration.