Files
pytorch/torch/export/_draft_export.py
Maggie Moss f414aa8e0d Add pyrefly suppressions (3/n) (#164588)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/bb31574ac8a59893c9cf52189e67bb2d

after:

 0 errors (1,970 ignored)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164588
Approved by: https://github.com/oulgen
2025-10-03 22:03:03 +00:00

545 lines
20 KiB
Python

import getpass
import json
import logging
import os
import re
import tempfile
import time
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Optional, Union
import torch
import torch._logging._internal
import torch.utils._pytree as pytree
from torch._dynamo.exc import UserError, UserErrorType
from torch._export.passes.insert_custom_op_guards import (
get_op_profiles,
insert_custom_op_guards,
OpProfile,
)
from torch._utils_internal import log_draft_export_usage
from ._trace import _export, get_ep_stats
from .dynamic_shapes import _DimHint, _DimHintType, Dim
from .exported_program import ExportedProgram
log = logging.getLogger(__name__)
class FailureType(IntEnum):
MISSING_FAKE_KERNEL = 1
DATA_DEPENDENT_ERROR = 2
GUARD_ADDED = 3
MISMATCHED_FAKE_KERNEL = 4
def __str__(self) -> str:
return self.name
def prettify_stack(stack: list[dict[str, str]], str_to_filename: dict[int, str]) -> str:
res = ""
for frame in stack:
if frame["filename"] not in str_to_filename:
continue
res += f"""
File {str_to_filename[frame["filename"]]}, lineno {frame["line"]}, in {frame["name"]}""" # type: ignore[index]
res += f"\n {stack[-1]['loc']}"
return res
def prettify_frame_locals(
loc: str, locals: dict[str, Any], symbols: dict[str, Any]
) -> str:
local_str = "\n".join(f" {k}: {v}" for k, v in locals.items())
res = f"""
Locals:
{local_str}
"""
if any(v is not None for v in symbols.values()):
symbol_str = "\n".join(
f" {k}: {v}" for k, v in symbols.items() if v is not None
)
res += f"""
Symbols:
{symbol_str}
"""
return res
def get_loc(filename: str, lineno: int) -> Optional[str]:
try:
with open(filename) as f:
for i, line in enumerate(f):
if i == lineno - 1:
return line.strip()
except FileNotFoundError:
pass
return None
class FailureReport:
def __init__(
self, failure_type: FailureType, data: dict[str, Any], xfail: bool = False
) -> None:
self.failure_type: FailureType = failure_type
self.data: dict[str, Any] = data
self.xfail: bool = xfail
def __repr__(self) -> str:
return f"FailureReport(failure_type={self.failure_type}, xfail={self.xfail}, data={self.data})"
def print(self, str_to_filename: dict[int, str]) -> str:
if self.failure_type == FailureType.MISSING_FAKE_KERNEL:
op = self.data["op"]
return f"""Missing fake kernel.
torch.ops.{op} is missing a fake kernel implementation.
Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a meta implementation.
""" # noqa: B950
elif self.failure_type == FailureType.GUARD_ADDED:
locals_info = (
prettify_frame_locals(**self.data["frame_locals"])
if self.data["frame_locals"]
else ""
)
return f"""Guard Added.
A guard was added during tracing, which might've resulted in some incorrect
tracing or constraint violation error.
Specifically, this guard was added: {self.data["expr"]}, where {self.data["symbol_to_sources"]}.
This occurred at the following stacktrace: {prettify_stack(self.data["user_stack"], str_to_filename)}:
{locals_info}
And the following framework stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}\n
Because of this, we have modified the dynamic shapes structure to be the
following. You can also use torch.export.Dim.AUTO instead to specify your
dynamic shapes, and we will automatically infer the dynamism for you.
```
dynamic_shapes = {self.data["new_dynamic_shapes"]}
```
"""
elif self.failure_type == FailureType.DATA_DEPENDENT_ERROR:
locals_info = (
prettify_frame_locals(**self.data["frame_locals"])
if self.data["frame_locals"]
else ""
)
return f"""Data dependent error.
When exporting, we were unable to evaluate the value of `{self.data["expr"]}`.
This was encountered {self.data["occurrences"]} times.
This occurred at the following user stacktrace: {prettify_stack(self.data["user_stack"], str_to_filename)}
{locals_info}
And the following framework stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}\n
As a result, it was specialized to a constant (e.g. `{self.data["result"]}` in the 1st occurrence), and asserts were inserted into the graph.
Please add `torch._check(...)` to the original code to assert this data-dependent assumption.
Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details.
""" # noqa: B950
elif self.failure_type == FailureType.MISMATCHED_FAKE_KERNEL:
op = self.data["op"]
reason = self.data["reason"]
return f"""Mismatched fake kernel.
torch.ops.{op} has a fake kernel implementation, but it has incorrect behavior, based on the real kernel.
The reason for the mismatch is: {reason}.
Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a fake implementation.
""" # noqa: B950
else:
raise ValueError(f"Unknown failure type: {self.failure_type}")
class DraftExportReport:
def __init__(
self,
failures: list[FailureReport],
str_to_filename: dict[int, str],
expressions_created: dict[int, dict[str, Any]],
op_profiles: dict[str, set[OpProfile]],
):
self.failures: list[FailureReport] = failures
self.str_to_filename = str_to_filename
self.expressions_created: dict[int, dict[str, Any]] = expressions_created
self.op_profiles = op_profiles
def successful(self) -> bool:
return len(self.failures) == 0 or all(
failure.xfail for failure in self.failures
)
def __repr__(self) -> str:
return f"DraftExportReport({self.failures})"
def __str__(self) -> str:
WARNING_COLOR = "\033[93m"
GREEN_COLOR = "\033[92m"
END_COLOR = "\033[0m"
if self.successful():
return f"""{GREEN_COLOR}
##############################################################################################
Congratuations: No issues are found during export, and it was able to soundly produce a graph.
You can now change back to torch.export.export()
##############################################################################################
{END_COLOR}"""
error = f"""{WARNING_COLOR}
###################################################################################################
WARNING: {len(self.failures)} issue(s) found during export, and it was not able to soundly produce a graph.
Please follow the instructions to fix the errors.
###################################################################################################
"""
for i, failure in enumerate(self.failures):
error += f"{i + 1}. {failure.print(self.str_to_filename)}\n"
error += END_COLOR
return error
def apply_suggested_fixes(self) -> None:
raise NotImplementedError("Not implemented yet")
@dataclass
class ExpressionCreatedNode:
result_id: int
argument_ids: list[int]
record: dict[str, object]
visited: bool = False
class LogRecord:
def __init__(self) -> None:
self.log_count: dict[int, int] = {}
self.logs: list[tuple[str, dict[str, Any]]] = []
def _hash(self, element: tuple[str, dict[str, Any]]) -> int:
key, data = element
if key == "missing_fake_kernel":
return hash((key, data["op"]))
elif key == "mismatched_fake_kernel":
return hash((key, data["op"], data["reason"]))
elif key == "propagate_real_tensors_provenance":
return hash((key, json.dumps(data["user_stack"])))
elif key == "guard_added":
return hash((key, json.dumps(data["user_stack"])))
elif key == "create_unbacked_symbol":
return hash((key, json.dumps(data["user_stack"])))
return hash((key, json.dumps(data)))
def try_add(self, element: tuple[str, dict[str, str]]) -> bool:
hash_value = self._hash(element)
if hash_value in self.log_count:
self.log_count[hash_value] += 1
return False
self.log_count[hash_value] = 1
self.logs.append(element)
return True
def get_log_count(self, element: tuple[str, dict[str, Any]]) -> int:
return self.log_count[self._hash(element)]
class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
def __init__(self) -> None:
self.specific_log_keys = [
"str",
"exported_program",
"propagate_real_tensors_provenance",
"guard_added",
"missing_fake_kernel",
"mismatched_fake_kernel",
"expression_created",
"create_unbacked_symbol",
]
self.log_record: LogRecord = LogRecord()
self.expression_created_logs: dict[int, ExpressionCreatedNode] = {}
self.symbol_to_expressions: dict[str, list[dict[str, Any]]] = {}
self.logger = logging.getLogger("torch.__trace")
self.prev_get_dtrace = False
if root_dir := os.environ.get(torch._logging._internal.DTRACE_ENV_VAR):
super().__init__(root_dir)
else:
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
root_dir = os.path.join(
tempfile.gettempdir(),
"export_" + sanitized_username,
)
super().__init__(root_dir)
self.setFormatter(torch._logging._internal.TorchLogsFormatter(trace=True))
def __enter__(self) -> "CaptureStructuredTrace":
self.log_record = LogRecord()
self.expression_created_logs = {}
# Remove the lazy trace handler if it exists
possible_lazy_trace_handlers = [
handler
for handler in self.logger.handlers
if isinstance(handler, torch._logging._internal.LazyTraceHandler)
]
for handler in possible_lazy_trace_handlers:
self.logger.removeHandler(handler)
self.logger.addHandler(self)
self.prev_get_dtrace = torch._logging._internal.GET_DTRACE_STRUCTURED
# pyrefly: ignore # bad-assignment
torch._logging._internal.GET_DTRACE_STRUCTURED = True
return self
def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore[no-untyped-def]
self.log_record = LogRecord()
self.expression_created_logs = {}
self.logger.removeHandler(self)
# pyrefly: ignore # bad-assignment
torch._logging._internal.GET_DTRACE_STRUCTURED = self.prev_get_dtrace
self.prev_get_dtrace = False
def emit(self, record: Any) -> None:
def _log_expression_created(
emit_func: Callable[[Any], None], sym_node_id: int
) -> None:
# Log all the relevant expression_created logs
if sym_node_id is None:
return
if res := self.expression_created_logs.get(sym_node_id, None):
# Don't log the expression if we have already
# printed it beforehand
if not res.visited:
res.visited = True
for arg in res.argument_ids:
_log_expression_created(emit_func, arg)
emit_func(res.record)
metadata = record.metadata
for key in self.specific_log_keys:
if key in metadata:
if self.log_record.try_add((key, metadata[key])):
if key == "expression_created":
# We don't want to log all expression_created logs, only
# the ones that are relevant to the
# guards/propagate_real_tensor
self.expression_created_logs[metadata[key]["result_id"]] = (
ExpressionCreatedNode(
metadata[key]["result_id"],
metadata[key].get("argument_ids", []),
record,
)
)
return
elif key == "propagate_real_tensors_provenance":
_log_expression_created(
super().emit, metadata[key].get("expr_node_id")
)
elif key == "guard_added":
if len(metadata[key]["symbol_to_sources"]) == 0:
# We only want to include guards added that are relevant to
# the symbolic shapes corresponding to the inputs which were
# specified in the dynamic_shapes arg. These have a source.
return
elif metadata[key]["prefix"] == "runtime_assert":
# This should've been captured by a
# propagate_real_tensors log
return
_log_expression_created(
super().emit, metadata[key].get("expr_node_id")
)
super().emit(record)
def draft_export(
mod: torch.nn.Module,
args: tuple[Any, ...],
kwargs: Optional[Mapping[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
preserve_module_call_signature: tuple[str, ...] = (),
strict: bool = False,
pre_dispatch: bool = True,
prefer_deferred_runtime_asserts_over_guards: bool = False,
) -> ExportedProgram:
start_time = time.time()
kwargs = kwargs or {}
dynamic_shapes = dynamic_shapes or {}
constraint_violation_msg = None
capture_structured_log = CaptureStructuredTrace()
with (
torch._functorch.config.patch(
fake_tensor_propagate_real_tensors=True,
generate_fake_kernels_from_real_mismatches=True,
),
capture_structured_log,
):
try:
new_shapes = None
ep = _export(
mod,
args,
kwargs,
dynamic_shapes=dynamic_shapes,
strict=strict,
pre_dispatch=pre_dispatch,
preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
)
except Exception as exc:
if (
isinstance(exc, UserError)
and exc.error_type == UserErrorType.CONSTRAINT_VIOLATION
):
constraint_violation_msg = exc.msg
def convert_dim_to_auto(dim: Any) -> Any:
if isinstance(dim, Dim):
return Dim.AUTO(min=dim.min, max=dim.max)
elif isinstance(dim, _DimHint) and dim.type == _DimHintType.DYNAMIC:
return Dim.AUTO(min=dim.min, max=dim.max)
return dim
new_shapes = pytree.tree_map(convert_dim_to_auto, dynamic_shapes)
ep = _export(
mod,
args,
kwargs,
dynamic_shapes=new_shapes,
strict=strict,
pre_dispatch=pre_dispatch,
preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
)
else:
log_draft_export_usage(
error=True,
export_time=time.time() - start_time,
strict=strict,
message=str(exc),
type=f"{type(exc).__name__}.{type(exc).__qualname__}",
)
raise exc
torch._logging.dtrace_structured("exported_program", payload_fn=lambda: str(ep))
str_to_filename: dict[int, str] = {}
failures: list[FailureReport] = []
incorrect_custom_ops: set[str] = set()
expressions_created: dict[int, dict[str, Any]] = {}
for log_name, log_contents in capture_structured_log.log_record.logs:
failure_type = None
if log_name == "str":
str_to_filename[log_contents[1]] = log_contents[0] # type: ignore[index]
continue
elif log_name == "propagate_real_tensors_provenance":
log_contents["occurrences"] = (
capture_structured_log.log_record.get_log_count(
(log_name, log_contents)
)
)
failure_type = FailureType.DATA_DEPENDENT_ERROR
elif log_name == "guard_added":
if new_shapes is None:
continue
failure_type = FailureType.GUARD_ADDED
log_contents["new_dynamic_shapes"] = new_shapes
elif log_name == "missing_fake_kernel":
failure_type = FailureType.MISSING_FAKE_KERNEL
incorrect_custom_ops.add(log_contents["op"])
elif log_name == "mismatched_fake_kernel":
failure_type = FailureType.MISMATCHED_FAKE_KERNEL
incorrect_custom_ops.add(log_contents["op"])
else:
continue
assert failure_type is not None
failures.append(
FailureReport(
failure_type,
log_contents,
)
)
for k, v in capture_structured_log.expression_created_logs.items():
if v.visited:
expressions_created[k] = v.record
op_profiles = get_op_profiles(ep.graph_module, incorrect_custom_ops)
report = DraftExportReport(
failures, str_to_filename, expressions_created, op_profiles
)
# Add asserts around custom ops
insert_custom_op_guards(ep.graph_module, incorrect_custom_ops)
ep._report = report
if not report.successful():
log_filename = capture_structured_log.stream.name
warning_msg = f"""
###################################################################################################
WARNING: {len(report.failures)} issue(s) found during export, and it was not able to soundly produce a graph.
To view the report of failures in an html page, please run the command:
`tlparse {log_filename} --export`
Or, you can view the errors in python by inspecting `print(ep._report)`.
"""
if len(report.op_profiles) > 0:
warning_msg += f"""
While tracing we found {len(report.op_profiles)} operator(s) which do not have a fake kernel registered.
If you intend to retrace the exported graph or run it with fake tensors, please run it under the
following context manager, which will register a fake kernel for those operators.
```
with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles):
# run with fake tensors
```
"""
warning_msg += """#################################################################################################"""
log.warning(warning_msg)
else:
log.info(
"""
##############################################################################################
Congratuations: No issues are found during export, and it was able to soundly produce a graph.
You can now change back to torch.export.export()
##############################################################################################
"""
)
log_draft_export_usage(
error=False,
export_time=time.time() - start_time,
strict=strict,
constraint_violations=constraint_violation_msg,
report=ep._report,
**get_ep_stats(ep),
)
return ep