mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Changes: 1. Always explicit `.absolute()`: `Path(__file__)` -> `Path(__file__).absolute()` 2. Replace `path.resolve()` with `path.absolute()` if the code is resolving the PyTorch repo root directory. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129409 Approved by: https://github.com/albanD
179 lines
3.4 KiB
Python
179 lines
3.4 KiB
Python
"""Error reproduction utilities for op consistency tests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import difflib
|
|
import platform
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
import onnx
|
|
import onnxruntime as ort
|
|
import onnxscript
|
|
|
|
import torch
|
|
|
|
|
|
_MISMATCH_MARKDOWN_TEMPLATE = """\
|
|
### Summary
|
|
|
|
The output of ONNX Runtime does not match that of PyTorch when executing test
|
|
`{test_name}`, `sample {sample_num}` in ONNX Script `TorchLib`.
|
|
|
|
To recreate this report, use
|
|
|
|
```bash
|
|
CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name}
|
|
```
|
|
|
|
### ONNX Model
|
|
|
|
```
|
|
{onnx_model_text}
|
|
```
|
|
|
|
### Inputs
|
|
|
|
Shapes: `{input_shapes}`
|
|
|
|
<details><summary>Details</summary>
|
|
<p>
|
|
|
|
```python
|
|
kwargs = {kwargs}
|
|
inputs = {inputs}
|
|
```
|
|
|
|
</p>
|
|
</details>
|
|
|
|
### Expected output
|
|
|
|
Shape: `{expected_shape}`
|
|
|
|
<details><summary>Details</summary>
|
|
<p>
|
|
|
|
```python
|
|
expected = {expected}
|
|
```
|
|
|
|
</p>
|
|
</details>
|
|
|
|
### Actual output
|
|
|
|
Shape: `{actual_shape}`
|
|
|
|
<details><summary>Details</summary>
|
|
<p>
|
|
|
|
```python
|
|
actual = {actual}
|
|
```
|
|
|
|
</p>
|
|
</details>
|
|
|
|
### Difference
|
|
|
|
<details><summary>Details</summary>
|
|
<p>
|
|
|
|
```diff
|
|
{diff}
|
|
```
|
|
|
|
</p>
|
|
</details>
|
|
|
|
### Full error stack
|
|
|
|
```
|
|
{error_stack}
|
|
```
|
|
|
|
### Environment
|
|
|
|
```
|
|
{sys_info}
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
def create_mismatch_report(
|
|
test_name: str,
|
|
sample_num: int,
|
|
onnx_model: onnx.ModelProto,
|
|
inputs,
|
|
kwargs,
|
|
actual,
|
|
expected,
|
|
error: Exception,
|
|
) -> None:
|
|
torch.set_printoptions(threshold=sys.maxsize)
|
|
|
|
error_text = str(error)
|
|
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
|
|
short_test_name = test_name.split(".")[-1]
|
|
diff = difflib.unified_diff(
|
|
str(actual).splitlines(),
|
|
str(expected).splitlines(),
|
|
fromfile="actual",
|
|
tofile="expected",
|
|
lineterm="",
|
|
)
|
|
onnx_model_text = onnx.printer.to_text(onnx_model)
|
|
input_shapes = repr(
|
|
[
|
|
f"Tensor<{inp.shape}, dtype={inp.dtype}>"
|
|
if isinstance(inp, torch.Tensor)
|
|
else inp
|
|
for inp in inputs
|
|
]
|
|
)
|
|
sys_info = f"""\
|
|
OS: {platform.platform()}
|
|
Python version: {sys.version}
|
|
onnx=={onnx.__version__}
|
|
onnxruntime=={ort.__version__}
|
|
onnxscript=={onnxscript.__version__}
|
|
numpy=={np.__version__}
|
|
torch=={torch.__version__}"""
|
|
|
|
markdown = _MISMATCH_MARKDOWN_TEMPLATE.format(
|
|
test_name=test_name,
|
|
short_test_name=short_test_name,
|
|
sample_num=sample_num,
|
|
input_shapes=input_shapes,
|
|
inputs=inputs,
|
|
kwargs=kwargs,
|
|
expected=expected,
|
|
expected_shape=expected.shape if isinstance(expected, torch.Tensor) else None,
|
|
actual=actual,
|
|
actual_shape=actual.shape if isinstance(actual, torch.Tensor) else None,
|
|
diff="\n".join(diff),
|
|
error_stack=error_stack,
|
|
sys_info=sys_info,
|
|
onnx_model_text=onnx_model_text,
|
|
)
|
|
|
|
markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md'
|
|
markdown_file_path = save_error_report(markdown_file_name, markdown)
|
|
print(f"Created reproduction report at {markdown_file_path}")
|
|
|
|
|
|
def save_error_report(file_name: str, text: str):
|
|
reports_dir = Path("error_reports")
|
|
reports_dir.mkdir(parents=True, exist_ok=True)
|
|
file_path = reports_dir / file_name
|
|
with open(file_path, "w", encoding="utf-8") as f:
|
|
f.write(text)
|
|
|
|
return file_path
|