|
|
|
@ -1,5 +1,3 @@
|
|
|
|
|
# mypy: ignore-errors
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
This module provides debugging backends for TorchDynamo to help diagnose and troubleshoot
|
|
|
|
|
compilation and execution issues. It includes:
|
|
|
|
@ -28,40 +26,54 @@ These backends are primarily used for:
|
|
|
|
|
import dataclasses
|
|
|
|
|
import functools
|
|
|
|
|
import logging
|
|
|
|
|
from collections.abc import Iterable
|
|
|
|
|
from importlib import import_module
|
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from functorch.compile import min_cut_rematerialization_partition
|
|
|
|
|
from torch import _guards
|
|
|
|
|
from torch._dynamo.output_graph import GraphCompileReason
|
|
|
|
|
from torch._functorch import config as functorch_config
|
|
|
|
|
from torch._functorch.compilers import ts_compile
|
|
|
|
|
|
|
|
|
|
from .common import aot_autograd
|
|
|
|
|
from .registry import register_debug_backend as register_backend
|
|
|
|
|
from .registry import CompiledFn, CompilerFn, register_debug_backend as register_backend
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from torch.fx.node import Target
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_backend
|
|
|
|
|
def eager(gm, fake_tensor_inputs, **kwargs):
|
|
|
|
|
def eager(
|
|
|
|
|
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
if kwargs:
|
|
|
|
|
log.warning("eager backend ignoring extra kwargs %s", kwargs)
|
|
|
|
|
return gm.forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_eager_backend_with_torch_function_mode(mode):
|
|
|
|
|
def make_eager_backend_with_torch_function_mode(
|
|
|
|
|
mode: torch.overrides.TorchFunctionMode,
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
return make_eager_backend_with_torch_function_modes([mode])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_eager_backend_with_torch_function_modes(modes):
|
|
|
|
|
def make_eager_backend_with_torch_function_modes(
|
|
|
|
|
modes: Iterable[torch.overrides.TorchFunctionMode],
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
"""Used to trace HOPs (cond and while) for eager execution, the metadata
|
|
|
|
|
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
|
|
|
|
|
in the HOP, so we need to externally run this mode and not trace it."""
|
|
|
|
|
from contextlib import ExitStack
|
|
|
|
|
|
|
|
|
|
def fn(gm, fake_tensor_inputs, **kwargs):
|
|
|
|
|
def fn(
|
|
|
|
|
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
stack = ExitStack()
|
|
|
|
|
for mode in modes:
|
|
|
|
|
stack.enter_context(mode)
|
|
|
|
@ -74,13 +86,15 @@ def make_eager_backend_with_torch_function_modes(modes):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_backend
|
|
|
|
|
def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
|
|
|
|
|
def eager_noexcept(
|
|
|
|
|
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
if kwargs:
|
|
|
|
|
log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs)
|
|
|
|
|
|
|
|
|
|
# This backend is intended to check that dynamo-generated GraphModules
|
|
|
|
|
# do not cause errors.
|
|
|
|
|
def inner(*args):
|
|
|
|
|
def inner(*args: Any) -> Any:
|
|
|
|
|
try:
|
|
|
|
|
return gm(*args)
|
|
|
|
|
except Exception as e:
|
|
|
|
@ -92,13 +106,15 @@ def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_backend
|
|
|
|
|
def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs):
|
|
|
|
|
def pre_dispatch_eager(
|
|
|
|
|
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
|
|
|
|
) -> torch.fx.GraphModule:
|
|
|
|
|
if kwargs:
|
|
|
|
|
log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs)
|
|
|
|
|
|
|
|
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
|
|
|
|
|
|
def runnable_gm(*args):
|
|
|
|
|
def runnable_gm(*args: Any) -> Any:
|
|
|
|
|
return torch.fx.Interpreter(gm).run(*args)
|
|
|
|
|
|
|
|
|
|
pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
|
|
|
|
@ -108,7 +124,9 @@ def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_backend
|
|
|
|
|
def eager_debug(gm, fake_tensor_inputs, **kwargs):
|
|
|
|
|
def eager_debug(
|
|
|
|
|
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
if kwargs:
|
|
|
|
|
log.warning("eager_debug backend ignoring extra kwargs %s", kwargs)
|
|
|
|
|
|
|
|
|
@ -117,42 +135,55 @@ def eager_debug(gm, fake_tensor_inputs, **kwargs):
|
|
|
|
|
# We could add more debugging bits here.
|
|
|
|
|
# Right now, this backend can be used to check for and error on
|
|
|
|
|
# custom dispatcher ops that have incorrect schemas.
|
|
|
|
|
def inner(*args):
|
|
|
|
|
def inner(*args: Any) -> Any:
|
|
|
|
|
with SchemaCheckMode():
|
|
|
|
|
return torch.fx.Interpreter(gm).run(*args)
|
|
|
|
|
|
|
|
|
|
return inner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_backend(name="ts")
|
|
|
|
|
def torchscript(gm, fake_tensor_inputs):
|
|
|
|
|
@register_backend(name="ts") # type: ignore[misc]
|
|
|
|
|
def torchscript(
|
|
|
|
|
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
|
|
|
|
|
) -> torch.jit.ScriptModule:
|
|
|
|
|
return torch.jit.script(gm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# used boxed call to discard inputs when they are no longer needed
|
|
|
|
|
def boxed_nop(fx_g, example_inputs):
|
|
|
|
|
def run(args):
|
|
|
|
|
def boxed_nop(
|
|
|
|
|
fx_g: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
def run(args: Any) -> Any:
|
|
|
|
|
return torch.fx.Interpreter(fx_g).boxed_run(args)
|
|
|
|
|
|
|
|
|
|
run._boxed_call = True
|
|
|
|
|
run._boxed_call = True # type: ignore[attr-defined]
|
|
|
|
|
return run
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def boxed_nop_with_mode(fx_g, example_inputs, *, mode):
|
|
|
|
|
def run(args):
|
|
|
|
|
def boxed_nop_with_mode(
|
|
|
|
|
fx_g: torch.fx.GraphModule,
|
|
|
|
|
example_inputs: list[torch.Tensor],
|
|
|
|
|
*,
|
|
|
|
|
mode: torch.overrides.TorchFunctionMode,
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
def run(args: Any) -> Any:
|
|
|
|
|
with mode:
|
|
|
|
|
return torch.fx.Interpreter(fx_g).boxed_run(args)
|
|
|
|
|
|
|
|
|
|
run._boxed_call = True
|
|
|
|
|
run._boxed_call = True # type: ignore[attr-defined]
|
|
|
|
|
return run
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fake_crossref_boxed_nop(fx_g, example_inputs, ignore_op_fn=None):
|
|
|
|
|
def run(args):
|
|
|
|
|
def fake_crossref_boxed_nop(
|
|
|
|
|
fx_g: torch.fx.GraphModule,
|
|
|
|
|
example_inputs: list[torch.Tensor],
|
|
|
|
|
ignore_op_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None,
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
def run(args: Any) -> Any:
|
|
|
|
|
with torch._subclasses.CrossRefFakeMode(ignore_op_fn):
|
|
|
|
|
return torch.fx.Interpreter(fx_g).boxed_run(args)
|
|
|
|
|
|
|
|
|
|
run._boxed_call = True
|
|
|
|
|
run._boxed_call = True # type: ignore[attr-defined]
|
|
|
|
|
return run
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -160,7 +191,9 @@ def ignore_builtins(op: torch._ops.OpOverload) -> bool:
|
|
|
|
|
return op.namespace in ("aten", "prims", "prim")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_nop_func():
|
|
|
|
|
def get_nop_func() -> Callable[
|
|
|
|
|
[torch.fx.GraphModule, list[torch.Tensor]], Callable[..., Any]
|
|
|
|
|
]:
|
|
|
|
|
if not torch._functorch.config.fake_tensor_crossref:
|
|
|
|
|
return boxed_nop
|
|
|
|
|
elif torch._functorch.config.fake_tensor_crossref == "all":
|
|
|
|
@ -173,12 +206,12 @@ def get_nop_func():
|
|
|
|
|
# Useful for debugging purpose
|
|
|
|
|
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
|
|
|
|
|
def aot_eager(
|
|
|
|
|
gm,
|
|
|
|
|
fake_tensor_inputs,
|
|
|
|
|
fw_compiler=None,
|
|
|
|
|
bw_compiler=None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
gm: torch.fx.GraphModule,
|
|
|
|
|
fake_tensor_inputs: list[torch.Tensor],
|
|
|
|
|
fw_compiler: Optional[Callable[..., Any]] = None,
|
|
|
|
|
bw_compiler: Optional[Callable[..., Any]] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
return aot_autograd(
|
|
|
|
|
fw_compiler=fw_compiler or boxed_nop,
|
|
|
|
|
bw_compiler=bw_compiler or boxed_nop,
|
|
|
|
@ -201,7 +234,9 @@ register_backend(
|
|
|
|
|
# inductor problems.
|
|
|
|
|
# aot_eager_decomp_partition just replaces the inductor compiler with nop to help
|
|
|
|
|
# isolate inductor vs aot_eager errors
|
|
|
|
|
def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
|
|
|
|
|
def aot_eager_decomp_partition(
|
|
|
|
|
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
if kwargs:
|
|
|
|
|
log.warning(
|
|
|
|
|
"aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs
|
|
|
|
@ -213,7 +248,7 @@ def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
|
|
|
|
|
if bisect_changes := CompilerBisector.get_config_change(
|
|
|
|
|
"aot_eager_decomp_partition"
|
|
|
|
|
):
|
|
|
|
|
config_patches.update(bisect_changes)
|
|
|
|
|
config_patches.update(bisect_changes) # type: ignore[arg-type]
|
|
|
|
|
|
|
|
|
|
with functorch_config.patch(config_patches):
|
|
|
|
|
return aot_autograd(
|
|
|
|
@ -237,7 +272,12 @@ register_backend(
|
|
|
|
|
|
|
|
|
|
# aot_eager_decomp_partition_with_mode is similar as aot_eager_decomp_partition,
|
|
|
|
|
# except that it takes a TorchDispatchMode mode and run the fw/bw in the mode
|
|
|
|
|
def aot_eager_decomp_partition_with_mode(gm, fake_tensor_inputs, mode, **kwarg):
|
|
|
|
|
def aot_eager_decomp_partition_with_mode(
|
|
|
|
|
gm: torch.fx.GraphModule,
|
|
|
|
|
fake_tensor_inputs: list[torch.Tensor],
|
|
|
|
|
mode: Any,
|
|
|
|
|
**kwarg: Any,
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
return aot_autograd(
|
|
|
|
|
# these are taken from memory_efficient_fusion()
|
|
|
|
|
fw_compiler=functools.partial(boxed_nop_with_mode, mode=mode),
|
|
|
|
@ -254,11 +294,13 @@ def aot_eager_decomp_partition_with_mode(gm, fake_tensor_inputs, mode, **kwarg):
|
|
|
|
|
|
|
|
|
|
register_backend(
|
|
|
|
|
name="aot_eager_decomp_partition_with_mode",
|
|
|
|
|
compiler_fn=aot_eager_decomp_partition_with_mode,
|
|
|
|
|
compiler_fn=aot_eager_decomp_partition_with_mode, # type: ignore[arg-type]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def aot_eager_decomp_partition_crossref(gm, fake_tensor_inputs, **kwargs):
|
|
|
|
|
def aot_eager_decomp_partition_crossref(
|
|
|
|
|
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
|
|
|
|
) -> Callable[..., Any]:
|
|
|
|
|
# if the config is set, respect it, otherwise only test custom_ops.
|
|
|
|
|
# custom_op bad metas always manifest as an error whereas aten will only sometimes.
|
|
|
|
|
# by default, use the less noisy option
|
|
|
|
@ -296,7 +338,9 @@ class TestingOnlyCompileError(Exception):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_backend
|
|
|
|
|
def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
|
|
|
|
def relu_compile_error_TESTING_ONLY(
|
|
|
|
|
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
|
|
|
|
) -> torch.fx.GraphModule:
|
|
|
|
|
for node in gm.graph.nodes:
|
|
|
|
|
if node.target == torch.relu:
|
|
|
|
|
raise ReluCompileError
|
|
|
|
@ -304,7 +348,9 @@ def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_backend
|
|
|
|
|
def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
|
|
|
|
def relu_runtime_error_TESTING_ONLY(
|
|
|
|
|
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
|
|
|
|
) -> torch.fx.GraphModule:
|
|
|
|
|
for node in gm.graph.nodes:
|
|
|
|
|
if node.target == torch.relu:
|
|
|
|
|
node.target = torch._assert
|
|
|
|
@ -314,7 +360,9 @@ def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_backend
|
|
|
|
|
def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
|
|
|
|
def relu_accuracy_error_TESTING_ONLY(
|
|
|
|
|
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
|
|
|
|
) -> torch.fx.GraphModule:
|
|
|
|
|
for node in gm.graph.nodes:
|
|
|
|
|
if node.target == torch.relu:
|
|
|
|
|
node.target = torch.add
|
|
|
|
@ -325,7 +373,9 @@ def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_backend
|
|
|
|
|
def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
|
|
|
|
def non_leaf_compile_error_TESTING_ONLY(
|
|
|
|
|
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
|
|
|
|
) -> torch.fx.GraphModule:
|
|
|
|
|
# Require at least one non-trivial thing in the graph,
|
|
|
|
|
# see https://github.com/pytorch/pytorch/issues/102898
|
|
|
|
|
for node in gm.graph.nodes:
|
|
|
|
@ -349,11 +399,9 @@ class ExplainOutput:
|
|
|
|
|
graphs: list[torch.fx.GraphModule]
|
|
|
|
|
graph_count: int
|
|
|
|
|
graph_break_count: int
|
|
|
|
|
break_reasons: list[
|
|
|
|
|
Any
|
|
|
|
|
] # Type is GraphCompileReason but doesn't matter for this purpose
|
|
|
|
|
break_reasons: list[GraphCompileReason]
|
|
|
|
|
op_count: int
|
|
|
|
|
ops_per_graph: Optional[list[torch.fx.Node]] = None
|
|
|
|
|
ops_per_graph: Optional[list[list["Target"]]] = None
|
|
|
|
|
out_guards: Optional[list[_guards.Guard]] = None
|
|
|
|
|
compile_times: Optional[str] = None
|
|
|
|
|
|
|
|
|
@ -389,8 +437,18 @@ class ExplainOutput:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _explain_graph_detail(
|
|
|
|
|
gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons
|
|
|
|
|
):
|
|
|
|
|
gm: torch.fx.GraphModule,
|
|
|
|
|
graphs: list[torch.fx.GraphModule],
|
|
|
|
|
op_count: int,
|
|
|
|
|
ops_per_graph: list[list["Target"]],
|
|
|
|
|
break_reasons: list[GraphCompileReason],
|
|
|
|
|
) -> tuple[
|
|
|
|
|
torch.fx.GraphModule,
|
|
|
|
|
list[torch.fx.GraphModule],
|
|
|
|
|
int,
|
|
|
|
|
list[list["Target"]],
|
|
|
|
|
list[GraphCompileReason],
|
|
|
|
|
]:
|
|
|
|
|
"""
|
|
|
|
|
This function is a utility which processes a torch.fx.GraphModule and
|
|
|
|
|
accumulates information about its ops, graph breaks, and other details. It
|
|
|
|
@ -412,8 +470,8 @@ def _explain_graph_detail(
|
|
|
|
|
ops = [node.target for node in gm.graph.nodes if node.op == "call_function"]
|
|
|
|
|
op_count += len(ops)
|
|
|
|
|
ops_per_graph.append(ops)
|
|
|
|
|
if gm.compile_subgraph_reason.graph_break:
|
|
|
|
|
break_reasons.append(gm.compile_subgraph_reason)
|
|
|
|
|
if gm.compile_subgraph_reason.graph_break: # type: ignore[union-attr]
|
|
|
|
|
break_reasons.append(gm.compile_subgraph_reason) # type: ignore[arg-type]
|
|
|
|
|
|
|
|
|
|
return gm, graphs, op_count, ops_per_graph, break_reasons
|
|
|
|
|
|
|
|
|
@ -443,17 +501,20 @@ class ExplainWithBackend:
|
|
|
|
|
print(eb.output())
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, backend) -> None:
|
|
|
|
|
def __init__(self, backend: Union[CompilerFn, str]) -> None:
|
|
|
|
|
from .registry import lookup_backend
|
|
|
|
|
|
|
|
|
|
self.backend = lookup_backend(backend)
|
|
|
|
|
self.graphs = []
|
|
|
|
|
self.graphs: list[torch.fx.GraphModule] = []
|
|
|
|
|
self.op_count = 0
|
|
|
|
|
self.break_reasons = []
|
|
|
|
|
self.break_reasons: list[GraphCompileReason] = []
|
|
|
|
|
|
|
|
|
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
|
|
|
|
def __call__(
|
|
|
|
|
self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
|
|
|
|
) -> CompiledFn:
|
|
|
|
|
ops_per_graph: list[list[Target]] = []
|
|
|
|
|
gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail(
|
|
|
|
|
gm, self.graphs, self.op_count, [], self.break_reasons
|
|
|
|
|
gm, self.graphs, self.op_count, ops_per_graph, self.break_reasons
|
|
|
|
|
)
|
|
|
|
|
return self.backend(gm, example_inputs)
|
|
|
|
|
|
|
|
|
|