mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Print the index and summary of the SampleInput that failed an OpInfo test (#99444)
Related to the Reproducible Testing BE project. Goal is to print out the sample input that failed an OpInfo test. Crazy idea: to avoid requiring widespread changes across tests that use OpInfo sample inputs, return a new special iterator type from `OpInfo.sample_inputs()`, etc. that tracks the most recent item seen. If a test fails later on, print out this info to identify the sample that failed the test. This solves the problem that the test framework currently has no concept of which sample input is being operated on. This PR contains the following changes: * New `TrackedInputIter` that wraps a sample inputs func iterator and tracks the most recent input seen in a `TrackedInput` structure * The information is stored in a dictionary on the test function itself, mapping `full test ID -> most recent TrackedInput` * To determine the test function that is being run, we do some stack crawling hackery in `extract_test_fn_and_id()` * Above applies only when one of the following is called: `OpInfo.sample_inputs()`, `OpInfo.error_inputs()`, `OpInfo.reference_inputs()`, and `OpInfo.conjugate_sample_inputs()`. This could easily be extended to `ModuleInfo`s and the sparse sample input funcs as well Example output when a sample input causes a failure: ``` ====================================================================== ERROR: test_foo_add_cpu_uint8 (__main__.TestFakeTensorCPU) ---------------------------------------------------------------------- Traceback (most recent call last): File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 911, in test_wrapper return test(*args, **kwargs) File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 1097, in only_fn return fn(slf, *args, **kwargs) File "/home/jbschlosser/branches/reproducible_testing/test/test_ops.py", line 2211, in test_foo self.fail('Example failure') AssertionError: Example failure The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_utils.py", line 2436, in wrapper method(*args, **kwargs) File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 414, in instantiated_test result = test(self, **param_kwargs) File "/home/jbschlosser/branches/reproducible_testing/torch/testing/_internal/common_device_type.py", line 917, in test_wrapper raise Exception( Exception: Caused by sample input at index 2: SampleInput(input=Tensor[size=(5, 1), device="cpu", dtype=torch.uint8], args=TensorList[Tensor[size=(5,), device="cpu", dtype=torch.uint8]], kwargs={}, broadcasts_input=True, name='') To execute this test, run the following from the base repo dir: python test/test_ops.py -k test_foo_add_cpu_uint8 This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ---------------------------------------------------------------------- ``` This notably doesn't print the actual `SampleInput` values, as that's hard without fully reproducible random sample generation. I went down this path for a while and it seems infeasible without adding an untenable amount of overhead to set the random seed per SampleInput (see https://github.com/pytorch/pytorch/issues/86694#issuecomment-1614943708 for more details). For now, I am settling for at least spitting out the index and some metadata of the `SampleInput`, as it seems better than nothing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99444 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
e4a88d9581
commit
e7f12b1eb0
@ -12,7 +12,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest.mock
|
||||
from typing import Any, Callable, Iterator, List, Tuple, Generator
|
||||
from typing import Any, Callable, Iterator, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -2397,19 +2397,19 @@ class TestOpInfoSampleFunctions(TestCase):
|
||||
def test_opinfo_sample_generators(self, device, dtype, op):
|
||||
# Test op.sample_inputs doesn't generate multiple samples when called
|
||||
samples = op.sample_inputs(device, dtype)
|
||||
self.assertIsInstance(samples, Generator)
|
||||
self.assertIsInstance(samples, Iterator)
|
||||
|
||||
@ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one)
|
||||
def test_opinfo_reference_generators(self, device, dtype, op):
|
||||
# Test op.reference_inputs doesn't generate multiple samples when called
|
||||
samples = op.reference_inputs(device, dtype)
|
||||
self.assertIsInstance(samples, Generator)
|
||||
self.assertIsInstance(samples, Iterator)
|
||||
|
||||
@ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
|
||||
def test_opinfo_error_generators(self, device, op):
|
||||
# Test op.error_inputs doesn't generate multiple inputs when called
|
||||
samples = op.error_inputs(device)
|
||||
self.assertIsInstance(samples, Generator)
|
||||
self.assertIsInstance(samples, Iterator)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestOpInfoSampleFunctions, globals())
|
||||
|
Reference in New Issue
Block a user