mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
General per-SampleInput xfail / skip system (#140443)
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python @dataclass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" @dataclass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" @dataclass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): @ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): @ops(op_db) @sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS) # NB: the @ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op): for sample, subtest_ctx in op.sample_inputs( # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True, # an informative error will be thrown. device, dtype, requires_grad=False, use_subtests=True ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. I also demonstrate usage of syntactic sugar over this system in `test/functorch/test_vmap.py`. Here, a skip for the `to()` operator is replaced with a granular xfail for `test_vmap_exhaustive()`: ```python ... # pre-existing xfail xfail("item"), # new granular xfail using syntactic sugar over the general system xfailIf( "to", lambda sample: ( sample.kwargs["memory_format"] == torch.channels_last ), ), ... ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/140443 Approved by: https://github.com/janeyx99, https://github.com/zou3519 ghstack dependencies: #140160, #138370
This commit is contained in:
committed by
PyTorch MergeBot
parent
cee3f8541e
commit
780c580d68
@ -367,13 +367,30 @@ def clear_tracked_input():
|
||||
# Wraps an iterator and tracks the most recent value the iterator produces
|
||||
# for debugging purposes. Tracked values are stored on the test function.
|
||||
class TrackedInputIter:
|
||||
def __init__(self, child_iter, input_type_desc,
|
||||
callback=lambda x: x, set_seed=True, restrict_to_index=None):
|
||||
def __init__(
|
||||
self,
|
||||
child_iter,
|
||||
input_type_desc,
|
||||
item_callback=None,
|
||||
track_callback=None,
|
||||
set_seed=True,
|
||||
restrict_to_index=None
|
||||
):
|
||||
self.child_iter = enumerate(child_iter)
|
||||
# Input type describes the things we're tracking (e.g. "sample input", "error input").
|
||||
self.input_type_desc = input_type_desc
|
||||
# Callback is run on each iterated thing to get the thing to track.
|
||||
self.callback = callback
|
||||
# NB: The two types of callbacks below exist because the thing we want to track isn't
|
||||
# always the same as the thing we want returned from the iterator. An example of this
|
||||
# is ErrorInput, which we want returned from the iterator, but which contains a
|
||||
# SampleInput that we want to track.
|
||||
# Item callback is run on each (iterated thing, index) to get the thing to return.
|
||||
self.item_callback = item_callback
|
||||
if self.item_callback is None:
|
||||
self.item_callback = lambda x, i: x
|
||||
# Track callback is run on each iterated thing to get the thing to track.
|
||||
self.track_callback = track_callback
|
||||
if self.track_callback is None:
|
||||
self.track_callback = lambda x: x
|
||||
self.test_fn = extract_test_fn()
|
||||
# Indicates whether the random seed should be set before each call to the iterator
|
||||
self.set_seed = set_seed
|
||||
@ -402,10 +419,10 @@ class TrackedInputIter:
|
||||
|
||||
self._set_tracked_input(
|
||||
TrackedInput(
|
||||
index=input_idx, val=self.callback(input_val), type_desc=self.input_type_desc
|
||||
index=input_idx, val=self.track_callback(input_val), type_desc=self.input_type_desc
|
||||
)
|
||||
)
|
||||
return input_val
|
||||
return self.item_callback(input_val, input_idx)
|
||||
|
||||
def _set_tracked_input(self, tracked_input: TrackedInput):
|
||||
if self.test_fn is None:
|
||||
|
||||
Reference in New Issue
Block a user