mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/bb31574ac8a59893c9cf52189e67bb2d after: 0 errors (1,970 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164588 Approved by: https://github.com/oulgen
545 lines
20 KiB
Python
545 lines
20 KiB
Python
import getpass
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import tempfile
|
|
import time
|
|
from collections.abc import Callable, Mapping
|
|
from dataclasses import dataclass
|
|
from enum import IntEnum
|
|
from typing import Any, Optional, Union
|
|
|
|
import torch
|
|
import torch._logging._internal
|
|
import torch.utils._pytree as pytree
|
|
from torch._dynamo.exc import UserError, UserErrorType
|
|
from torch._export.passes.insert_custom_op_guards import (
|
|
get_op_profiles,
|
|
insert_custom_op_guards,
|
|
OpProfile,
|
|
)
|
|
from torch._utils_internal import log_draft_export_usage
|
|
|
|
from ._trace import _export, get_ep_stats
|
|
from .dynamic_shapes import _DimHint, _DimHintType, Dim
|
|
from .exported_program import ExportedProgram
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class FailureType(IntEnum):
|
|
MISSING_FAKE_KERNEL = 1
|
|
DATA_DEPENDENT_ERROR = 2
|
|
GUARD_ADDED = 3
|
|
MISMATCHED_FAKE_KERNEL = 4
|
|
|
|
def __str__(self) -> str:
|
|
return self.name
|
|
|
|
|
|
def prettify_stack(stack: list[dict[str, str]], str_to_filename: dict[int, str]) -> str:
|
|
res = ""
|
|
for frame in stack:
|
|
if frame["filename"] not in str_to_filename:
|
|
continue
|
|
|
|
res += f"""
|
|
File {str_to_filename[frame["filename"]]}, lineno {frame["line"]}, in {frame["name"]}""" # type: ignore[index]
|
|
|
|
res += f"\n {stack[-1]['loc']}"
|
|
return res
|
|
|
|
|
|
def prettify_frame_locals(
|
|
loc: str, locals: dict[str, Any], symbols: dict[str, Any]
|
|
) -> str:
|
|
local_str = "\n".join(f" {k}: {v}" for k, v in locals.items())
|
|
res = f"""
|
|
Locals:
|
|
{local_str}
|
|
"""
|
|
if any(v is not None for v in symbols.values()):
|
|
symbol_str = "\n".join(
|
|
f" {k}: {v}" for k, v in symbols.items() if v is not None
|
|
)
|
|
res += f"""
|
|
Symbols:
|
|
{symbol_str}
|
|
"""
|
|
return res
|
|
|
|
|
|
def get_loc(filename: str, lineno: int) -> Optional[str]:
|
|
try:
|
|
with open(filename) as f:
|
|
for i, line in enumerate(f):
|
|
if i == lineno - 1:
|
|
return line.strip()
|
|
except FileNotFoundError:
|
|
pass
|
|
return None
|
|
|
|
|
|
class FailureReport:
|
|
def __init__(
|
|
self, failure_type: FailureType, data: dict[str, Any], xfail: bool = False
|
|
) -> None:
|
|
self.failure_type: FailureType = failure_type
|
|
self.data: dict[str, Any] = data
|
|
self.xfail: bool = xfail
|
|
|
|
def __repr__(self) -> str:
|
|
return f"FailureReport(failure_type={self.failure_type}, xfail={self.xfail}, data={self.data})"
|
|
|
|
def print(self, str_to_filename: dict[int, str]) -> str:
|
|
if self.failure_type == FailureType.MISSING_FAKE_KERNEL:
|
|
op = self.data["op"]
|
|
|
|
return f"""Missing fake kernel.
|
|
torch.ops.{op} is missing a fake kernel implementation.
|
|
|
|
Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a meta implementation.
|
|
""" # noqa: B950
|
|
|
|
elif self.failure_type == FailureType.GUARD_ADDED:
|
|
locals_info = (
|
|
prettify_frame_locals(**self.data["frame_locals"])
|
|
if self.data["frame_locals"]
|
|
else ""
|
|
)
|
|
return f"""Guard Added.
|
|
A guard was added during tracing, which might've resulted in some incorrect
|
|
tracing or constraint violation error.
|
|
Specifically, this guard was added: {self.data["expr"]}, where {self.data["symbol_to_sources"]}.
|
|
This occurred at the following stacktrace: {prettify_stack(self.data["user_stack"], str_to_filename)}:
|
|
{locals_info}
|
|
And the following framework stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}\n
|
|
Because of this, we have modified the dynamic shapes structure to be the
|
|
following. You can also use torch.export.Dim.AUTO instead to specify your
|
|
dynamic shapes, and we will automatically infer the dynamism for you.
|
|
```
|
|
dynamic_shapes = {self.data["new_dynamic_shapes"]}
|
|
```
|
|
"""
|
|
|
|
elif self.failure_type == FailureType.DATA_DEPENDENT_ERROR:
|
|
locals_info = (
|
|
prettify_frame_locals(**self.data["frame_locals"])
|
|
if self.data["frame_locals"]
|
|
else ""
|
|
)
|
|
return f"""Data dependent error.
|
|
When exporting, we were unable to evaluate the value of `{self.data["expr"]}`.
|
|
This was encountered {self.data["occurrences"]} times.
|
|
This occurred at the following user stacktrace: {prettify_stack(self.data["user_stack"], str_to_filename)}
|
|
{locals_info}
|
|
And the following framework stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}\n
|
|
As a result, it was specialized to a constant (e.g. `{self.data["result"]}` in the 1st occurrence), and asserts were inserted into the graph.
|
|
|
|
Please add `torch._check(...)` to the original code to assert this data-dependent assumption.
|
|
Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details.
|
|
""" # noqa: B950
|
|
|
|
elif self.failure_type == FailureType.MISMATCHED_FAKE_KERNEL:
|
|
op = self.data["op"]
|
|
reason = self.data["reason"]
|
|
return f"""Mismatched fake kernel.
|
|
torch.ops.{op} has a fake kernel implementation, but it has incorrect behavior, based on the real kernel.
|
|
The reason for the mismatch is: {reason}.
|
|
|
|
Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a fake implementation.
|
|
""" # noqa: B950
|
|
|
|
else:
|
|
raise ValueError(f"Unknown failure type: {self.failure_type}")
|
|
|
|
|
|
class DraftExportReport:
|
|
def __init__(
|
|
self,
|
|
failures: list[FailureReport],
|
|
str_to_filename: dict[int, str],
|
|
expressions_created: dict[int, dict[str, Any]],
|
|
op_profiles: dict[str, set[OpProfile]],
|
|
):
|
|
self.failures: list[FailureReport] = failures
|
|
self.str_to_filename = str_to_filename
|
|
self.expressions_created: dict[int, dict[str, Any]] = expressions_created
|
|
self.op_profiles = op_profiles
|
|
|
|
def successful(self) -> bool:
|
|
return len(self.failures) == 0 or all(
|
|
failure.xfail for failure in self.failures
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"DraftExportReport({self.failures})"
|
|
|
|
def __str__(self) -> str:
|
|
WARNING_COLOR = "\033[93m"
|
|
GREEN_COLOR = "\033[92m"
|
|
END_COLOR = "\033[0m"
|
|
|
|
if self.successful():
|
|
return f"""{GREEN_COLOR}
|
|
##############################################################################################
|
|
Congratuations: No issues are found during export, and it was able to soundly produce a graph.
|
|
You can now change back to torch.export.export()
|
|
##############################################################################################
|
|
{END_COLOR}"""
|
|
|
|
error = f"""{WARNING_COLOR}
|
|
###################################################################################################
|
|
WARNING: {len(self.failures)} issue(s) found during export, and it was not able to soundly produce a graph.
|
|
Please follow the instructions to fix the errors.
|
|
###################################################################################################
|
|
|
|
"""
|
|
|
|
for i, failure in enumerate(self.failures):
|
|
error += f"{i + 1}. {failure.print(self.str_to_filename)}\n"
|
|
error += END_COLOR
|
|
return error
|
|
|
|
def apply_suggested_fixes(self) -> None:
|
|
raise NotImplementedError("Not implemented yet")
|
|
|
|
|
|
@dataclass
|
|
class ExpressionCreatedNode:
|
|
result_id: int
|
|
argument_ids: list[int]
|
|
record: dict[str, object]
|
|
visited: bool = False
|
|
|
|
|
|
class LogRecord:
|
|
def __init__(self) -> None:
|
|
self.log_count: dict[int, int] = {}
|
|
self.logs: list[tuple[str, dict[str, Any]]] = []
|
|
|
|
def _hash(self, element: tuple[str, dict[str, Any]]) -> int:
|
|
key, data = element
|
|
|
|
if key == "missing_fake_kernel":
|
|
return hash((key, data["op"]))
|
|
elif key == "mismatched_fake_kernel":
|
|
return hash((key, data["op"], data["reason"]))
|
|
elif key == "propagate_real_tensors_provenance":
|
|
return hash((key, json.dumps(data["user_stack"])))
|
|
elif key == "guard_added":
|
|
return hash((key, json.dumps(data["user_stack"])))
|
|
elif key == "create_unbacked_symbol":
|
|
return hash((key, json.dumps(data["user_stack"])))
|
|
|
|
return hash((key, json.dumps(data)))
|
|
|
|
def try_add(self, element: tuple[str, dict[str, str]]) -> bool:
|
|
hash_value = self._hash(element)
|
|
if hash_value in self.log_count:
|
|
self.log_count[hash_value] += 1
|
|
return False
|
|
|
|
self.log_count[hash_value] = 1
|
|
self.logs.append(element)
|
|
return True
|
|
|
|
def get_log_count(self, element: tuple[str, dict[str, Any]]) -> int:
|
|
return self.log_count[self._hash(element)]
|
|
|
|
|
|
class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
|
def __init__(self) -> None:
|
|
self.specific_log_keys = [
|
|
"str",
|
|
"exported_program",
|
|
"propagate_real_tensors_provenance",
|
|
"guard_added",
|
|
"missing_fake_kernel",
|
|
"mismatched_fake_kernel",
|
|
"expression_created",
|
|
"create_unbacked_symbol",
|
|
]
|
|
self.log_record: LogRecord = LogRecord()
|
|
self.expression_created_logs: dict[int, ExpressionCreatedNode] = {}
|
|
self.symbol_to_expressions: dict[str, list[dict[str, Any]]] = {}
|
|
self.logger = logging.getLogger("torch.__trace")
|
|
self.prev_get_dtrace = False
|
|
|
|
if root_dir := os.environ.get(torch._logging._internal.DTRACE_ENV_VAR):
|
|
super().__init__(root_dir)
|
|
else:
|
|
sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
|
|
root_dir = os.path.join(
|
|
tempfile.gettempdir(),
|
|
"export_" + sanitized_username,
|
|
)
|
|
super().__init__(root_dir)
|
|
|
|
self.setFormatter(torch._logging._internal.TorchLogsFormatter(trace=True))
|
|
|
|
def __enter__(self) -> "CaptureStructuredTrace":
|
|
self.log_record = LogRecord()
|
|
self.expression_created_logs = {}
|
|
|
|
# Remove the lazy trace handler if it exists
|
|
possible_lazy_trace_handlers = [
|
|
handler
|
|
for handler in self.logger.handlers
|
|
if isinstance(handler, torch._logging._internal.LazyTraceHandler)
|
|
]
|
|
for handler in possible_lazy_trace_handlers:
|
|
self.logger.removeHandler(handler)
|
|
|
|
self.logger.addHandler(self)
|
|
self.prev_get_dtrace = torch._logging._internal.GET_DTRACE_STRUCTURED
|
|
# pyrefly: ignore # bad-assignment
|
|
torch._logging._internal.GET_DTRACE_STRUCTURED = True
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore[no-untyped-def]
|
|
self.log_record = LogRecord()
|
|
self.expression_created_logs = {}
|
|
self.logger.removeHandler(self)
|
|
# pyrefly: ignore # bad-assignment
|
|
torch._logging._internal.GET_DTRACE_STRUCTURED = self.prev_get_dtrace
|
|
self.prev_get_dtrace = False
|
|
|
|
def emit(self, record: Any) -> None:
|
|
def _log_expression_created(
|
|
emit_func: Callable[[Any], None], sym_node_id: int
|
|
) -> None:
|
|
# Log all the relevant expression_created logs
|
|
if sym_node_id is None:
|
|
return
|
|
if res := self.expression_created_logs.get(sym_node_id, None):
|
|
# Don't log the expression if we have already
|
|
# printed it beforehand
|
|
if not res.visited:
|
|
res.visited = True
|
|
for arg in res.argument_ids:
|
|
_log_expression_created(emit_func, arg)
|
|
|
|
emit_func(res.record)
|
|
|
|
metadata = record.metadata
|
|
for key in self.specific_log_keys:
|
|
if key in metadata:
|
|
if self.log_record.try_add((key, metadata[key])):
|
|
if key == "expression_created":
|
|
# We don't want to log all expression_created logs, only
|
|
# the ones that are relevant to the
|
|
# guards/propagate_real_tensor
|
|
self.expression_created_logs[metadata[key]["result_id"]] = (
|
|
ExpressionCreatedNode(
|
|
metadata[key]["result_id"],
|
|
metadata[key].get("argument_ids", []),
|
|
record,
|
|
)
|
|
)
|
|
return
|
|
|
|
elif key == "propagate_real_tensors_provenance":
|
|
_log_expression_created(
|
|
super().emit, metadata[key].get("expr_node_id")
|
|
)
|
|
|
|
elif key == "guard_added":
|
|
if len(metadata[key]["symbol_to_sources"]) == 0:
|
|
# We only want to include guards added that are relevant to
|
|
# the symbolic shapes corresponding to the inputs which were
|
|
# specified in the dynamic_shapes arg. These have a source.
|
|
return
|
|
elif metadata[key]["prefix"] == "runtime_assert":
|
|
# This should've been captured by a
|
|
# propagate_real_tensors log
|
|
return
|
|
|
|
_log_expression_created(
|
|
super().emit, metadata[key].get("expr_node_id")
|
|
)
|
|
|
|
super().emit(record)
|
|
|
|
|
|
def draft_export(
|
|
mod: torch.nn.Module,
|
|
args: tuple[Any, ...],
|
|
kwargs: Optional[Mapping[str, Any]] = None,
|
|
*,
|
|
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
|
preserve_module_call_signature: tuple[str, ...] = (),
|
|
strict: bool = False,
|
|
pre_dispatch: bool = True,
|
|
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
|
) -> ExportedProgram:
|
|
start_time = time.time()
|
|
kwargs = kwargs or {}
|
|
dynamic_shapes = dynamic_shapes or {}
|
|
|
|
constraint_violation_msg = None
|
|
capture_structured_log = CaptureStructuredTrace()
|
|
|
|
with (
|
|
torch._functorch.config.patch(
|
|
fake_tensor_propagate_real_tensors=True,
|
|
generate_fake_kernels_from_real_mismatches=True,
|
|
),
|
|
capture_structured_log,
|
|
):
|
|
try:
|
|
new_shapes = None
|
|
ep = _export(
|
|
mod,
|
|
args,
|
|
kwargs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
strict=strict,
|
|
pre_dispatch=pre_dispatch,
|
|
preserve_module_call_signature=preserve_module_call_signature,
|
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
|
)
|
|
except Exception as exc:
|
|
if (
|
|
isinstance(exc, UserError)
|
|
and exc.error_type == UserErrorType.CONSTRAINT_VIOLATION
|
|
):
|
|
constraint_violation_msg = exc.msg
|
|
|
|
def convert_dim_to_auto(dim: Any) -> Any:
|
|
if isinstance(dim, Dim):
|
|
return Dim.AUTO(min=dim.min, max=dim.max)
|
|
elif isinstance(dim, _DimHint) and dim.type == _DimHintType.DYNAMIC:
|
|
return Dim.AUTO(min=dim.min, max=dim.max)
|
|
return dim
|
|
|
|
new_shapes = pytree.tree_map(convert_dim_to_auto, dynamic_shapes)
|
|
ep = _export(
|
|
mod,
|
|
args,
|
|
kwargs,
|
|
dynamic_shapes=new_shapes,
|
|
strict=strict,
|
|
pre_dispatch=pre_dispatch,
|
|
preserve_module_call_signature=preserve_module_call_signature,
|
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
|
)
|
|
else:
|
|
log_draft_export_usage(
|
|
error=True,
|
|
export_time=time.time() - start_time,
|
|
strict=strict,
|
|
message=str(exc),
|
|
type=f"{type(exc).__name__}.{type(exc).__qualname__}",
|
|
)
|
|
raise exc
|
|
|
|
torch._logging.dtrace_structured("exported_program", payload_fn=lambda: str(ep))
|
|
|
|
str_to_filename: dict[int, str] = {}
|
|
failures: list[FailureReport] = []
|
|
incorrect_custom_ops: set[str] = set()
|
|
expressions_created: dict[int, dict[str, Any]] = {}
|
|
|
|
for log_name, log_contents in capture_structured_log.log_record.logs:
|
|
failure_type = None
|
|
|
|
if log_name == "str":
|
|
str_to_filename[log_contents[1]] = log_contents[0] # type: ignore[index]
|
|
continue
|
|
|
|
elif log_name == "propagate_real_tensors_provenance":
|
|
log_contents["occurrences"] = (
|
|
capture_structured_log.log_record.get_log_count(
|
|
(log_name, log_contents)
|
|
)
|
|
)
|
|
|
|
failure_type = FailureType.DATA_DEPENDENT_ERROR
|
|
|
|
elif log_name == "guard_added":
|
|
if new_shapes is None:
|
|
continue
|
|
|
|
failure_type = FailureType.GUARD_ADDED
|
|
log_contents["new_dynamic_shapes"] = new_shapes
|
|
elif log_name == "missing_fake_kernel":
|
|
failure_type = FailureType.MISSING_FAKE_KERNEL
|
|
incorrect_custom_ops.add(log_contents["op"])
|
|
|
|
elif log_name == "mismatched_fake_kernel":
|
|
failure_type = FailureType.MISMATCHED_FAKE_KERNEL
|
|
incorrect_custom_ops.add(log_contents["op"])
|
|
|
|
else:
|
|
continue
|
|
|
|
assert failure_type is not None
|
|
failures.append(
|
|
FailureReport(
|
|
failure_type,
|
|
log_contents,
|
|
)
|
|
)
|
|
|
|
for k, v in capture_structured_log.expression_created_logs.items():
|
|
if v.visited:
|
|
expressions_created[k] = v.record
|
|
|
|
op_profiles = get_op_profiles(ep.graph_module, incorrect_custom_ops)
|
|
report = DraftExportReport(
|
|
failures, str_to_filename, expressions_created, op_profiles
|
|
)
|
|
|
|
# Add asserts around custom ops
|
|
insert_custom_op_guards(ep.graph_module, incorrect_custom_ops)
|
|
|
|
ep._report = report
|
|
if not report.successful():
|
|
log_filename = capture_structured_log.stream.name
|
|
|
|
warning_msg = f"""
|
|
###################################################################################################
|
|
WARNING: {len(report.failures)} issue(s) found during export, and it was not able to soundly produce a graph.
|
|
To view the report of failures in an html page, please run the command:
|
|
`tlparse {log_filename} --export`
|
|
Or, you can view the errors in python by inspecting `print(ep._report)`.
|
|
"""
|
|
|
|
if len(report.op_profiles) > 0:
|
|
warning_msg += f"""
|
|
While tracing we found {len(report.op_profiles)} operator(s) which do not have a fake kernel registered.
|
|
If you intend to retrace the exported graph or run it with fake tensors, please run it under the
|
|
following context manager, which will register a fake kernel for those operators.
|
|
```
|
|
with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles):
|
|
# run with fake tensors
|
|
```
|
|
"""
|
|
|
|
warning_msg += """#################################################################################################"""
|
|
|
|
log.warning(warning_msg)
|
|
|
|
else:
|
|
log.info(
|
|
"""
|
|
##############################################################################################
|
|
Congratuations: No issues are found during export, and it was able to soundly produce a graph.
|
|
You can now change back to torch.export.export()
|
|
##############################################################################################
|
|
"""
|
|
)
|
|
|
|
log_draft_export_usage(
|
|
error=False,
|
|
export_time=time.time() - start_time,
|
|
strict=strict,
|
|
constraint_violations=constraint_violation_msg,
|
|
report=ep._report,
|
|
**get_ep_stats(ep),
|
|
)
|
|
return ep
|