Compare commits

...

8 Commits

Author SHA1 Message Date
1207f9ab93 typing tvm.py 2025-08-11 14:57:13 -07:00
fcfb6bab89 Type backend torchxla 2025-08-11 14:37:34 -07:00
95bd114806 typing registry.py 2025-08-11 14:09:50 -07:00
ec68abdc38 typing inductor and placeholder backends 2025-08-11 13:51:06 -07:00
ee417d1806 typing distributed.py 2025-08-11 13:43:05 -07:00
4a8afeaffb typing debugging.py 2025-08-11 11:33:30 -07:00
90cba401a0 Type cudagraphs.py 2025-08-11 10:24:17 -07:00
6e4c4d9e57 Typing for common.py 2025-08-11 09:50:55 -07:00
12 changed files with 290 additions and 179 deletions

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
This module provides common utilities and base classes for TorchDynamo backends.
@ -21,6 +19,9 @@ optimization of both forward and backward passes.
import contextlib
import functools
import logging
from collections.abc import Iterable
from typing import Any, Callable
from typing_extensions import ParamSpec, TypeVar
from unittest.mock import patch
import torch
@ -36,13 +37,18 @@ from torch.utils._python_dispatch import _disable_current_modes
log = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
class AotAutograd:
def __init__(self, **kwargs) -> None:
def __init__(self, **kwargs: Any) -> None:
self.__name__ = "compiler_fn"
self.kwargs = kwargs
def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs):
def __call__(
self, gm: torch.fx.GraphModule, example_inputs: Iterable[Any], **kwargs: Any
) -> Callable[..., Any]:
if kwargs:
log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs)
@ -66,8 +72,8 @@ class AotAutograd:
counters["aot_autograd"]["not_ok"] += 1
return gm
def wrap_bw_compiler(bw_compiler_fn):
def _wrapped_bw_compiler(*args, **kwargs):
def wrap_bw_compiler(bw_compiler_fn: Callable[P, R]) -> Callable[..., R]:
def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R:
# Note [Wrapping bw_compiler in disable]
# The two disables here:
# - stop TorchDynamo from trying to compile the bw_compiler function itself
@ -75,7 +81,7 @@ class AotAutograd:
return disable(
disable(
bw_compiler_fn, reason="do not trace backward compiler function"
)(*args, **kwargs),
)(*args, **kwargs), # type: ignore[misc]
reason="do not trace generated backwards pass",
)
@ -99,7 +105,9 @@ class AotAutograd:
# debug asserts slow down compile time noticeably,
# So only default them on when the aot_eager backend is used.
if self.kwargs.get("fw_compiler", None) == nop:
patch_config = patch("functorch.compile.config.debug_assert", True)
patch_config: contextlib.AbstractContextManager[Any] = patch(
"functorch.compile.config.debug_assert", True
)
else:
patch_config = contextlib.nullcontext()
@ -116,11 +124,11 @@ class AotAutograd:
raise
def aot_autograd(**kwargs) -> AotAutograd:
def aot_autograd(**kwargs: Any) -> AotAutograd:
return AotAutograd(**kwargs)
def mem_efficient_fusion_kwargs(use_decomps):
def mem_efficient_fusion_kwargs(use_decomps: bool) -> dict[str, Any]:
from functorch.compile import (
default_decompositions,
min_cut_rematerialization_partition,
@ -140,28 +148,30 @@ def mem_efficient_fusion_kwargs(use_decomps):
return kwargs
def fake_tensor_unsupported(fn):
def fake_tensor_unsupported(fn: Callable[[Any, list[Any], Any], R]) -> Any:
"""
Decorator for backends that need real inputs. We swap out fake
tensors for zero tensors.
"""
@functools.wraps(fn)
def wrapper(model, inputs, **kwargs):
def wrapper(model: Any, inputs: Any, **kwargs: Any) -> Any:
with _disable_current_modes():
inputs = list(map(defake, inputs))
return fn(model, inputs, **kwargs)
return fn(model, inputs, **kwargs) # type: ignore[call-arg]
return wrapper
def device_from_inputs(example_inputs) -> torch.device:
def device_from_inputs(example_inputs: Iterable[Any]) -> torch.device:
for x in example_inputs:
if hasattr(x, "device"):
return x.device
return torch.device("cpu") # Default fallback
def dtype_from_inputs(example_inputs) -> torch.dtype:
def dtype_from_inputs(example_inputs: Iterable[Any]) -> torch.dtype:
for x in example_inputs:
if hasattr(x, "dtype"):
return x.dtype
return torch.float32 # Default fallback

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
This module implements CUDA graphs support for TorchDynamo backends.
@ -25,9 +23,11 @@ Key components:
import functools
from collections import defaultdict
from typing import Optional
from collections.abc import Sequence
from typing import Any, Callable, Optional
import torch
import torch.fx
from torch._dynamo import config
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.backends.debugging import boxed_nop
@ -51,8 +51,8 @@ from torch.multiprocessing.reductions import StorageWeakRef
from .registry import register_backend
def find_input_mutations(g):
def meta_fk(meta):
def find_input_mutations(g: torch.fx.Graph) -> set[int]:
def meta_fk(meta: dict[str, Any]) -> Any:
return meta["val"] if "val" in meta else meta["fake_result"]
inputs = defaultdict(set)
@ -90,7 +90,9 @@ def find_input_mutations(g):
return mutated_inputs
def get_device_node_mapping(gm: torch.fx.GraphModule):
def get_device_node_mapping(
gm: torch.fx.GraphModule,
) -> dict[torch.device, torch.fx.Node]:
device_node_mapping: dict[torch.device, torch.fx.Node] = {}
for n in gm.graph.nodes:
t = n.meta.get("val", None)
@ -100,7 +102,7 @@ def get_device_node_mapping(gm: torch.fx.GraphModule):
def check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model: torch.fx.GraphModule, num_fixed
aot_model: torch.fx.GraphModule, num_fixed: int
) -> Optional[str]:
mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
if not mutation_indices:
@ -110,7 +112,7 @@ def check_for_mutation_ignore_cuda_graph_managed_tensor(
return get_mutation_stack_trace(placeholders, mutation_indices)
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]:
if not config.cudagraph_backend_support_input_mutation:
if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model, num_fixed
@ -128,28 +130,35 @@ def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
return None
def get_device_index(gm) -> int:
def get_device_index(gm: torch.fx.GraphModule) -> int:
device = next(iter(get_device_node_mapping(gm)))
assert device.type == "cuda"
return device.index
def get_stack_traces(gm) -> list[Optional[str]]:
def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]:
output = output_node(gm)
assert len(output.args) == 1
args = output.args[0]
if not hasattr(args, "__iter__"):
return []
return [
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
for arg in output.args[0]
for arg in args # type: ignore[union-attr]
]
def cudagraphs(dynamo_model, dynamo_inputs):
def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any:
from torch._inductor.cudagraph_trees import cudagraphify_impl
do_cudagraphs = BoxedBool(True)
boxed_device_index = BoxedDeviceIndex(None)
def forward_cudagraphs(aot_model, aot_inputs, is_inference=False):
def forward_cudagraphs(
aot_model: torch.fx.GraphModule,
aot_inputs: list[Any],
is_inference: bool = False,
) -> Any:
interp = boxed_nop(aot_model, aot_inputs)
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
if skip_msg := check_for_skip(aot_model, fixed):
@ -166,15 +175,17 @@ def cudagraphs(dynamo_model, dynamo_inputs):
range(fixed),
device_index=boxed_device_index.value,
is_backward=False,
is_inference=False,
is_inference=False, # Q: should forward is_inference here?
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
out._boxed_call = True # type: ignore[attr-defined]
return out
def backward_cudagraphs(aot_model, aot_inputs):
def backward_cudagraphs(
aot_model: torch.fx.GraphModule, aot_inputs: list[Any]
) -> Any:
interp = boxed_nop(aot_model, aot_inputs)
if not do_cudagraphs:
return aot_model
@ -182,20 +193,23 @@ def cudagraphs(dynamo_model, dynamo_inputs):
fixed = count_tangents(aot_model)
if skip_msg := check_for_skip(aot_model, fixed):
log_cudagraph_skip_and_bump_counter(
"skipping cudagraphs due to %s", skip_msg
f"skipping cudagraphs due to {skip_msg}"
)
# See [Backward Generation Handling]
device_idx = boxed_device_index.value
if device_idx is None:
device_idx = 0 # Default to device 0 if not set
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_device_index.value, create_if_none_exists=False
device_idx, create_if_none_exists=False
)
assert manager is not None
def fn(inputs):
def fn(inputs: list[Any]) -> Any:
manager.set_to_running_backward()
return aot_model(inputs)
fn._boxed_call = True
fn._boxed_call = True # type: ignore[attr-defined]
return fn
out = cudagraphify_impl(
@ -209,7 +223,7 @@ def cudagraphs(dynamo_model, dynamo_inputs):
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
out._boxed_call = True # type: ignore[attr-defined]
return out
aot_cudagraphs = aot_autograd(
@ -225,13 +239,13 @@ class CudagraphsBackend:
compiler_name = "cudagraphs"
@staticmethod
def reset():
def reset() -> None:
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
reset_cudagraph_trees()
@staticmethod
def __call__(model, inputs):
def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any:
return cudagraphs(model, inputs)
@ -240,7 +254,12 @@ class CudagraphsBackend:
register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
def cudagraphs_inner(
model: Callable[..., Any],
inputs: Sequence[Any],
copy_outputs: bool = True,
copy_inputs: bool = True,
) -> Callable[..., Sequence[Any]]:
"""This isn't registered as a backend, but is used in some benchmarks"""
assert isinstance(inputs, (list, tuple))
if copy_inputs:
@ -265,7 +284,7 @@ def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
def run(*new_inputs):
def run(*new_inputs: Any) -> Sequence[Any]:
assert len(static_inputs) == len(new_inputs)
if copy_inputs:
for dst, src in zip(static_inputs, new_inputs):

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
This module provides debugging backends for TorchDynamo to help diagnose and troubleshoot
compilation and execution issues. It includes:
@ -28,40 +26,56 @@ These backends are primarily used for:
import dataclasses
import functools
import logging
from collections.abc import Iterable
from importlib import import_module
from typing import Any, Optional
from typing import Any, Callable, Optional, Union
import torch
from functorch.compile import min_cut_rematerialization_partition
from torch import _guards
from torch._dynamo.output_graph import GraphCompileReason
from torch._functorch import config as functorch_config
from torch._functorch.compilers import ts_compile
from torch.fx.node import Target
from .common import aot_autograd
from .registry import register_debug_backend as register_backend
from .registry import (
CompiledFn,
CompilerFn,
lookup_backend,
register_debug_backend as register_backend,
)
log = logging.getLogger(__name__)
@register_backend
def eager(gm, fake_tensor_inputs, **kwargs):
def eager(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
if kwargs:
log.warning("eager backend ignoring extra kwargs %s", kwargs)
return gm.forward
def make_eager_backend_with_torch_function_mode(mode):
def make_eager_backend_with_torch_function_mode(
mode: torch.overrides.TorchFunctionMode,
) -> Callable[..., Any]:
return make_eager_backend_with_torch_function_modes([mode])
def make_eager_backend_with_torch_function_modes(modes):
def make_eager_backend_with_torch_function_modes(
modes: Iterable[torch.overrides.TorchFunctionMode],
) -> Callable[..., Any]:
"""Used to trace HOPs (cond and while) for eager execution, the metadata
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
in the HOP, so we need to externally run this mode and not trace it."""
from contextlib import ExitStack
def fn(gm, fake_tensor_inputs, **kwargs):
def fn(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
stack = ExitStack()
for mode in modes:
stack.enter_context(mode)
@ -74,13 +88,15 @@ def make_eager_backend_with_torch_function_modes(modes):
@register_backend
def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
def eager_noexcept(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
if kwargs:
log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs)
# This backend is intended to check that dynamo-generated GraphModules
# do not cause errors.
def inner(*args):
def inner(*args: Any) -> Any:
try:
return gm(*args)
except Exception as e:
@ -92,13 +108,15 @@ def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
@register_backend
def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs):
def pre_dispatch_eager(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> torch.fx.GraphModule:
if kwargs:
log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs)
from torch.fx.experimental.proxy_tensor import make_fx
def runnable_gm(*args):
def runnable_gm(*args: Any) -> Any:
return torch.fx.Interpreter(gm).run(*args)
pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
@ -108,7 +126,9 @@ def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs):
@register_backend
def eager_debug(gm, fake_tensor_inputs, **kwargs):
def eager_debug(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
if kwargs:
log.warning("eager_debug backend ignoring extra kwargs %s", kwargs)
@ -117,42 +137,55 @@ def eager_debug(gm, fake_tensor_inputs, **kwargs):
# We could add more debugging bits here.
# Right now, this backend can be used to check for and error on
# custom dispatcher ops that have incorrect schemas.
def inner(*args):
def inner(*args: Any) -> Any:
with SchemaCheckMode():
return torch.fx.Interpreter(gm).run(*args)
return inner
@register_backend(name="ts")
def torchscript(gm, fake_tensor_inputs):
@register_backend(name="ts") # type: ignore[misc]
def torchscript(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
) -> torch.jit.ScriptModule:
return torch.jit.script(gm)
# used boxed call to discard inputs when they are no longer needed
def boxed_nop(fx_g, example_inputs):
def run(args):
def boxed_nop(
fx_g: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> Callable[..., Any]:
def run(args: Any) -> Any:
return torch.fx.Interpreter(fx_g).boxed_run(args)
run._boxed_call = True
run._boxed_call = True # type: ignore[attr-defined]
return run
def boxed_nop_with_mode(fx_g, example_inputs, *, mode):
def run(args):
def boxed_nop_with_mode(
fx_g: torch.fx.GraphModule,
example_inputs: list[torch.Tensor],
*,
mode: torch.overrides.TorchFunctionMode,
) -> Callable[..., Any]:
def run(args: Any) -> Any:
with mode:
return torch.fx.Interpreter(fx_g).boxed_run(args)
run._boxed_call = True
run._boxed_call = True # type: ignore[attr-defined]
return run
def fake_crossref_boxed_nop(fx_g, example_inputs, ignore_op_fn=None):
def run(args):
def fake_crossref_boxed_nop(
fx_g: torch.fx.GraphModule,
example_inputs: list[torch.Tensor],
ignore_op_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None,
) -> Callable[..., Any]:
def run(args: Any) -> Any:
with torch._subclasses.CrossRefFakeMode(ignore_op_fn):
return torch.fx.Interpreter(fx_g).boxed_run(args)
run._boxed_call = True
run._boxed_call = True # type: ignore[attr-defined]
return run
@ -160,7 +193,9 @@ def ignore_builtins(op: torch._ops.OpOverload) -> bool:
return op.namespace in ("aten", "prims", "prim")
def get_nop_func():
def get_nop_func() -> Callable[
[torch.fx.GraphModule, list[torch.Tensor]], Callable[..., Any]
]:
if not torch._functorch.config.fake_tensor_crossref:
return boxed_nop
elif torch._functorch.config.fake_tensor_crossref == "all":
@ -173,12 +208,12 @@ def get_nop_func():
# Useful for debugging purpose
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
def aot_eager(
gm,
fake_tensor_inputs,
fw_compiler=None,
bw_compiler=None,
**kwargs,
):
gm: torch.fx.GraphModule,
fake_tensor_inputs: list[torch.Tensor],
fw_compiler: Optional[Callable[..., Any]] = None,
bw_compiler: Optional[Callable[..., Any]] = None,
**kwargs: Any,
) -> Callable[..., Any]:
return aot_autograd(
fw_compiler=fw_compiler or boxed_nop,
bw_compiler=bw_compiler or boxed_nop,
@ -201,7 +236,9 @@ register_backend(
# inductor problems.
# aot_eager_decomp_partition just replaces the inductor compiler with nop to help
# isolate inductor vs aot_eager errors
def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
def aot_eager_decomp_partition(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
if kwargs:
log.warning(
"aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs
@ -213,7 +250,7 @@ def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
if bisect_changes := CompilerBisector.get_config_change(
"aot_eager_decomp_partition"
):
config_patches.update(bisect_changes)
config_patches.update(bisect_changes) # type: ignore[arg-type]
with functorch_config.patch(config_patches):
return aot_autograd(
@ -237,7 +274,12 @@ register_backend(
# aot_eager_decomp_partition_with_mode is similar as aot_eager_decomp_partition,
# except that it takes a TorchDispatchMode mode and run the fw/bw in the mode
def aot_eager_decomp_partition_with_mode(gm, fake_tensor_inputs, mode, **kwarg):
def aot_eager_decomp_partition_with_mode(
gm: torch.fx.GraphModule,
fake_tensor_inputs: list[torch.Tensor],
mode: Any,
**kwarg: Any,
) -> Callable[..., Any]:
return aot_autograd(
# these are taken from memory_efficient_fusion()
fw_compiler=functools.partial(boxed_nop_with_mode, mode=mode),
@ -254,11 +296,13 @@ def aot_eager_decomp_partition_with_mode(gm, fake_tensor_inputs, mode, **kwarg):
register_backend(
name="aot_eager_decomp_partition_with_mode",
compiler_fn=aot_eager_decomp_partition_with_mode,
compiler_fn=aot_eager_decomp_partition_with_mode, # type: ignore[arg-type]
)
def aot_eager_decomp_partition_crossref(gm, fake_tensor_inputs, **kwargs):
def aot_eager_decomp_partition_crossref(
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
) -> Callable[..., Any]:
# if the config is set, respect it, otherwise only test custom_ops.
# custom_op bad metas always manifest as an error whereas aten will only sometimes.
# by default, use the less noisy option
@ -296,7 +340,9 @@ class TestingOnlyCompileError(Exception):
@register_backend
def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
def relu_compile_error_TESTING_ONLY(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.target == torch.relu:
raise ReluCompileError
@ -304,7 +350,9 @@ def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
@register_backend
def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
def relu_runtime_error_TESTING_ONLY(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.target == torch.relu:
node.target = torch._assert
@ -314,7 +362,9 @@ def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
@register_backend
def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
def relu_accuracy_error_TESTING_ONLY(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.target == torch.relu:
node.target = torch.add
@ -325,7 +375,9 @@ def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
@register_backend
def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
def non_leaf_compile_error_TESTING_ONLY(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
# Require at least one non-trivial thing in the graph,
# see https://github.com/pytorch/pytorch/issues/102898
for node in gm.graph.nodes:
@ -349,11 +401,9 @@ class ExplainOutput:
graphs: list[torch.fx.GraphModule]
graph_count: int
graph_break_count: int
break_reasons: list[
Any
] # Type is GraphCompileReason but doesn't matter for this purpose
break_reasons: list[GraphCompileReason]
op_count: int
ops_per_graph: Optional[list[torch.fx.Node]] = None
ops_per_graph: Optional[list[Target]] = None
out_guards: Optional[list[_guards.Guard]] = None
compile_times: Optional[str] = None
@ -374,7 +424,7 @@ class ExplainOutput:
output += "Ops per Graph:\n"
for idx, ops in enumerate(self.ops_per_graph):
output += f" Ops {idx + 1}:\n"
for op in ops:
for op in ops: # type: ignore[union-attr]
output += f" {op}\n"
if self.out_guards is not None:
@ -389,8 +439,18 @@ class ExplainOutput:
def _explain_graph_detail(
gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons
):
gm: torch.fx.GraphModule,
graphs: list[torch.fx.GraphModule],
op_count: int,
ops_per_graph: list[list[Target]],
break_reasons: list[GraphCompileReason],
) -> tuple[
torch.fx.GraphModule,
list[torch.fx.GraphModule],
int,
list[list[Target]],
list[GraphCompileReason],
]:
"""
This function is a utility which processes a torch.fx.GraphModule and
accumulates information about its ops, graph breaks, and other details. It
@ -412,8 +472,8 @@ def _explain_graph_detail(
ops = [node.target for node in gm.graph.nodes if node.op == "call_function"]
op_count += len(ops)
ops_per_graph.append(ops)
if gm.compile_subgraph_reason.graph_break:
break_reasons.append(gm.compile_subgraph_reason)
if gm.compile_subgraph_reason.graph_break: # type: ignore[union-attr]
break_reasons.append(gm.compile_subgraph_reason) # type: ignore[arg-type]
return gm, graphs, op_count, ops_per_graph, break_reasons
@ -443,17 +503,18 @@ class ExplainWithBackend:
print(eb.output())
"""
def __init__(self, backend) -> None:
from .registry import lookup_backend
def __init__(self, backend: Union[CompilerFn, str]) -> None:
self.backend = lookup_backend(backend)
self.graphs = []
self.graphs: list[torch.fx.GraphModule] = []
self.op_count = 0
self.break_reasons = []
self.break_reasons: list[GraphCompileReason] = []
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
def __call__(
self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> CompiledFn:
ops_per_graph: list[list[Target]] = []
gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail(
gm, self.graphs, self.op_count, [], self.break_reasons
gm, self.graphs, self.op_count, ops_per_graph, self.break_reasons
)
return self.backend(gm, example_inputs)

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
This module implements distributed training optimizations for TorchDynamo backends.
@ -21,11 +19,12 @@ of compilation.
import logging
import traceback
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any, Callable, Optional
from unittest import mock
import torch
from torch import fx
from torch._dynamo.backends.registry import CompiledFn, CompilerFn
from torch._dynamo.output_graph import GraphCompileReason
from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
from torch._logging import trace_structured
@ -39,7 +38,7 @@ log = logging.getLogger(__name__)
ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")
def args_str(args):
def args_str(args: Any) -> str:
# a debug helper
if torch.is_tensor(args):
return f"T[{args.shape}]"
@ -58,7 +57,7 @@ class Bucket:
nodes: list[fx.Node] = field(default_factory=list)
# param_ids is just used for unit testing
param_ids: list = field(default_factory=list)
param_ids: list[int] = field(default_factory=list)
# keep track of any buckets that were extended for logging purposes
opcount_increased_to_capture_external_output: int = 0
@ -78,9 +77,9 @@ def bucket_has_external_output(bucket: Bucket) -> bool:
return False
def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int):
def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int) -> None:
headers = ("Index", "Size (b)", "Param Names")
rows = []
rows: list[tuple[Optional[int], Optional[int], str]] = []
extended_buckets = []
for idx, bucket in enumerate(reversed(buckets)):
if len(bucket.params) > 0:
@ -136,7 +135,7 @@ def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int):
log.debug("DDPOptimizer captured no parameters and did not split this graph.")
def has_higher_order_op(gm):
def has_higher_order_op(gm: fx.GraphModule) -> bool:
# Check if there is a higher order op in the graph
for node in gm.graph.nodes:
if node.op == "get_attr":
@ -146,7 +145,7 @@ def has_higher_order_op(gm):
return False
def propagate_metadata(orig_gm, split_gm) -> None:
def propagate_metadata(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
for name, module in split_gm.named_modules():
if "." not in name and len(name):
# TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384
@ -154,7 +153,7 @@ def propagate_metadata(orig_gm, split_gm) -> None:
module._param_name_to_source = orig_gm._param_name_to_source
def propagate_dynamo_source(orig_gm, split_gm) -> None:
def propagate_dynamo_source(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
name_to_dynamo_source = {}
for node in orig_gm.graph.find_nodes(op="placeholder"):
name_to_dynamo_source[node.name] = node._dynamo_source
@ -168,12 +167,19 @@ def propagate_dynamo_source(orig_gm, split_gm) -> None:
# compile each of the partitioned submodules using the user-provided compiler
class SubmodCompiler(torch.fx.interpreter.Interpreter):
def __init__(self, module, compiler, fake_mode) -> None:
def __init__(
self,
module: fx.GraphModule,
compiler: CompilerFn,
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
) -> None:
super().__init__(module)
self.compiler = compiler
self.fake_mode = fake_mode
def compile_submod(self, input_mod, args, kwargs):
def compile_submod(
self, input_mod: fx.GraphModule, args: list[torch.Tensor], kwargs: Any
) -> Any:
"""
Compile the submodule,
using a wrapper to make sure its output is always a tuple,
@ -182,12 +188,14 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
assert len(kwargs) == 0, "We assume only args for these modules"
class WrapperModule(torch.nn.Module):
def __init__(self, submod, unwrap_singleton_tuple) -> None:
def __init__(
self, submod: Callable[..., Any], unwrap_singleton_tuple: bool
) -> None:
super().__init__()
self.submod = submod
self.unwrap_singleton_tuple = unwrap_singleton_tuple
def forward(self, *args):
def forward(self, *args: Any) -> Any:
x = self.submod(*args)
# TODO(whc)
# for some reason the isinstance check is necessary if I split one node per submod
@ -205,12 +213,12 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
sn.args = (sn.args,)
input_mod.recompile()
input_mod.compile_subgraph_reason = GraphCompileReason(
input_mod.compile_subgraph_reason = GraphCompileReason( # type: ignore[assignment]
"DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
" Set `torch._dynamo.config.optimize_ddp = False` to disable.",
[
# it's close to useless to get a real stacktrace here, and quite verbose.
traceback.FrameSummary(__file__, 0, DDPOptimizer),
traceback.FrameSummary(__file__, 0, "DDPOptimizer"),
],
)
@ -257,7 +265,7 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
assert isinstance(kwargs, dict)
if n.op == "call_module":
real_mod = self.fetch_attr(n.target)
real_mod = self.fetch_attr(str(n.target))
if self.fake_mode:
curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
else:
@ -287,10 +295,10 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
def __init__(self) -> None:
self.tc = torch._guards.TracingContext.try_get()
assert self.tc
torch._guards.TracingContext.try_get().fakify_first_call = True
self.tc.fakify_first_call = True
def __del__(self) -> None:
self.tc.fakify_first_call = False
self.tc.fakify_first_call = False # type: ignore[union-attr]
# For aot_eager and other backends, tracing context is not set
has_tracing_context = torch._guards.TracingContext.try_get() is not None
@ -308,9 +316,9 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
# We update the original (outer) graph with a call into the compiled module
# instead of the uncompiled one.
self.module.delete_submodule(n.target)
n.target = "compiled_" + n.target
self.module.add_submodule(n.target, compiled_submod_real)
self.module.delete_submodule(n.target) # type: ignore[operator]
n.target = "compiled_" + n.target # type: ignore[operator]
self.module.add_submodule(n.target, compiled_submod_real) # type: ignore[operator]
# Finally, we have to produce inputs for use compiling the next submodule,
# and these need to be FakeTensors, so we execute the module under fake_mode
@ -398,7 +406,7 @@ class DDPOptimizer:
def __init__(
self,
bucket_bytes_cap: int,
backend_compile_fn,
backend_compile_fn: CompilerFn,
first_bucket_cap: Optional[int] = None,
) -> None:
if first_bucket_cap is not None:
@ -416,21 +424,27 @@ class DDPOptimizer:
self.backend_compile_fn = backend_compile_fn
def _ignore_parameter(self, parameter):
def _ignore_parameter(self, parameter: torch.nn.Parameter) -> bool:
return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
def add_param(self, bucket, param, name):
def add_param(self, bucket: Bucket, param: torch.nn.Parameter, name: str) -> None:
bucket.size += param.untyped_storage().nbytes()
bucket.params.append(name)
bucket.param_ids.append(id(param))
def add_module_params_to_bucket(self, mod, bucket, processed_modules, prefix):
def add_module_params_to_bucket(
self,
mod: torch.nn.Module,
bucket: Bucket,
processed_modules: set[torch.nn.Module],
prefix: str,
) -> None:
processed_modules.add(mod)
for name, param in mod.named_parameters():
if param.requires_grad and not self._ignore_parameter(param):
self.add_param(bucket, param, f"{prefix}_{name}")
def add_param_args(self, bucket, node):
def add_param_args(self, bucket: Bucket, node: fx.Node) -> None:
for arg in node.args:
if not isinstance(arg, torch.fx.node.Node):
continue
@ -442,9 +456,11 @@ class DDPOptimizer:
and param.requires_grad
and not self._ignore_parameter(param)
):
self.add_param(bucket, param, arg.target)
self.add_param(bucket, param, str(arg.target))
def compile_fn(self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]):
def compile_fn(
self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]
) -> CompiledFn:
"""
Implements graph splitting, first determining a set of of buckets by counting
parameter sizes in reverse graph order, then invoking the user/backend compiler
@ -453,7 +469,7 @@ class DDPOptimizer:
"""
# 1: compute the partition map according to DDP bucket logic
buckets = [Bucket()] # (size, param_names)
processed_modules = set()
processed_modules: set[torch.nn.Module] = set()
for node in reversed(gm.graph.nodes):
if node.op in ("output", "placeholder"):
continue
@ -533,7 +549,9 @@ class DDPOptimizer:
partition_map[node] = idx
split_gm = fx.passes.split_module.split_module(
gm, None, lambda node: partition_map[node]
gm,
None, # type: ignore[arg-type]
lambda node: partition_map[node],
)
# See note [Assumption on Dynamo Metadata]

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
This module provides the TorchInductor backend integration for TorchDynamo.
@ -12,12 +10,14 @@ The inductor backend can be used with torch.compile():
model = torch.compile(model, backend="inductor")
"""
from typing import Any
from torch._dynamo import register_backend
from torch._dynamo.utils import dynamo_timed
@register_backend
def inductor(*args, **kwargs):
def inductor(*args: Any, **kwargs: Any) -> Any:
with dynamo_timed("inductor_import", log_pt2_compile_event=True):
# do import here to avoid loading inductor into memory when it is not used
# The AsyncCompile subproc pool can be slow to start, so warm it up as early

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
# This backend is maintained by ONNX team. To direct issues
# to the right people, please tag related GitHub issues with `module: onnx`.
#

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
This module implements TorchDynamo's backend registry system for managing compiler backends.
@ -65,7 +63,7 @@ import logging
import sys
from collections.abc import Sequence
from importlib.metadata import EntryPoint
from typing import Callable, Optional, Protocol
from typing import Any, Callable, Optional, Protocol, Union
import torch
from torch import fx
@ -88,7 +86,7 @@ def register_backend(
compiler_fn: Optional[CompilerFn] = None,
name: Optional[str] = None,
tags: Sequence[str] = (),
):
) -> Any:
"""
Decorator to add a given compiler to the registry to allow calling
`torch.compile` with string shorthand. Note: for projects not
@ -102,14 +100,14 @@ def register_backend(
"""
if compiler_fn is None:
# @register_backend(name="") syntax
return functools.partial(register_backend, name=name, tags=tags)
return functools.partial(register_backend, name=name, tags=tags) # type: ignore[return-value]
assert callable(compiler_fn)
name = name or compiler_fn.__name__
assert name not in _COMPILER_FNS, f"duplicate name: {name}"
if compiler_fn not in _BACKENDS:
_BACKENDS[name] = None
_COMPILER_FNS[name] = compiler_fn
compiler_fn._tags = tuple(tags)
compiler_fn._tags = tuple(tags) # type: ignore[attr-defined]
return compiler_fn
@ -119,7 +117,7 @@ register_experimental_backend = functools.partial(
)
def lookup_backend(compiler_fn):
def lookup_backend(compiler_fn: Union[str, CompilerFn]) -> CompilerFn:
"""Expand backend strings to functions"""
if isinstance(compiler_fn, str):
if compiler_fn not in _BACKENDS:
@ -131,31 +129,32 @@ def lookup_backend(compiler_fn):
if compiler_fn not in _COMPILER_FNS:
entry_point = _BACKENDS[compiler_fn]
register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
if entry_point is not None:
register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
compiler_fn = _COMPILER_FNS[compiler_fn]
return compiler_fn
def list_backends(exclude_tags=("debug", "experimental")) -> list[str]:
def list_backends(exclude_tags: Sequence[str] = ("debug", "experimental")) -> list[str]:
"""
Return valid strings that can be passed to:
torch.compile(..., backend="name")
"""
_lazy_import()
exclude_tags = set(exclude_tags or ())
exclude_tags_set = set(exclude_tags or ())
backends = [
name
for name in _BACKENDS.keys()
if name not in _COMPILER_FNS
or not exclude_tags.intersection(_COMPILER_FNS[name]._tags)
or not exclude_tags_set.intersection(_COMPILER_FNS[name]._tags) # type: ignore[attr-defined]
]
return sorted(backends)
@functools.cache
def _lazy_import():
def _lazy_import() -> None:
from .. import backends
from ..utils import import_submodule
@ -169,7 +168,7 @@ def _lazy_import():
@functools.cache
def _discover_entrypoint_backends():
def _discover_entrypoint_backends() -> None:
# importing here so it will pick up the mocked version in test_backends.py
from importlib.metadata import entry_points
@ -177,9 +176,9 @@ def _discover_entrypoint_backends():
if sys.version_info < (3, 10):
eps = entry_points()
eps = eps[group_name] if group_name in eps else []
eps = {ep.name: ep for ep in eps}
eps_dict = {ep.name: ep for ep in eps}
else:
eps = entry_points(group=group_name)
eps = {name: eps[name] for name in eps.names}
for backend_name in eps:
_BACKENDS[backend_name] = eps[backend_name]
eps_dict = {name: eps[name] for name in eps.names}
for backend_name in eps_dict:
_BACKENDS[backend_name] = eps_dict[backend_name]

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
# import torch # type: ignore[import]
# from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import]
# from .registry import register_backend # type: ignore[import]

View File

@ -1,26 +1,33 @@
# mypy: ignore-errors
import logging
from typing import Any, Callable
import torch
from functorch.compile import make_boxed_func
from torch import fx
from ..backends.common import aot_autograd
from .registry import register_backend, register_experimental_backend
from .registry import CompiledFn, register_backend, register_experimental_backend
log = logging.getLogger(__name__)
@register_experimental_backend
def openxla_eval(model, fake_tensor_inputs):
def openxla_eval(
model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
) -> CompiledFn:
return xla_backend_helper(model, fake_tensor_inputs, boxed=False)
def openxla_eval_boxed(model, fake_tensor_inputs):
def openxla_eval_boxed(
model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
) -> Callable[..., Any]:
return xla_backend_helper(model, fake_tensor_inputs, boxed=True)
def xla_backend_helper(model, fake_tensor_inputs, boxed=False):
def xla_backend_helper(
model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], boxed: bool = False
) -> Callable[..., Any]:
try:
import torch_xla.core.dynamo_bridge as bridge
except ImportError as e:
@ -30,7 +37,7 @@ def xla_backend_helper(model, fake_tensor_inputs, boxed=False):
compiled_graph = None
def fwd(*args):
def fwd(*args: torch.Tensor) -> Any:
nonlocal model
nonlocal compiled_graph
if compiled_graph is None:

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
This module provides TVM backend integration for TorchDynamo.
@ -29,9 +27,10 @@ import os
import sys
import tempfile
from types import MappingProxyType
from typing import Optional
from typing import Any, Callable, Optional
import torch
from torch import fx
from .common import device_from_inputs, fake_tensor_unsupported
from .registry import register_backend
@ -41,15 +40,16 @@ log = logging.getLogger(__name__)
@register_backend
@fake_tensor_unsupported
@fake_tensor_unsupported # type: ignore[arg-type]
def tvm(
gm,
example_inputs,
gm: fx.GraphModule,
example_inputs: list[torch.Tensor],
*,
options: Optional[MappingProxyType] = MappingProxyType(
{"scheduler": None, "trials": 20000, "opt_level": 3}
),
):
options: Optional[MappingProxyType[str, Any]] = None,
) -> Callable[..., Any]:
if options is None:
options = MappingProxyType({"scheduler": None, "trials": 20000, "opt_level": 3})
assert options is not None
import tvm # type: ignore[import]
from tvm import relay # type: ignore[import]
from tvm.contrib import graph_executor # type: ignore[import]
@ -147,7 +147,7 @@ def tvm(
)
m = graph_executor.GraphModule(lib["default"](dev))
def to_torch_tensor(nd_tensor):
def to_torch_tensor(nd_tensor: tvm.nd.array) -> torch.Tensor:
"""A helper function to transfer a NDArray to torch.tensor."""
if nd_tensor.dtype == "bool":
# DLPack does not support boolean so it can't be handled by
@ -156,7 +156,7 @@ def tvm(
return torch.from_numpy(nd_tensor.numpy())
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
def to_tvm_tensor(torch_tensor):
def to_tvm_tensor(torch_tensor: torch.Tensor) -> tvm.nd.array:
"""A helper function to transfer a torch.tensor to NDArray."""
if torch_tensor.dtype == torch.bool:
# same reason as above, fallback to numpy conversion which
@ -164,7 +164,7 @@ def tvm(
return tvm.nd.array(torch_tensor.cpu().numpy())
return tvm.nd.from_dlpack(torch_tensor)
def exec_tvm(*i_args):
def exec_tvm(*i_args: torch.Tensor) -> list[torch.Tensor]:
args = [a.contiguous() for a in i_args]
shape_info, _ = m.get_input_info()
active_inputs = {name for name, _ in shape_info.items()}
@ -193,7 +193,7 @@ tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule")
tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler")
def has_tvm():
def has_tvm() -> bool:
try:
importlib.import_module("tvm")
return True
@ -202,7 +202,7 @@ def has_tvm():
@functools.cache
def llvm_target():
def llvm_target() -> str:
if sys.platform == "linux":
cpuinfo = open("/proc/cpuinfo").read()
if "avx512" in cpuinfo:

View File

@ -12,7 +12,7 @@ from torch.utils._ordered_set import OrderedSet
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Sequence, Set as AbstractSet
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
@ -108,7 +108,8 @@ def format_default_skip_message(reason: str) -> str:
def get_mutation_stack_trace(
placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int]
placeholders: Sequence[PlaceholderInfo],
mutation_indices: Union[AbstractSet[int], Sequence[int]],
) -> str:
stack_trace: Optional[str] = ""

View File

@ -5,8 +5,8 @@ import logging
import operator
import types
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import ParamSpec, TypeAliasType, TypeVar
import torch
from torch._C import _fx_map_aggregate, _fx_map_arg, _NodeBase
@ -46,7 +46,7 @@ BaseArgumentTypes = Union[
]
base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined]
Target = Union[Callable[..., Any], str]
Target = TypeAliasType("Target", Union[Callable[..., Any], str])
Argument = Optional[
Union[