Files
pytorch/test/onnx/error_reproduction.py
Xuehai Pan 7763c83af6 [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
ghstack dependencies: #127122, #127123, #127124, #127125
2024-05-27 04:22:18 +00:00

178 lines
3.4 KiB
Python

"""Error reproduction utilities for op consistency tests."""
from __future__ import annotations
import difflib
import pathlib
import platform
import sys
import time
import traceback
import numpy as np
import onnx
import onnxruntime as ort
import onnxscript
import torch
_MISMATCH_MARKDOWN_TEMPLATE = """\
### Summary
The output of ONNX Runtime does not match that of PyTorch when executing test
`{test_name}`, `sample {sample_num}` in ONNX Script `TorchLib`.
To recreate this report, use
```bash
CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name}
```
### ONNX Model
```
{onnx_model_text}
```
### Inputs
Shapes: `{input_shapes}`
<details><summary>Details</summary>
<p>
```python
kwargs = {kwargs}
inputs = {inputs}
```
</p>
</details>
### Expected output
Shape: `{expected_shape}`
<details><summary>Details</summary>
<p>
```python
expected = {expected}
```
</p>
</details>
### Actual output
Shape: `{actual_shape}`
<details><summary>Details</summary>
<p>
```python
actual = {actual}
```
</p>
</details>
### Difference
<details><summary>Details</summary>
<p>
```diff
{diff}
```
</p>
</details>
### Full error stack
```
{error_stack}
```
### Environment
```
{sys_info}
```
"""
def create_mismatch_report(
test_name: str,
sample_num: int,
onnx_model: onnx.ModelProto,
inputs,
kwargs,
actual,
expected,
error: Exception,
) -> None:
torch.set_printoptions(threshold=sys.maxsize)
error_text = str(error)
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
short_test_name = test_name.split(".")[-1]
diff = difflib.unified_diff(
str(actual).splitlines(),
str(expected).splitlines(),
fromfile="actual",
tofile="expected",
lineterm="",
)
onnx_model_text = onnx.printer.to_text(onnx_model)
input_shapes = repr(
[
f"Tensor<{inp.shape}, dtype={inp.dtype}>"
if isinstance(inp, torch.Tensor)
else inp
for inp in inputs
]
)
sys_info = f"""\
OS: {platform.platform()}
Python version: {sys.version}
onnx=={onnx.__version__}
onnxruntime=={ort.__version__}
onnxscript=={onnxscript.__version__}
numpy=={np.__version__}
torch=={torch.__version__}"""
markdown = _MISMATCH_MARKDOWN_TEMPLATE.format(
test_name=test_name,
short_test_name=short_test_name,
sample_num=sample_num,
input_shapes=input_shapes,
inputs=inputs,
kwargs=kwargs,
expected=expected,
expected_shape=expected.shape if isinstance(expected, torch.Tensor) else None,
actual=actual,
actual_shape=actual.shape if isinstance(actual, torch.Tensor) else None,
diff="\n".join(diff),
error_stack=error_stack,
sys_info=sys_info,
onnx_model_text=onnx_model_text,
)
markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md'
markdown_file_path = save_error_report(markdown_file_name, markdown)
print(f"Created reproduction report at {markdown_file_path}")
def save_error_report(file_name: str, text: str):
reports_dir = pathlib.Path("error_reports")
reports_dir.mkdir(parents=True, exist_ok=True)
file_path = reports_dir / file_name
with open(file_path, "w", encoding="utf-8") as f:
f.write(text)
return file_path