mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
More flexible test parametrization with @reparametrize (#138369)
**Background:** The `@parametrize` decorator enjoys widespread usage as a convenient tool for ensuring extensive test coverage. One particular feature that makes this easy is the ability to stack such decorators, testing over the cross-product of inputs. Example: ```python class MyTestClass(TestCase): @parametrize("x", range(3)) @parametrize("y", [False, True]) def test_foo(self, x, y): # Invoked with: # x=0, y=False # x=1, y=False # x=2, y=False # x=0, y=True # x=1, y=True # x=2, y=True ... ``` Note that the `@ops` and `@modules` decorators employ the same underlying machinery for parametrizing over `OpInfo` / `ModuleInfo` entries. These decorators also parametrize over op-specific `device` / `dtype` info *according to what is supported for each op*. ```python class MyTestClass(TestCase): @ops(op_db) def test_foo(self, op, device, dtype): # Invoked each OpInfo in the db along with each device / dtype that corresponds # with this op according to the OpInfo entry. ... ``` Note that this in contrast to the naive cross product between ops and devices / dtypes, which would generate too many tests. Certain use cases benefit from a similar type of flexible parametrization that is more intelligent than simple cross-product composition. It is expensive to generate / run too many tests, even if the unneeded ones are skipped appropriately. This PR attempts to generalize such flexible parametrization and satisfy these use cases through the introduction of a `@reparametrize` decorator, which operates on an existing parametrizer and allows for customized on-the-fly parametrization through the use of an `adapter_fn`. Examples: ```python # adapter_fn that adds a new arg def include_is_even_arg(test_name, param_kwargs): x = param_kwargs["x"] is_even = x % 2 == 0 new_param_kwargs = dict(param_kwargs) new_param_kwargs["is_even"] = is_even is_even_suffix = "_even" if is_even else "_odd" new_test_name = f"{test_name}{is_even_suffix}" yield (new_test_name, new_param_kwargs) # adapter_fn that excludes certain values def exclude_odds(test_name, param_kwargs): x = param_kwargs["x"] is_even = x % 2 == 0 yield None if not is_even else (test_name, param_kwargs) class MyTestClass(TestCase): @reparametrize(parametrize("x", range(5)), include_is_even_arg) def test_foo(self, x, is_even): # Invoked with both the x value and the new is_even arg ... @reparametrize(parametrize("x", range(5)), exclude_odds) def test_bar(self, x): # Only invoked with even x values ... ``` For a more real-world use case, imagine you want to write a set of OpInfo tests that parametrize over additional op-specific things beyond `device` / `dtype` (in NJT's case, this includes contiguity type, whether to operate over the batch / ragged / other dims, etc.). The `@reparametrize` decorator allows you to customize the `@ops` parametrization to add in these additional args as they make sense on a per-op basis. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138369 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
ebaa774f96
commit
23d590e518
@ -702,6 +702,61 @@ class parametrize(_TestParametrizer):
|
||||
'Note that this may result from reuse of a generator.')
|
||||
|
||||
|
||||
class reparametrize(_TestParametrizer):
|
||||
"""
|
||||
Decorator for adjusting the way an existing parametrizer operates. This class runs
|
||||
the given adapter_fn on each parametrization produced by the given parametrizer,
|
||||
allowing for on-the-fly parametrization more flexible than the default,
|
||||
product-based composition that occurs when stacking parametrization decorators.
|
||||
|
||||
If the adapter_fn returns None for a given test parametrization, that parametrization
|
||||
will be excluded. Otherwise, it's expected that the adapter_fn returns an iterable of
|
||||
modified parametrizations, with tweaked test names and parameter kwargs.
|
||||
|
||||
Examples::
|
||||
|
||||
def include_is_even_arg(test_name, param_kwargs):
|
||||
x = param_kwargs["x"]
|
||||
is_even = x % 2 == 0
|
||||
new_param_kwargs = dict(param_kwargs)
|
||||
new_param_kwargs["is_even"] = is_even
|
||||
is_even_suffix = "_even" if is_even else "_odd"
|
||||
new_test_name = f"{test_name}{is_even_suffix}"
|
||||
yield (new_test_name, new_param_kwargs)
|
||||
|
||||
...
|
||||
|
||||
@reparametrize(parametrize("x", range(5)), include_is_even_arg)
|
||||
def test_foo(self, x, is_even):
|
||||
...
|
||||
|
||||
def exclude_odds(test_name, param_kwargs):
|
||||
x = param_kwargs["x"]
|
||||
is_even = x % 2 == 0
|
||||
yield None if not is_even else (test_name, param_kwargs)
|
||||
|
||||
...
|
||||
|
||||
@reparametrize(parametrize("x", range(5)), exclude_odds)
|
||||
def test_bar(self, x):
|
||||
...
|
||||
|
||||
"""
|
||||
def __init__(self, parametrizer, adapter_fn):
|
||||
self.parametrizer = parametrizer
|
||||
self.adapter_fn = adapter_fn
|
||||
|
||||
def _parametrize_test(self, test, generic_cls, device_cls):
|
||||
for (gen_test, test_name, param_kwargs, decorator_fn) in \
|
||||
self.parametrizer._parametrize_test(test, generic_cls, device_cls):
|
||||
adapted = self.adapter_fn(test_name, param_kwargs)
|
||||
if adapted is not None:
|
||||
for adapted_item in adapted:
|
||||
if adapted_item is not None:
|
||||
new_test_name, new_param_kwargs = adapted_item
|
||||
yield (gen_test, new_test_name, new_param_kwargs, decorator_fn)
|
||||
|
||||
|
||||
class decorateIf(_TestParametrizer):
|
||||
"""
|
||||
Decorator for applying parameter-specific conditional decoration.
|
||||
|
Reference in New Issue
Block a user