[BE] Improve pytest summary display for OpInfo tests (#162961)

pytest summarizes test failures by printing a truncated first line of the test of the OUTERMOST wrapped exception.

Prior to this PR, it looked like this:

```
FAILED [0.0454s] test/distributed/tensor/test_dtensor_ops.py::TestLocalDTensorOpsCPU::test_dtensor_op_db_H_cpu_float32 - Exception: Caused by sample input at index 0: SampleInput(input=Tensor[size=(12, 12), device="cpu", dtype=torch.float32], args=(), kwargs={}, ...
```

I argue this is not so useful.  If I have a lot of test failures, I look to the test summary to understand what /kind/ of errors I have, so I can assess which ones I should look at first.  In other words, this is better:

```
FAILED [0.1387s] test/distributed/tensor/test_dtensor_ops.py::TestLocalDTensorOpsCPU::test_dtensor_op_db__softmax_backward_data_cpu_float32 - Exception: Tensor-likes are not close!
```

Now I know specifically this is a numerics problem!

This PR does it by prepending the old exception text to the wrapped exception.  This is slightly redundant, as we are exception chaining, but it does the job.  Open to bikeshedding.

Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162961
Approved by: https://github.com/malfet
This commit is contained in:
Edward Yang
2025-09-15 10:40:28 -04:00
committed by PyTorch MergeBot
parent de3a863cd8
commit 1247dde1f2
2 changed files with 2 additions and 2 deletions

View File

@ -636,7 +636,7 @@ class TestDTensorOps(DTensorOpTestBase):
) )
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
f"failed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})" f"{str(e)}\n\nfailed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
) from e ) from e
return rs return rs

View File

@ -1139,7 +1139,7 @@ class ops(_TestParametrizer):
tracked_input = get_tracked_input() tracked_input = get_tracked_input()
if PRINT_REPRO_ON_FAILURE and tracked_input is not None: if PRINT_REPRO_ON_FAILURE and tracked_input is not None:
e_tracked = Exception( # noqa: TRY002 e_tracked = Exception( # noqa: TRY002
f"Caused by {tracked_input.type_desc} " f"{str(e)}\n\nCaused by {tracked_input.type_desc} "
f"at index {tracked_input.index}: " f"at index {tracked_input.index}: "
f"{_serialize_sample(tracked_input.val)}" f"{_serialize_sample(tracked_input.val)}"
) )