mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129771 Approved by: https://github.com/justinchuby, https://github.com/janeyx99
211 lines
6.8 KiB
Python
211 lines
6.8 KiB
Python
# mypy: allow-untyped-defs
|
|
"""Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`."""
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import gzip
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch.onnx._internal.diagnostics import infra
|
|
from torch.onnx._internal.diagnostics.infra import formatter, sarif
|
|
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
|
|
from torch.utils import cpp_backtrace
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Generator
|
|
|
|
|
|
def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.Stack:
|
|
"""Returns the current C++ call stack.
|
|
|
|
This function utilizes `torch.utils.cpp_backtrace` to get the current C++ call stack.
|
|
The returned C++ call stack is a concatenated string of the C++ call stack frames.
|
|
Each frame is separated by a newline character, in the same format of
|
|
r"frame #[0-9]+: (?P<frame_info>.*)". More info at `c10/util/Backtrace.cpp`.
|
|
|
|
"""
|
|
frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n")
|
|
frame_messages = []
|
|
for frame in frames:
|
|
segments = frame.split(":", 1)
|
|
if len(segments) == 2:
|
|
frame_messages.append(segments[1].strip())
|
|
else:
|
|
frame_messages.append("<unknown frame>")
|
|
return infra.Stack(
|
|
frames=[
|
|
infra.StackFrame(location=infra.Location(message=message))
|
|
for message in frame_messages
|
|
]
|
|
)
|
|
|
|
|
|
class TorchScriptOnnxExportDiagnostic(infra.Diagnostic):
|
|
"""Base class for all export diagnostics.
|
|
|
|
This class is used to represent all export diagnostics. It is a subclass of
|
|
infra.Diagnostic, and adds additional methods to add more information to the
|
|
diagnostic.
|
|
"""
|
|
|
|
python_call_stack: infra.Stack | None = None
|
|
cpp_call_stack: infra.Stack | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
frames_to_skip: int = 1,
|
|
cpp_stack: bool = False,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.python_call_stack = self.record_python_call_stack(
|
|
frames_to_skip=frames_to_skip
|
|
)
|
|
if cpp_stack:
|
|
self.cpp_call_stack = self.record_cpp_call_stack(
|
|
frames_to_skip=frames_to_skip
|
|
)
|
|
|
|
def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
|
|
"""Records the current C++ call stack in the diagnostic."""
|
|
stack = _cpp_call_stack(frames_to_skip=frames_to_skip)
|
|
stack.message = "C++ call stack"
|
|
self.with_stack(stack)
|
|
return stack
|
|
|
|
|
|
class ExportDiagnosticEngine:
|
|
"""PyTorch ONNX Export diagnostic engine.
|
|
|
|
The only purpose of creating this class instead of using `DiagnosticContext` directly
|
|
is to provide a background context for `diagnose` calls inside exporter.
|
|
|
|
By design, one `torch.onnx.export` call should initialize one diagnostic context.
|
|
All `diagnose` calls inside exporter should be made in the context of that export.
|
|
However, since diagnostic context is currently being accessed via a global variable,
|
|
there is no guarantee that the context is properly initialized. Therefore, we need
|
|
to provide a default background context to fallback to, otherwise any invocation of
|
|
exporter internals, e.g. unit tests, will fail due to missing diagnostic context.
|
|
This can be removed once the pipeline for context to flow through the exporter is
|
|
established.
|
|
"""
|
|
|
|
contexts: list[infra.DiagnosticContext]
|
|
_background_context: infra.DiagnosticContext
|
|
|
|
def __init__(self) -> None:
|
|
self.contexts = []
|
|
self._background_context = infra.DiagnosticContext(
|
|
name="torch.onnx",
|
|
version=torch.__version__,
|
|
)
|
|
|
|
@property
|
|
def background_context(self) -> infra.DiagnosticContext:
|
|
return self._background_context
|
|
|
|
def create_diagnostic_context(
|
|
self,
|
|
name: str,
|
|
version: str,
|
|
options: infra.DiagnosticOptions | None = None,
|
|
) -> infra.DiagnosticContext:
|
|
"""Creates a new diagnostic context.
|
|
|
|
Args:
|
|
name: The subject name for the diagnostic context.
|
|
version: The subject version for the diagnostic context.
|
|
options: The options for the diagnostic context.
|
|
|
|
Returns:
|
|
A new diagnostic context.
|
|
"""
|
|
if options is None:
|
|
options = infra.DiagnosticOptions()
|
|
context: infra.DiagnosticContext[infra.Diagnostic] = infra.DiagnosticContext(
|
|
name, version, options
|
|
)
|
|
self.contexts.append(context)
|
|
return context
|
|
|
|
def clear(self):
|
|
"""Clears all diagnostic contexts."""
|
|
self.contexts.clear()
|
|
self._background_context.diagnostics.clear()
|
|
|
|
def to_json(self) -> str:
|
|
return formatter.sarif_to_json(self.sarif_log())
|
|
|
|
def dump(self, file_path: str, compress: bool = False) -> None:
|
|
"""Dumps the SARIF log to a file."""
|
|
if compress:
|
|
with gzip.open(file_path, "wt") as f:
|
|
f.write(self.to_json())
|
|
else:
|
|
with open(file_path, "w") as f:
|
|
f.write(self.to_json())
|
|
|
|
def sarif_log(self):
|
|
log = sarif.SarifLog(
|
|
version=sarif_version.SARIF_VERSION,
|
|
schema_uri=sarif_version.SARIF_SCHEMA_LINK,
|
|
runs=[context.sarif() for context in self.contexts],
|
|
)
|
|
|
|
log.runs.append(self._background_context.sarif())
|
|
return log
|
|
|
|
|
|
engine = ExportDiagnosticEngine()
|
|
_context = engine.background_context
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def create_export_diagnostic_context() -> (
|
|
Generator[infra.DiagnosticContext, None, None]
|
|
):
|
|
"""Create a diagnostic context for export.
|
|
|
|
This is a workaround for code robustness since diagnostic context is accessed by
|
|
export internals via global variable. See `ExportDiagnosticEngine` for more details.
|
|
"""
|
|
global _context
|
|
assert (
|
|
_context == engine.background_context
|
|
), "Export context is already set. Nested export is not supported."
|
|
_context = engine.create_diagnostic_context(
|
|
"torch.onnx.export",
|
|
torch.__version__,
|
|
)
|
|
try:
|
|
yield _context
|
|
finally:
|
|
_context = engine.background_context
|
|
|
|
|
|
def diagnose(
|
|
rule: infra.Rule,
|
|
level: infra.Level,
|
|
message: str | None = None,
|
|
frames_to_skip: int = 2,
|
|
**kwargs,
|
|
) -> TorchScriptOnnxExportDiagnostic:
|
|
"""Creates a diagnostic and record it in the global diagnostic context.
|
|
|
|
This is a wrapper around `context.log` that uses the global diagnostic
|
|
context.
|
|
"""
|
|
diagnostic = TorchScriptOnnxExportDiagnostic(
|
|
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
|
|
)
|
|
export_context().log(diagnostic)
|
|
return diagnostic
|
|
|
|
|
|
def export_context() -> infra.DiagnosticContext:
|
|
global _context
|
|
return _context
|