typing debugging.py (#160364)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160364
Approved by: https://github.com/Skylion007
ghstack dependencies: #160362, #160363
This commit is contained in:
Lucas Kabela
2025-08-13 15:40:09 -07:00
committed by PyTorch MergeBot
parent 6fe6dd9fdc
commit 9faca5f260
9 changed files with 141 additions and 68 deletions

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
This module provides debugging backends for TorchDynamo to help diagnose and troubleshoot
compilation and execution issues. It includes:
@ -28,40 +26,54 @@ These backends are primarily used for:
import dataclasses
import functools
import logging
from collections.abc import Iterable
from importlib import import_module
from typing import Any, Optional
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
from functorch.compile import min_cut_rematerialization_partition
from torch import _guards
from torch._dynamo.output_graph import GraphCompileReason
from torch._functorch import config as functorch_config
from torch._functorch.compilers import ts_compile
from .common import aot_autograd
from .registry import register_debug_backend as register_backend
from .registry import CompiledFn, CompilerFn, register_debug_backend as register_backend
if TYPE_CHECKING:
from torch.fx.node import Target
log = logging.getLogger(__name__)
@register_backend
def eager(gm, fake_tensor_inputs, **kwargs):
def eager(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
if kwargs:
log.warning("eager backend ignoring extra kwargs %s", kwargs)
return gm.forward
def make_eager_backend_with_torch_function_mode(mode):
def make_eager_backend_with_torch_function_mode(
mode: torch.overrides.TorchFunctionMode,
) -> Callable[..., Any]:
return make_eager_backend_with_torch_function_modes([mode])
def make_eager_backend_with_torch_function_modes(modes):
def make_eager_backend_with_torch_function_modes(
modes: Iterable[torch.overrides.TorchFunctionMode],
) -> Callable[..., Any]:
"""Used to trace HOPs (cond and while) for eager execution, the metadata
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
in the HOP, so we need to externally run this mode and not trace it."""
from contextlib import ExitStack
def fn(gm, fake_tensor_inputs, **kwargs):
def fn(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
stack = ExitStack()
for mode in modes:
stack.enter_context(mode)
@ -74,13 +86,15 @@ def make_eager_backend_with_torch_function_modes(modes):
@register_backend
def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
def eager_noexcept(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
if kwargs:
log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs)
# This backend is intended to check that dynamo-generated GraphModules
# do not cause errors.
def inner(*args):
def inner(*args: Any) -> Any:
try:
return gm(*args)
except Exception as e:
@ -92,13 +106,15 @@ def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
@register_backend
def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs):
def pre_dispatch_eager(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> torch.fx.GraphModule:
if kwargs:
log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs)
from torch.fx.experimental.proxy_tensor import make_fx
def runnable_gm(*args):
def runnable_gm(*args: Any) -> Any:
return torch.fx.Interpreter(gm).run(*args)
pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
@ -108,7 +124,9 @@ def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs):
@register_backend
def eager_debug(gm, fake_tensor_inputs, **kwargs):
def eager_debug(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
if kwargs:
log.warning("eager_debug backend ignoring extra kwargs %s", kwargs)
@ -117,42 +135,55 @@ def eager_debug(gm, fake_tensor_inputs, **kwargs):
# We could add more debugging bits here.
# Right now, this backend can be used to check for and error on
# custom dispatcher ops that have incorrect schemas.
def inner(*args):
def inner(*args: Any) -> Any:
with SchemaCheckMode():
return torch.fx.Interpreter(gm).run(*args)
return inner
@register_backend(name="ts")
def torchscript(gm, fake_tensor_inputs):
@register_backend(name="ts") # type: ignore[misc]
def torchscript(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
) -> torch.jit.ScriptModule:
return torch.jit.script(gm)
# used boxed call to discard inputs when they are no longer needed
def boxed_nop(fx_g, example_inputs):
def run(args):
def boxed_nop(
fx_g: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> Callable[..., Any]:
def run(args: Any) -> Any:
return torch.fx.Interpreter(fx_g).boxed_run(args)
run._boxed_call = True
run._boxed_call = True # type: ignore[attr-defined]
return run
def boxed_nop_with_mode(fx_g, example_inputs, *, mode):
def run(args):
def boxed_nop_with_mode(
fx_g: torch.fx.GraphModule,
example_inputs: list[torch.Tensor],
*,
mode: torch.overrides.TorchFunctionMode,
) -> Callable[..., Any]:
def run(args: Any) -> Any:
with mode:
return torch.fx.Interpreter(fx_g).boxed_run(args)
run._boxed_call = True
run._boxed_call = True # type: ignore[attr-defined]
return run
def fake_crossref_boxed_nop(fx_g, example_inputs, ignore_op_fn=None):
def run(args):
def fake_crossref_boxed_nop(
fx_g: torch.fx.GraphModule,
example_inputs: list[torch.Tensor],
ignore_op_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None,
) -> Callable[..., Any]:
def run(args: Any) -> Any:
with torch._subclasses.CrossRefFakeMode(ignore_op_fn):
return torch.fx.Interpreter(fx_g).boxed_run(args)
run._boxed_call = True
run._boxed_call = True # type: ignore[attr-defined]
return run
@ -160,7 +191,9 @@ def ignore_builtins(op: torch._ops.OpOverload) -> bool:
return op.namespace in ("aten", "prims", "prim")
def get_nop_func():
def get_nop_func() -> Callable[
[torch.fx.GraphModule, list[torch.Tensor]], Callable[..., Any]
]:
if not torch._functorch.config.fake_tensor_crossref:
return boxed_nop
elif torch._functorch.config.fake_tensor_crossref == "all":
@ -173,12 +206,12 @@ def get_nop_func():
# Useful for debugging purpose
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
def aot_eager(
gm,
fake_tensor_inputs,
fw_compiler=None,
bw_compiler=None,
**kwargs,
):
gm: torch.fx.GraphModule,
fake_tensor_inputs: list[torch.Tensor],
fw_compiler: Optional[Callable[..., Any]] = None,
bw_compiler: Optional[Callable[..., Any]] = None,
**kwargs: Any,
) -> Callable[..., Any]:
return aot_autograd(
fw_compiler=fw_compiler or boxed_nop,
bw_compiler=bw_compiler or boxed_nop,
@ -201,7 +234,9 @@ register_backend(
# inductor problems.
# aot_eager_decomp_partition just replaces the inductor compiler with nop to help
# isolate inductor vs aot_eager errors
def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
def aot_eager_decomp_partition(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
if kwargs:
log.warning(
"aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs
@ -213,7 +248,7 @@ def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
if bisect_changes := CompilerBisector.get_config_change(
"aot_eager_decomp_partition"
):
config_patches.update(bisect_changes)
config_patches.update(bisect_changes) # type: ignore[arg-type]
with functorch_config.patch(config_patches):
return aot_autograd(
@ -237,7 +272,12 @@ register_backend(
# aot_eager_decomp_partition_with_mode is similar as aot_eager_decomp_partition,
# except that it takes a TorchDispatchMode mode and run the fw/bw in the mode
def aot_eager_decomp_partition_with_mode(gm, fake_tensor_inputs, mode, **kwarg):
def aot_eager_decomp_partition_with_mode(
gm: torch.fx.GraphModule,
fake_tensor_inputs: list[torch.Tensor],
mode: Any,
**kwarg: Any,
) -> Callable[..., Any]:
return aot_autograd(
# these are taken from memory_efficient_fusion()
fw_compiler=functools.partial(boxed_nop_with_mode, mode=mode),
@ -254,11 +294,13 @@ def aot_eager_decomp_partition_with_mode(gm, fake_tensor_inputs, mode, **kwarg):
register_backend(
name="aot_eager_decomp_partition_with_mode",
compiler_fn=aot_eager_decomp_partition_with_mode,
compiler_fn=aot_eager_decomp_partition_with_mode, # type: ignore[arg-type]
)
def aot_eager_decomp_partition_crossref(gm, fake_tensor_inputs, **kwargs):
def aot_eager_decomp_partition_crossref(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
# if the config is set, respect it, otherwise only test custom_ops.
# custom_op bad metas always manifest as an error whereas aten will only sometimes.
# by default, use the less noisy option
@ -296,7 +338,9 @@ class TestingOnlyCompileError(Exception):
@register_backend
def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
def relu_compile_error_TESTING_ONLY(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.target == torch.relu:
raise ReluCompileError
@ -304,7 +348,9 @@ def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
@register_backend
def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
def relu_runtime_error_TESTING_ONLY(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.target == torch.relu:
node.target = torch._assert
@ -314,7 +360,9 @@ def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
@register_backend
def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
def relu_accuracy_error_TESTING_ONLY(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.target == torch.relu:
node.target = torch.add
@ -325,7 +373,9 @@ def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
@register_backend
def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
def non_leaf_compile_error_TESTING_ONLY(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
# Require at least one non-trivial thing in the graph,
# see https://github.com/pytorch/pytorch/issues/102898
for node in gm.graph.nodes:
@ -349,11 +399,9 @@ class ExplainOutput:
graphs: list[torch.fx.GraphModule]
graph_count: int
graph_break_count: int
break_reasons: list[
Any
] # Type is GraphCompileReason but doesn't matter for this purpose
break_reasons: list[GraphCompileReason]
op_count: int
ops_per_graph: Optional[list[torch.fx.Node]] = None
ops_per_graph: Optional[list[list["Target"]]] = None
out_guards: Optional[list[_guards.Guard]] = None
compile_times: Optional[str] = None
@ -389,8 +437,18 @@ class ExplainOutput:
def _explain_graph_detail(
gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons
):
gm: torch.fx.GraphModule,
graphs: list[torch.fx.GraphModule],
op_count: int,
ops_per_graph: list[list["Target"]],
break_reasons: list[GraphCompileReason],
) -> tuple[
torch.fx.GraphModule,
list[torch.fx.GraphModule],
int,
list[list["Target"]],
list[GraphCompileReason],
]:
"""
This function is a utility which processes a torch.fx.GraphModule and
accumulates information about its ops, graph breaks, and other details. It
@ -412,8 +470,8 @@ def _explain_graph_detail(
ops = [node.target for node in gm.graph.nodes if node.op == "call_function"]
op_count += len(ops)
ops_per_graph.append(ops)
if gm.compile_subgraph_reason.graph_break:
break_reasons.append(gm.compile_subgraph_reason)
if gm.compile_subgraph_reason.graph_break: # type: ignore[union-attr]
break_reasons.append(gm.compile_subgraph_reason) # type: ignore[arg-type]
return gm, graphs, op_count, ops_per_graph, break_reasons
@ -443,17 +501,20 @@ class ExplainWithBackend:
print(eb.output())
"""
def __init__(self, backend) -> None:
def __init__(self, backend: Union[CompilerFn, str]) -> None:
from .registry import lookup_backend
self.backend = lookup_backend(backend)
self.graphs = []
self.graphs: list[torch.fx.GraphModule] = []
self.op_count = 0
self.break_reasons = []
self.break_reasons: list[GraphCompileReason] = []
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
def __call__(
self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> CompiledFn:
ops_per_graph: list[list[Target]] = []
gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail(
gm, self.graphs, self.op_count, [], self.break_reasons
gm, self.graphs, self.op_count, ops_per_graph, self.break_reasons
)
return self.backend(gm, example_inputs)

View File

@ -1245,7 +1245,7 @@ def explain(f: Callable[..., Any], *extra_args: Any, **extra_kwargs: Any) -> Any
graphs: list[torch.fx.GraphModule] = []
break_reasons: list[Any] = []
op_count: int = 0
ops_per_graph: list[torch.fx.Node] = []
ops_per_graph: list[list[Target]] = []
out_guards: list[_guards.Guard] = []
def dynamo_graph_accumulating_compiler(

View File

@ -191,7 +191,9 @@ def cond(
):
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
backend: Union[str, Callable[..., Any]] = (
make_eager_backend_with_torch_function_mode(metadata_mode)
)
else:
backend = "eager"
return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(

View File

@ -3,7 +3,7 @@
import contextlib
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Optional, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.utils._pytree as pytree
@ -134,7 +134,9 @@ def invoke_subgraph_placeholder(func, *args, **kwargs):
):
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
backend: Union[str, Callable[..., Any]] = (
make_eager_backend_with_torch_function_mode(metadata_mode)
)
else:
backend = "eager"

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from typing import Any, Callable, Union
import torch
import torch._subclasses.functional_tensor
import torch.utils._pytree as pytree
@ -33,7 +35,9 @@ def strict_mode(callable, operands):
modes = [metadata_mode, predispatch_mode]
modes = [mode for mode in modes if mode is not None]
if modes:
backend = make_eager_backend_with_torch_function_modes(modes)
backend: Union[str, Callable[..., Any]] = (
make_eager_backend_with_torch_function_modes(modes)
)
else:
backend = "eager"
with torch._dynamo.utils.disable_cache_limit():

View File

@ -103,7 +103,9 @@ def _maybe_compile_and_run_fn(fn, *args):
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
backend: Union[str, Callable[..., Any]] = (
make_eager_backend_with_torch_function_mode(metadata_mode)
)
else:
backend = "eager"
return torch.compile(fn, backend=backend, fullgraph=True)(*args)

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Callable, Union
from typing import Any, Callable, Union
import torch
import torch.utils._pytree as pytree
@ -171,7 +171,9 @@ def while_loop(cond_fn, body_fn, carried_inputs):
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
backend: Union[str, Callable[..., Any]] = (
make_eager_backend_with_torch_function_mode(metadata_mode)
)
else:
backend = "eager"
return torch.compile(

View File

@ -5,8 +5,8 @@ import logging
import operator
import types
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import ParamSpec, TypeAlias, TypeVar
import torch
from torch._C import _fx_map_aggregate, _fx_map_arg, _NodeBase
@ -46,7 +46,7 @@ BaseArgumentTypes = Union[
]
base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined]
Target = Union[Callable[..., Any], str]
Target: TypeAlias = Union[Callable[..., Any], str]
Argument = Optional[
Union[

View File

@ -9,7 +9,7 @@ import math
import operator
import warnings
from enum import Enum
from typing import Callable, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
from torch import Tensor
@ -1607,8 +1607,8 @@ def flex_attention(
with _temp_remove_pre_dispatch_torch_function_mode():
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(
metadata_mode
backend: Union[str, Callable[..., Any]] = (
make_eager_backend_with_torch_function_mode(metadata_mode)
)
else:
backend = "eager"