Print the index and summary of the SampleInput that failed an OpInfo test (#99444)

Related to the Reproducible Testing BE project. Goal is to print out the sample input that failed an OpInfo test.

Crazy idea: to avoid requiring widespread changes across tests that use OpInfo sample inputs, return a new special iterator type from `OpInfo.sample_inputs()`, etc. that tracks the most recent item seen. If a test fails later on, print out this info to identify the sample that failed the test.

This solves the problem that the test framework currently has no concept of which sample input is being operated on.

This PR contains the following changes:
* New `TrackedInputIter` that wraps a sample inputs func iterator and tracks the most recent input seen in a `TrackedInput` structure
    * The information is stored in a dictionary on the test function itself, mapping `full test ID -> most recent TrackedInput`
* To determine the test function that is being run, we do some stack crawling hackery in `extract_test_fn_and_id()`
* Above applies only when one of the following is called: `OpInfo.sample_inputs()`, `OpInfo.error_inputs()`, `OpInfo.reference_inputs()`, and `OpInfo.conjugate_sample_inputs()`. This could easily be extended to `ModuleInfo`s and the sparse sample input funcs as well

Example output when a sample input causes a failure:
```
======================================================================
ERROR: test_foo_add_cpu_uint8 (__main__.TestFakeTensorCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 911, in test_wrapper
    return test(*args, **kwargs)
  File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 1097, in only_fn
    return fn(slf, *args, **kwargs)
  File "/home/jbschlosser/branches/reproducible_testing/test/test_ops.py", line 2211, in test_foo
    self.fail('Example failure')
AssertionError: Example failure

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_utils.py", line 2436, in wrapper
    method(*args, **kwargs)
  File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 414, in instantiated_test
    result = test(self, **param_kwargs)
  File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 917, in test_wrapper
    raise Exception(
Exception: Caused by sample input at index 2: SampleInput(input=Tensor[size=(5, 1), device="cpu", dtype=torch.uint8], args=TensorList[Tensor[size=(5,), device="cpu", dtype=torch.uint8]], kwargs={}, broadcasts_input=True, name='')

To execute this test, run the following from the base repo dir:
     python test/test_ops.py -k test_foo_add_cpu_uint8

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
```

This notably doesn't print the actual `SampleInput` values, as that's hard without fully reproducible random sample generation. I went down this path for a while and it seems infeasible without adding an untenable amount of overhead to set the random seed per SampleInput (see https://github.com/pytorch/pytorch/issues/86694#issuecomment-1614943708 for more details). For now, I am settling for at least spitting out the index and some metadata of the `SampleInput`, as it seems better than nothing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99444
Approved by: https://github.com/janeyx99
This commit is contained in:
Joel Schlosser
2023-11-20 15:41:31 -05:00
committed by PyTorch MergeBot
parent e4a88d9581
commit e7f12b1eb0
4 changed files with 129 additions and 14 deletions

View File

@ -12,7 +12,7 @@ import re
import subprocess
import sys
import unittest.mock
from typing import Any, Callable, Iterator, List, Tuple, Generator
from typing import Any, Callable, Iterator, List, Tuple
import torch
@ -2397,19 +2397,19 @@ class TestOpInfoSampleFunctions(TestCase):
def test_opinfo_sample_generators(self, device, dtype, op):
# Test op.sample_inputs doesn't generate multiple samples when called
samples = op.sample_inputs(device, dtype)
self.assertIsInstance(samples, Generator)
self.assertIsInstance(samples, Iterator)
@ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one)
def test_opinfo_reference_generators(self, device, dtype, op):
# Test op.reference_inputs doesn't generate multiple samples when called
samples = op.reference_inputs(device, dtype)
self.assertIsInstance(samples, Generator)
self.assertIsInstance(samples, Iterator)
@ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
def test_opinfo_error_generators(self, device, op):
# Test op.error_inputs doesn't generate multiple inputs when called
samples = op.error_inputs(device)
self.assertIsInstance(samples, Generator)
self.assertIsInstance(samples, Iterator)
instantiate_device_type_tests(TestOpInfoSampleFunctions, globals())