mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 14:54:55 +08:00
Compare commits
8 Commits
ciflow/tru
...
lucaskabel
| Author | SHA1 | Date | |
|---|---|---|---|
| 1207f9ab93 | |||
| fcfb6bab89 | |||
| 95bd114806 | |||
| ec68abdc38 | |||
| ee417d1806 | |||
| 4a8afeaffb | |||
| 90cba401a0 | |||
| 6e4c4d9e57 |
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module provides common utilities and base classes for TorchDynamo backends.
|
||||
|
||||
@ -21,6 +19,9 @@ optimization of both forward and backward passes.
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable
|
||||
from typing_extensions import ParamSpec, TypeVar
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -36,13 +37,18 @@ from torch.utils._python_dispatch import _disable_current_modes
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class AotAutograd:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.__name__ = "compiler_fn"
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs):
|
||||
def __call__(
|
||||
self, gm: torch.fx.GraphModule, example_inputs: Iterable[Any], **kwargs: Any
|
||||
) -> Callable[..., Any]:
|
||||
if kwargs:
|
||||
log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs)
|
||||
|
||||
@ -66,8 +72,8 @@ class AotAutograd:
|
||||
counters["aot_autograd"]["not_ok"] += 1
|
||||
return gm
|
||||
|
||||
def wrap_bw_compiler(bw_compiler_fn):
|
||||
def _wrapped_bw_compiler(*args, **kwargs):
|
||||
def wrap_bw_compiler(bw_compiler_fn: Callable[P, R]) -> Callable[..., R]:
|
||||
def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# Note [Wrapping bw_compiler in disable]
|
||||
# The two disables here:
|
||||
# - stop TorchDynamo from trying to compile the bw_compiler function itself
|
||||
@ -75,7 +81,7 @@ class AotAutograd:
|
||||
return disable(
|
||||
disable(
|
||||
bw_compiler_fn, reason="do not trace backward compiler function"
|
||||
)(*args, **kwargs),
|
||||
)(*args, **kwargs), # type: ignore[misc]
|
||||
reason="do not trace generated backwards pass",
|
||||
)
|
||||
|
||||
@ -99,7 +105,9 @@ class AotAutograd:
|
||||
# debug asserts slow down compile time noticeably,
|
||||
# So only default them on when the aot_eager backend is used.
|
||||
if self.kwargs.get("fw_compiler", None) == nop:
|
||||
patch_config = patch("functorch.compile.config.debug_assert", True)
|
||||
patch_config: contextlib.AbstractContextManager[Any] = patch(
|
||||
"functorch.compile.config.debug_assert", True
|
||||
)
|
||||
else:
|
||||
patch_config = contextlib.nullcontext()
|
||||
|
||||
@ -116,11 +124,11 @@ class AotAutograd:
|
||||
raise
|
||||
|
||||
|
||||
def aot_autograd(**kwargs) -> AotAutograd:
|
||||
def aot_autograd(**kwargs: Any) -> AotAutograd:
|
||||
return AotAutograd(**kwargs)
|
||||
|
||||
|
||||
def mem_efficient_fusion_kwargs(use_decomps):
|
||||
def mem_efficient_fusion_kwargs(use_decomps: bool) -> dict[str, Any]:
|
||||
from functorch.compile import (
|
||||
default_decompositions,
|
||||
min_cut_rematerialization_partition,
|
||||
@ -140,28 +148,30 @@ def mem_efficient_fusion_kwargs(use_decomps):
|
||||
return kwargs
|
||||
|
||||
|
||||
def fake_tensor_unsupported(fn):
|
||||
def fake_tensor_unsupported(fn: Callable[[Any, list[Any], Any], R]) -> Any:
|
||||
"""
|
||||
Decorator for backends that need real inputs. We swap out fake
|
||||
tensors for zero tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(model, inputs, **kwargs):
|
||||
def wrapper(model: Any, inputs: Any, **kwargs: Any) -> Any:
|
||||
with _disable_current_modes():
|
||||
inputs = list(map(defake, inputs))
|
||||
return fn(model, inputs, **kwargs)
|
||||
return fn(model, inputs, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def device_from_inputs(example_inputs) -> torch.device:
|
||||
def device_from_inputs(example_inputs: Iterable[Any]) -> torch.device:
|
||||
for x in example_inputs:
|
||||
if hasattr(x, "device"):
|
||||
return x.device
|
||||
return torch.device("cpu") # Default fallback
|
||||
|
||||
|
||||
def dtype_from_inputs(example_inputs) -> torch.dtype:
|
||||
def dtype_from_inputs(example_inputs: Iterable[Any]) -> torch.dtype:
|
||||
for x in example_inputs:
|
||||
if hasattr(x, "dtype"):
|
||||
return x.dtype
|
||||
return torch.float32 # Default fallback
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module implements CUDA graphs support for TorchDynamo backends.
|
||||
|
||||
@ -25,9 +23,11 @@ Key components:
|
||||
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.backends.debugging import boxed_nop
|
||||
@ -51,8 +51,8 @@ from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from .registry import register_backend
|
||||
|
||||
|
||||
def find_input_mutations(g):
|
||||
def meta_fk(meta):
|
||||
def find_input_mutations(g: torch.fx.Graph) -> set[int]:
|
||||
def meta_fk(meta: dict[str, Any]) -> Any:
|
||||
return meta["val"] if "val" in meta else meta["fake_result"]
|
||||
|
||||
inputs = defaultdict(set)
|
||||
@ -90,7 +90,9 @@ def find_input_mutations(g):
|
||||
return mutated_inputs
|
||||
|
||||
|
||||
def get_device_node_mapping(gm: torch.fx.GraphModule):
|
||||
def get_device_node_mapping(
|
||||
gm: torch.fx.GraphModule,
|
||||
) -> dict[torch.device, torch.fx.Node]:
|
||||
device_node_mapping: dict[torch.device, torch.fx.Node] = {}
|
||||
for n in gm.graph.nodes:
|
||||
t = n.meta.get("val", None)
|
||||
@ -100,7 +102,7 @@ def get_device_node_mapping(gm: torch.fx.GraphModule):
|
||||
|
||||
|
||||
def check_for_mutation_ignore_cuda_graph_managed_tensor(
|
||||
aot_model: torch.fx.GraphModule, num_fixed
|
||||
aot_model: torch.fx.GraphModule, num_fixed: int
|
||||
) -> Optional[str]:
|
||||
mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
|
||||
if not mutation_indices:
|
||||
@ -110,7 +112,7 @@ def check_for_mutation_ignore_cuda_graph_managed_tensor(
|
||||
return get_mutation_stack_trace(placeholders, mutation_indices)
|
||||
|
||||
|
||||
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
|
||||
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]:
|
||||
if not config.cudagraph_backend_support_input_mutation:
|
||||
if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
|
||||
aot_model, num_fixed
|
||||
@ -128,28 +130,35 @@ def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def get_device_index(gm) -> int:
|
||||
def get_device_index(gm: torch.fx.GraphModule) -> int:
|
||||
device = next(iter(get_device_node_mapping(gm)))
|
||||
assert device.type == "cuda"
|
||||
return device.index
|
||||
|
||||
|
||||
def get_stack_traces(gm) -> list[Optional[str]]:
|
||||
def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]:
|
||||
output = output_node(gm)
|
||||
assert len(output.args) == 1
|
||||
args = output.args[0]
|
||||
if not hasattr(args, "__iter__"):
|
||||
return []
|
||||
return [
|
||||
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
|
||||
for arg in output.args[0]
|
||||
for arg in args # type: ignore[union-attr]
|
||||
]
|
||||
|
||||
|
||||
def cudagraphs(dynamo_model, dynamo_inputs):
|
||||
def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any:
|
||||
from torch._inductor.cudagraph_trees import cudagraphify_impl
|
||||
|
||||
do_cudagraphs = BoxedBool(True)
|
||||
boxed_device_index = BoxedDeviceIndex(None)
|
||||
|
||||
def forward_cudagraphs(aot_model, aot_inputs, is_inference=False):
|
||||
def forward_cudagraphs(
|
||||
aot_model: torch.fx.GraphModule,
|
||||
aot_inputs: list[Any],
|
||||
is_inference: bool = False,
|
||||
) -> Any:
|
||||
interp = boxed_nop(aot_model, aot_inputs)
|
||||
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
|
||||
if skip_msg := check_for_skip(aot_model, fixed):
|
||||
@ -166,15 +175,17 @@ def cudagraphs(dynamo_model, dynamo_inputs):
|
||||
range(fixed),
|
||||
device_index=boxed_device_index.value,
|
||||
is_backward=False,
|
||||
is_inference=False,
|
||||
is_inference=False, # Q: should forward is_inference here?
|
||||
stack_traces=get_stack_traces(aot_model),
|
||||
placeholders=get_placeholder_info(aot_model.graph),
|
||||
mutated_input_idxs=find_input_mutations(aot_model.graph),
|
||||
)
|
||||
out._boxed_call = True
|
||||
out._boxed_call = True # type: ignore[attr-defined]
|
||||
return out
|
||||
|
||||
def backward_cudagraphs(aot_model, aot_inputs):
|
||||
def backward_cudagraphs(
|
||||
aot_model: torch.fx.GraphModule, aot_inputs: list[Any]
|
||||
) -> Any:
|
||||
interp = boxed_nop(aot_model, aot_inputs)
|
||||
if not do_cudagraphs:
|
||||
return aot_model
|
||||
@ -182,20 +193,23 @@ def cudagraphs(dynamo_model, dynamo_inputs):
|
||||
fixed = count_tangents(aot_model)
|
||||
if skip_msg := check_for_skip(aot_model, fixed):
|
||||
log_cudagraph_skip_and_bump_counter(
|
||||
"skipping cudagraphs due to %s", skip_msg
|
||||
f"skipping cudagraphs due to {skip_msg}"
|
||||
)
|
||||
|
||||
# See [Backward Generation Handling]
|
||||
device_idx = boxed_device_index.value
|
||||
if device_idx is None:
|
||||
device_idx = 0 # Default to device 0 if not set
|
||||
manager = torch._inductor.cudagraph_trees.get_manager(
|
||||
boxed_device_index.value, create_if_none_exists=False
|
||||
device_idx, create_if_none_exists=False
|
||||
)
|
||||
assert manager is not None
|
||||
|
||||
def fn(inputs):
|
||||
def fn(inputs: list[Any]) -> Any:
|
||||
manager.set_to_running_backward()
|
||||
return aot_model(inputs)
|
||||
|
||||
fn._boxed_call = True
|
||||
fn._boxed_call = True # type: ignore[attr-defined]
|
||||
return fn
|
||||
|
||||
out = cudagraphify_impl(
|
||||
@ -209,7 +223,7 @@ def cudagraphs(dynamo_model, dynamo_inputs):
|
||||
placeholders=get_placeholder_info(aot_model.graph),
|
||||
mutated_input_idxs=find_input_mutations(aot_model.graph),
|
||||
)
|
||||
out._boxed_call = True
|
||||
out._boxed_call = True # type: ignore[attr-defined]
|
||||
return out
|
||||
|
||||
aot_cudagraphs = aot_autograd(
|
||||
@ -225,13 +239,13 @@ class CudagraphsBackend:
|
||||
compiler_name = "cudagraphs"
|
||||
|
||||
@staticmethod
|
||||
def reset():
|
||||
def reset() -> None:
|
||||
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
||||
|
||||
reset_cudagraph_trees()
|
||||
|
||||
@staticmethod
|
||||
def __call__(model, inputs):
|
||||
def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any:
|
||||
return cudagraphs(model, inputs)
|
||||
|
||||
|
||||
@ -240,7 +254,12 @@ class CudagraphsBackend:
|
||||
register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
|
||||
|
||||
|
||||
def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
|
||||
def cudagraphs_inner(
|
||||
model: Callable[..., Any],
|
||||
inputs: Sequence[Any],
|
||||
copy_outputs: bool = True,
|
||||
copy_inputs: bool = True,
|
||||
) -> Callable[..., Sequence[Any]]:
|
||||
"""This isn't registered as a backend, but is used in some benchmarks"""
|
||||
assert isinstance(inputs, (list, tuple))
|
||||
if copy_inputs:
|
||||
@ -265,7 +284,7 @@ def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
|
||||
if not isinstance(static_outputs, (list, tuple)):
|
||||
static_outputs = (static_outputs,)
|
||||
|
||||
def run(*new_inputs):
|
||||
def run(*new_inputs: Any) -> Sequence[Any]:
|
||||
assert len(static_inputs) == len(new_inputs)
|
||||
if copy_inputs:
|
||||
for dst, src in zip(static_inputs, new_inputs):
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module provides debugging backends for TorchDynamo to help diagnose and troubleshoot
|
||||
compilation and execution issues. It includes:
|
||||
@ -28,40 +26,56 @@ These backends are primarily used for:
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from importlib import import_module
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch import _guards
|
||||
from torch._dynamo.output_graph import GraphCompileReason
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._functorch.compilers import ts_compile
|
||||
from torch.fx.node import Target
|
||||
|
||||
from .common import aot_autograd
|
||||
from .registry import register_debug_backend as register_backend
|
||||
from .registry import (
|
||||
CompiledFn,
|
||||
CompilerFn,
|
||||
lookup_backend,
|
||||
register_debug_backend as register_backend,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_backend
|
||||
def eager(gm, fake_tensor_inputs, **kwargs):
|
||||
def eager(
|
||||
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
||||
) -> Callable[..., Any]:
|
||||
if kwargs:
|
||||
log.warning("eager backend ignoring extra kwargs %s", kwargs)
|
||||
return gm.forward
|
||||
|
||||
|
||||
def make_eager_backend_with_torch_function_mode(mode):
|
||||
def make_eager_backend_with_torch_function_mode(
|
||||
mode: torch.overrides.TorchFunctionMode,
|
||||
) -> Callable[..., Any]:
|
||||
return make_eager_backend_with_torch_function_modes([mode])
|
||||
|
||||
|
||||
def make_eager_backend_with_torch_function_modes(modes):
|
||||
def make_eager_backend_with_torch_function_modes(
|
||||
modes: Iterable[torch.overrides.TorchFunctionMode],
|
||||
) -> Callable[..., Any]:
|
||||
"""Used to trace HOPs (cond and while) for eager execution, the metadata
|
||||
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
|
||||
in the HOP, so we need to externally run this mode and not trace it."""
|
||||
from contextlib import ExitStack
|
||||
|
||||
def fn(gm, fake_tensor_inputs, **kwargs):
|
||||
def fn(
|
||||
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
||||
) -> Callable[..., Any]:
|
||||
stack = ExitStack()
|
||||
for mode in modes:
|
||||
stack.enter_context(mode)
|
||||
@ -74,13 +88,15 @@ def make_eager_backend_with_torch_function_modes(modes):
|
||||
|
||||
|
||||
@register_backend
|
||||
def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
|
||||
def eager_noexcept(
|
||||
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
||||
) -> Callable[..., Any]:
|
||||
if kwargs:
|
||||
log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs)
|
||||
|
||||
# This backend is intended to check that dynamo-generated GraphModules
|
||||
# do not cause errors.
|
||||
def inner(*args):
|
||||
def inner(*args: Any) -> Any:
|
||||
try:
|
||||
return gm(*args)
|
||||
except Exception as e:
|
||||
@ -92,13 +108,15 @@ def eager_noexcept(gm, fake_tensor_inputs, **kwargs):
|
||||
|
||||
|
||||
@register_backend
|
||||
def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs):
|
||||
def pre_dispatch_eager(
|
||||
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
||||
) -> torch.fx.GraphModule:
|
||||
if kwargs:
|
||||
log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs)
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
def runnable_gm(*args):
|
||||
def runnable_gm(*args: Any) -> Any:
|
||||
return torch.fx.Interpreter(gm).run(*args)
|
||||
|
||||
pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
|
||||
@ -108,7 +126,9 @@ def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs):
|
||||
|
||||
|
||||
@register_backend
|
||||
def eager_debug(gm, fake_tensor_inputs, **kwargs):
|
||||
def eager_debug(
|
||||
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
||||
) -> Callable[..., Any]:
|
||||
if kwargs:
|
||||
log.warning("eager_debug backend ignoring extra kwargs %s", kwargs)
|
||||
|
||||
@ -117,42 +137,55 @@ def eager_debug(gm, fake_tensor_inputs, **kwargs):
|
||||
# We could add more debugging bits here.
|
||||
# Right now, this backend can be used to check for and error on
|
||||
# custom dispatcher ops that have incorrect schemas.
|
||||
def inner(*args):
|
||||
def inner(*args: Any) -> Any:
|
||||
with SchemaCheckMode():
|
||||
return torch.fx.Interpreter(gm).run(*args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@register_backend(name="ts")
|
||||
def torchscript(gm, fake_tensor_inputs):
|
||||
@register_backend(name="ts") # type: ignore[misc]
|
||||
def torchscript(
|
||||
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
|
||||
) -> torch.jit.ScriptModule:
|
||||
return torch.jit.script(gm)
|
||||
|
||||
|
||||
# used boxed call to discard inputs when they are no longer needed
|
||||
def boxed_nop(fx_g, example_inputs):
|
||||
def run(args):
|
||||
def boxed_nop(
|
||||
fx_g: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> Callable[..., Any]:
|
||||
def run(args: Any) -> Any:
|
||||
return torch.fx.Interpreter(fx_g).boxed_run(args)
|
||||
|
||||
run._boxed_call = True
|
||||
run._boxed_call = True # type: ignore[attr-defined]
|
||||
return run
|
||||
|
||||
|
||||
def boxed_nop_with_mode(fx_g, example_inputs, *, mode):
|
||||
def run(args):
|
||||
def boxed_nop_with_mode(
|
||||
fx_g: torch.fx.GraphModule,
|
||||
example_inputs: list[torch.Tensor],
|
||||
*,
|
||||
mode: torch.overrides.TorchFunctionMode,
|
||||
) -> Callable[..., Any]:
|
||||
def run(args: Any) -> Any:
|
||||
with mode:
|
||||
return torch.fx.Interpreter(fx_g).boxed_run(args)
|
||||
|
||||
run._boxed_call = True
|
||||
run._boxed_call = True # type: ignore[attr-defined]
|
||||
return run
|
||||
|
||||
|
||||
def fake_crossref_boxed_nop(fx_g, example_inputs, ignore_op_fn=None):
|
||||
def run(args):
|
||||
def fake_crossref_boxed_nop(
|
||||
fx_g: torch.fx.GraphModule,
|
||||
example_inputs: list[torch.Tensor],
|
||||
ignore_op_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None,
|
||||
) -> Callable[..., Any]:
|
||||
def run(args: Any) -> Any:
|
||||
with torch._subclasses.CrossRefFakeMode(ignore_op_fn):
|
||||
return torch.fx.Interpreter(fx_g).boxed_run(args)
|
||||
|
||||
run._boxed_call = True
|
||||
run._boxed_call = True # type: ignore[attr-defined]
|
||||
return run
|
||||
|
||||
|
||||
@ -160,7 +193,9 @@ def ignore_builtins(op: torch._ops.OpOverload) -> bool:
|
||||
return op.namespace in ("aten", "prims", "prim")
|
||||
|
||||
|
||||
def get_nop_func():
|
||||
def get_nop_func() -> Callable[
|
||||
[torch.fx.GraphModule, list[torch.Tensor]], Callable[..., Any]
|
||||
]:
|
||||
if not torch._functorch.config.fake_tensor_crossref:
|
||||
return boxed_nop
|
||||
elif torch._functorch.config.fake_tensor_crossref == "all":
|
||||
@ -173,12 +208,12 @@ def get_nop_func():
|
||||
# Useful for debugging purpose
|
||||
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
|
||||
def aot_eager(
|
||||
gm,
|
||||
fake_tensor_inputs,
|
||||
fw_compiler=None,
|
||||
bw_compiler=None,
|
||||
**kwargs,
|
||||
):
|
||||
gm: torch.fx.GraphModule,
|
||||
fake_tensor_inputs: list[torch.Tensor],
|
||||
fw_compiler: Optional[Callable[..., Any]] = None,
|
||||
bw_compiler: Optional[Callable[..., Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Callable[..., Any]:
|
||||
return aot_autograd(
|
||||
fw_compiler=fw_compiler or boxed_nop,
|
||||
bw_compiler=bw_compiler or boxed_nop,
|
||||
@ -201,7 +236,9 @@ register_backend(
|
||||
# inductor problems.
|
||||
# aot_eager_decomp_partition just replaces the inductor compiler with nop to help
|
||||
# isolate inductor vs aot_eager errors
|
||||
def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
|
||||
def aot_eager_decomp_partition(
|
||||
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
||||
) -> Callable[..., Any]:
|
||||
if kwargs:
|
||||
log.warning(
|
||||
"aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs
|
||||
@ -213,7 +250,7 @@ def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs):
|
||||
if bisect_changes := CompilerBisector.get_config_change(
|
||||
"aot_eager_decomp_partition"
|
||||
):
|
||||
config_patches.update(bisect_changes)
|
||||
config_patches.update(bisect_changes) # type: ignore[arg-type]
|
||||
|
||||
with functorch_config.patch(config_patches):
|
||||
return aot_autograd(
|
||||
@ -237,7 +274,12 @@ register_backend(
|
||||
|
||||
# aot_eager_decomp_partition_with_mode is similar as aot_eager_decomp_partition,
|
||||
# except that it takes a TorchDispatchMode mode and run the fw/bw in the mode
|
||||
def aot_eager_decomp_partition_with_mode(gm, fake_tensor_inputs, mode, **kwarg):
|
||||
def aot_eager_decomp_partition_with_mode(
|
||||
gm: torch.fx.GraphModule,
|
||||
fake_tensor_inputs: list[torch.Tensor],
|
||||
mode: Any,
|
||||
**kwarg: Any,
|
||||
) -> Callable[..., Any]:
|
||||
return aot_autograd(
|
||||
# these are taken from memory_efficient_fusion()
|
||||
fw_compiler=functools.partial(boxed_nop_with_mode, mode=mode),
|
||||
@ -254,11 +296,13 @@ def aot_eager_decomp_partition_with_mode(gm, fake_tensor_inputs, mode, **kwarg):
|
||||
|
||||
register_backend(
|
||||
name="aot_eager_decomp_partition_with_mode",
|
||||
compiler_fn=aot_eager_decomp_partition_with_mode,
|
||||
compiler_fn=aot_eager_decomp_partition_with_mode, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def aot_eager_decomp_partition_crossref(gm, fake_tensor_inputs, **kwargs):
|
||||
def aot_eager_decomp_partition_crossref(
|
||||
gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any
|
||||
) -> Callable[..., Any]:
|
||||
# if the config is set, respect it, otherwise only test custom_ops.
|
||||
# custom_op bad metas always manifest as an error whereas aten will only sometimes.
|
||||
# by default, use the less noisy option
|
||||
@ -296,7 +340,9 @@ class TestingOnlyCompileError(Exception):
|
||||
|
||||
|
||||
@register_backend
|
||||
def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
||||
def relu_compile_error_TESTING_ONLY(
|
||||
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> torch.fx.GraphModule:
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == torch.relu:
|
||||
raise ReluCompileError
|
||||
@ -304,7 +350,9 @@ def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
||||
|
||||
|
||||
@register_backend
|
||||
def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
||||
def relu_runtime_error_TESTING_ONLY(
|
||||
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> torch.fx.GraphModule:
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == torch.relu:
|
||||
node.target = torch._assert
|
||||
@ -314,7 +362,9 @@ def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
||||
|
||||
|
||||
@register_backend
|
||||
def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
||||
def relu_accuracy_error_TESTING_ONLY(
|
||||
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> torch.fx.GraphModule:
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == torch.relu:
|
||||
node.target = torch.add
|
||||
@ -325,7 +375,9 @@ def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
||||
|
||||
|
||||
@register_backend
|
||||
def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
||||
def non_leaf_compile_error_TESTING_ONLY(
|
||||
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> torch.fx.GraphModule:
|
||||
# Require at least one non-trivial thing in the graph,
|
||||
# see https://github.com/pytorch/pytorch/issues/102898
|
||||
for node in gm.graph.nodes:
|
||||
@ -349,11 +401,9 @@ class ExplainOutput:
|
||||
graphs: list[torch.fx.GraphModule]
|
||||
graph_count: int
|
||||
graph_break_count: int
|
||||
break_reasons: list[
|
||||
Any
|
||||
] # Type is GraphCompileReason but doesn't matter for this purpose
|
||||
break_reasons: list[GraphCompileReason]
|
||||
op_count: int
|
||||
ops_per_graph: Optional[list[torch.fx.Node]] = None
|
||||
ops_per_graph: Optional[list[Target]] = None
|
||||
out_guards: Optional[list[_guards.Guard]] = None
|
||||
compile_times: Optional[str] = None
|
||||
|
||||
@ -374,7 +424,7 @@ class ExplainOutput:
|
||||
output += "Ops per Graph:\n"
|
||||
for idx, ops in enumerate(self.ops_per_graph):
|
||||
output += f" Ops {idx + 1}:\n"
|
||||
for op in ops:
|
||||
for op in ops: # type: ignore[union-attr]
|
||||
output += f" {op}\n"
|
||||
|
||||
if self.out_guards is not None:
|
||||
@ -389,8 +439,18 @@ class ExplainOutput:
|
||||
|
||||
|
||||
def _explain_graph_detail(
|
||||
gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons
|
||||
):
|
||||
gm: torch.fx.GraphModule,
|
||||
graphs: list[torch.fx.GraphModule],
|
||||
op_count: int,
|
||||
ops_per_graph: list[list[Target]],
|
||||
break_reasons: list[GraphCompileReason],
|
||||
) -> tuple[
|
||||
torch.fx.GraphModule,
|
||||
list[torch.fx.GraphModule],
|
||||
int,
|
||||
list[list[Target]],
|
||||
list[GraphCompileReason],
|
||||
]:
|
||||
"""
|
||||
This function is a utility which processes a torch.fx.GraphModule and
|
||||
accumulates information about its ops, graph breaks, and other details. It
|
||||
@ -412,8 +472,8 @@ def _explain_graph_detail(
|
||||
ops = [node.target for node in gm.graph.nodes if node.op == "call_function"]
|
||||
op_count += len(ops)
|
||||
ops_per_graph.append(ops)
|
||||
if gm.compile_subgraph_reason.graph_break:
|
||||
break_reasons.append(gm.compile_subgraph_reason)
|
||||
if gm.compile_subgraph_reason.graph_break: # type: ignore[union-attr]
|
||||
break_reasons.append(gm.compile_subgraph_reason) # type: ignore[arg-type]
|
||||
|
||||
return gm, graphs, op_count, ops_per_graph, break_reasons
|
||||
|
||||
@ -443,17 +503,18 @@ class ExplainWithBackend:
|
||||
print(eb.output())
|
||||
"""
|
||||
|
||||
def __init__(self, backend) -> None:
|
||||
from .registry import lookup_backend
|
||||
|
||||
def __init__(self, backend: Union[CompilerFn, str]) -> None:
|
||||
self.backend = lookup_backend(backend)
|
||||
self.graphs = []
|
||||
self.graphs: list[torch.fx.GraphModule] = []
|
||||
self.op_count = 0
|
||||
self.break_reasons = []
|
||||
self.break_reasons: list[GraphCompileReason] = []
|
||||
|
||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
||||
def __call__(
|
||||
self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> CompiledFn:
|
||||
ops_per_graph: list[list[Target]] = []
|
||||
gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail(
|
||||
gm, self.graphs, self.op_count, [], self.break_reasons
|
||||
gm, self.graphs, self.op_count, ops_per_graph, self.break_reasons
|
||||
)
|
||||
return self.backend(gm, example_inputs)
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module implements distributed training optimizations for TorchDynamo backends.
|
||||
|
||||
@ -21,11 +19,12 @@ of compilation.
|
||||
import logging
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._dynamo.backends.registry import CompiledFn, CompilerFn
|
||||
from torch._dynamo.output_graph import GraphCompileReason
|
||||
from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
|
||||
from torch._logging import trace_structured
|
||||
@ -39,7 +38,7 @@ log = logging.getLogger(__name__)
|
||||
ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")
|
||||
|
||||
|
||||
def args_str(args):
|
||||
def args_str(args: Any) -> str:
|
||||
# a debug helper
|
||||
if torch.is_tensor(args):
|
||||
return f"T[{args.shape}]"
|
||||
@ -58,7 +57,7 @@ class Bucket:
|
||||
nodes: list[fx.Node] = field(default_factory=list)
|
||||
|
||||
# param_ids is just used for unit testing
|
||||
param_ids: list = field(default_factory=list)
|
||||
param_ids: list[int] = field(default_factory=list)
|
||||
|
||||
# keep track of any buckets that were extended for logging purposes
|
||||
opcount_increased_to_capture_external_output: int = 0
|
||||
@ -78,9 +77,9 @@ def bucket_has_external_output(bucket: Bucket) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int):
|
||||
def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int) -> None:
|
||||
headers = ("Index", "Size (b)", "Param Names")
|
||||
rows = []
|
||||
rows: list[tuple[Optional[int], Optional[int], str]] = []
|
||||
extended_buckets = []
|
||||
for idx, bucket in enumerate(reversed(buckets)):
|
||||
if len(bucket.params) > 0:
|
||||
@ -136,7 +135,7 @@ def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int):
|
||||
log.debug("DDPOptimizer captured no parameters and did not split this graph.")
|
||||
|
||||
|
||||
def has_higher_order_op(gm):
|
||||
def has_higher_order_op(gm: fx.GraphModule) -> bool:
|
||||
# Check if there is a higher order op in the graph
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "get_attr":
|
||||
@ -146,7 +145,7 @@ def has_higher_order_op(gm):
|
||||
return False
|
||||
|
||||
|
||||
def propagate_metadata(orig_gm, split_gm) -> None:
|
||||
def propagate_metadata(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
|
||||
for name, module in split_gm.named_modules():
|
||||
if "." not in name and len(name):
|
||||
# TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384
|
||||
@ -154,7 +153,7 @@ def propagate_metadata(orig_gm, split_gm) -> None:
|
||||
module._param_name_to_source = orig_gm._param_name_to_source
|
||||
|
||||
|
||||
def propagate_dynamo_source(orig_gm, split_gm) -> None:
|
||||
def propagate_dynamo_source(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
|
||||
name_to_dynamo_source = {}
|
||||
for node in orig_gm.graph.find_nodes(op="placeholder"):
|
||||
name_to_dynamo_source[node.name] = node._dynamo_source
|
||||
@ -168,12 +167,19 @@ def propagate_dynamo_source(orig_gm, split_gm) -> None:
|
||||
|
||||
# compile each of the partitioned submodules using the user-provided compiler
|
||||
class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
||||
def __init__(self, module, compiler, fake_mode) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
module: fx.GraphModule,
|
||||
compiler: CompilerFn,
|
||||
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
|
||||
) -> None:
|
||||
super().__init__(module)
|
||||
self.compiler = compiler
|
||||
self.fake_mode = fake_mode
|
||||
|
||||
def compile_submod(self, input_mod, args, kwargs):
|
||||
def compile_submod(
|
||||
self, input_mod: fx.GraphModule, args: list[torch.Tensor], kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Compile the submodule,
|
||||
using a wrapper to make sure its output is always a tuple,
|
||||
@ -182,12 +188,14 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
||||
assert len(kwargs) == 0, "We assume only args for these modules"
|
||||
|
||||
class WrapperModule(torch.nn.Module):
|
||||
def __init__(self, submod, unwrap_singleton_tuple) -> None:
|
||||
def __init__(
|
||||
self, submod: Callable[..., Any], unwrap_singleton_tuple: bool
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.submod = submod
|
||||
self.unwrap_singleton_tuple = unwrap_singleton_tuple
|
||||
|
||||
def forward(self, *args):
|
||||
def forward(self, *args: Any) -> Any:
|
||||
x = self.submod(*args)
|
||||
# TODO(whc)
|
||||
# for some reason the isinstance check is necessary if I split one node per submod
|
||||
@ -205,12 +213,12 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
||||
sn.args = (sn.args,)
|
||||
|
||||
input_mod.recompile()
|
||||
input_mod.compile_subgraph_reason = GraphCompileReason(
|
||||
input_mod.compile_subgraph_reason = GraphCompileReason( # type: ignore[assignment]
|
||||
"DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
|
||||
" Set `torch._dynamo.config.optimize_ddp = False` to disable.",
|
||||
[
|
||||
# it's close to useless to get a real stacktrace here, and quite verbose.
|
||||
traceback.FrameSummary(__file__, 0, DDPOptimizer),
|
||||
traceback.FrameSummary(__file__, 0, "DDPOptimizer"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -257,7 +265,7 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
||||
assert isinstance(kwargs, dict)
|
||||
|
||||
if n.op == "call_module":
|
||||
real_mod = self.fetch_attr(n.target)
|
||||
real_mod = self.fetch_attr(str(n.target))
|
||||
if self.fake_mode:
|
||||
curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
|
||||
else:
|
||||
@ -287,10 +295,10 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
||||
def __init__(self) -> None:
|
||||
self.tc = torch._guards.TracingContext.try_get()
|
||||
assert self.tc
|
||||
torch._guards.TracingContext.try_get().fakify_first_call = True
|
||||
self.tc.fakify_first_call = True
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.tc.fakify_first_call = False
|
||||
self.tc.fakify_first_call = False # type: ignore[union-attr]
|
||||
|
||||
# For aot_eager and other backends, tracing context is not set
|
||||
has_tracing_context = torch._guards.TracingContext.try_get() is not None
|
||||
@ -308,9 +316,9 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
||||
|
||||
# We update the original (outer) graph with a call into the compiled module
|
||||
# instead of the uncompiled one.
|
||||
self.module.delete_submodule(n.target)
|
||||
n.target = "compiled_" + n.target
|
||||
self.module.add_submodule(n.target, compiled_submod_real)
|
||||
self.module.delete_submodule(n.target) # type: ignore[operator]
|
||||
n.target = "compiled_" + n.target # type: ignore[operator]
|
||||
self.module.add_submodule(n.target, compiled_submod_real) # type: ignore[operator]
|
||||
|
||||
# Finally, we have to produce inputs for use compiling the next submodule,
|
||||
# and these need to be FakeTensors, so we execute the module under fake_mode
|
||||
@ -398,7 +406,7 @@ class DDPOptimizer:
|
||||
def __init__(
|
||||
self,
|
||||
bucket_bytes_cap: int,
|
||||
backend_compile_fn,
|
||||
backend_compile_fn: CompilerFn,
|
||||
first_bucket_cap: Optional[int] = None,
|
||||
) -> None:
|
||||
if first_bucket_cap is not None:
|
||||
@ -416,21 +424,27 @@ class DDPOptimizer:
|
||||
|
||||
self.backend_compile_fn = backend_compile_fn
|
||||
|
||||
def _ignore_parameter(self, parameter):
|
||||
def _ignore_parameter(self, parameter: torch.nn.Parameter) -> bool:
|
||||
return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
|
||||
|
||||
def add_param(self, bucket, param, name):
|
||||
def add_param(self, bucket: Bucket, param: torch.nn.Parameter, name: str) -> None:
|
||||
bucket.size += param.untyped_storage().nbytes()
|
||||
bucket.params.append(name)
|
||||
bucket.param_ids.append(id(param))
|
||||
|
||||
def add_module_params_to_bucket(self, mod, bucket, processed_modules, prefix):
|
||||
def add_module_params_to_bucket(
|
||||
self,
|
||||
mod: torch.nn.Module,
|
||||
bucket: Bucket,
|
||||
processed_modules: set[torch.nn.Module],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
processed_modules.add(mod)
|
||||
for name, param in mod.named_parameters():
|
||||
if param.requires_grad and not self._ignore_parameter(param):
|
||||
self.add_param(bucket, param, f"{prefix}_{name}")
|
||||
|
||||
def add_param_args(self, bucket, node):
|
||||
def add_param_args(self, bucket: Bucket, node: fx.Node) -> None:
|
||||
for arg in node.args:
|
||||
if not isinstance(arg, torch.fx.node.Node):
|
||||
continue
|
||||
@ -442,9 +456,11 @@ class DDPOptimizer:
|
||||
and param.requires_grad
|
||||
and not self._ignore_parameter(param)
|
||||
):
|
||||
self.add_param(bucket, param, arg.target)
|
||||
self.add_param(bucket, param, str(arg.target))
|
||||
|
||||
def compile_fn(self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]):
|
||||
def compile_fn(
|
||||
self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> CompiledFn:
|
||||
"""
|
||||
Implements graph splitting, first determining a set of of buckets by counting
|
||||
parameter sizes in reverse graph order, then invoking the user/backend compiler
|
||||
@ -453,7 +469,7 @@ class DDPOptimizer:
|
||||
"""
|
||||
# 1: compute the partition map according to DDP bucket logic
|
||||
buckets = [Bucket()] # (size, param_names)
|
||||
processed_modules = set()
|
||||
processed_modules: set[torch.nn.Module] = set()
|
||||
for node in reversed(gm.graph.nodes):
|
||||
if node.op in ("output", "placeholder"):
|
||||
continue
|
||||
@ -533,7 +549,9 @@ class DDPOptimizer:
|
||||
partition_map[node] = idx
|
||||
|
||||
split_gm = fx.passes.split_module.split_module(
|
||||
gm, None, lambda node: partition_map[node]
|
||||
gm,
|
||||
None, # type: ignore[arg-type]
|
||||
lambda node: partition_map[node],
|
||||
)
|
||||
|
||||
# See note [Assumption on Dynamo Metadata]
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module provides the TorchInductor backend integration for TorchDynamo.
|
||||
|
||||
@ -12,12 +10,14 @@ The inductor backend can be used with torch.compile():
|
||||
model = torch.compile(model, backend="inductor")
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from torch._dynamo import register_backend
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
|
||||
|
||||
@register_backend
|
||||
def inductor(*args, **kwargs):
|
||||
def inductor(*args: Any, **kwargs: Any) -> Any:
|
||||
with dynamo_timed("inductor_import", log_pt2_compile_event=True):
|
||||
# do import here to avoid loading inductor into memory when it is not used
|
||||
# The AsyncCompile subproc pool can be slow to start, so warm it up as early
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# This backend is maintained by ONNX team. To direct issues
|
||||
# to the right people, please tag related GitHub issues with `module: onnx`.
|
||||
#
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module implements TorchDynamo's backend registry system for managing compiler backends.
|
||||
|
||||
@ -65,7 +63,7 @@ import logging
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from importlib.metadata import EntryPoint
|
||||
from typing import Callable, Optional, Protocol
|
||||
from typing import Any, Callable, Optional, Protocol, Union
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
@ -88,7 +86,7 @@ def register_backend(
|
||||
compiler_fn: Optional[CompilerFn] = None,
|
||||
name: Optional[str] = None,
|
||||
tags: Sequence[str] = (),
|
||||
):
|
||||
) -> Any:
|
||||
"""
|
||||
Decorator to add a given compiler to the registry to allow calling
|
||||
`torch.compile` with string shorthand. Note: for projects not
|
||||
@ -102,14 +100,14 @@ def register_backend(
|
||||
"""
|
||||
if compiler_fn is None:
|
||||
# @register_backend(name="") syntax
|
||||
return functools.partial(register_backend, name=name, tags=tags)
|
||||
return functools.partial(register_backend, name=name, tags=tags) # type: ignore[return-value]
|
||||
assert callable(compiler_fn)
|
||||
name = name or compiler_fn.__name__
|
||||
assert name not in _COMPILER_FNS, f"duplicate name: {name}"
|
||||
if compiler_fn not in _BACKENDS:
|
||||
_BACKENDS[name] = None
|
||||
_COMPILER_FNS[name] = compiler_fn
|
||||
compiler_fn._tags = tuple(tags)
|
||||
compiler_fn._tags = tuple(tags) # type: ignore[attr-defined]
|
||||
return compiler_fn
|
||||
|
||||
|
||||
@ -119,7 +117,7 @@ register_experimental_backend = functools.partial(
|
||||
)
|
||||
|
||||
|
||||
def lookup_backend(compiler_fn):
|
||||
def lookup_backend(compiler_fn: Union[str, CompilerFn]) -> CompilerFn:
|
||||
"""Expand backend strings to functions"""
|
||||
if isinstance(compiler_fn, str):
|
||||
if compiler_fn not in _BACKENDS:
|
||||
@ -131,31 +129,32 @@ def lookup_backend(compiler_fn):
|
||||
|
||||
if compiler_fn not in _COMPILER_FNS:
|
||||
entry_point = _BACKENDS[compiler_fn]
|
||||
register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
|
||||
if entry_point is not None:
|
||||
register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
|
||||
compiler_fn = _COMPILER_FNS[compiler_fn]
|
||||
return compiler_fn
|
||||
|
||||
|
||||
def list_backends(exclude_tags=("debug", "experimental")) -> list[str]:
|
||||
def list_backends(exclude_tags: Sequence[str] = ("debug", "experimental")) -> list[str]:
|
||||
"""
|
||||
Return valid strings that can be passed to:
|
||||
|
||||
torch.compile(..., backend="name")
|
||||
"""
|
||||
_lazy_import()
|
||||
exclude_tags = set(exclude_tags or ())
|
||||
exclude_tags_set = set(exclude_tags or ())
|
||||
|
||||
backends = [
|
||||
name
|
||||
for name in _BACKENDS.keys()
|
||||
if name not in _COMPILER_FNS
|
||||
or not exclude_tags.intersection(_COMPILER_FNS[name]._tags)
|
||||
or not exclude_tags_set.intersection(_COMPILER_FNS[name]._tags) # type: ignore[attr-defined]
|
||||
]
|
||||
return sorted(backends)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _lazy_import():
|
||||
def _lazy_import() -> None:
|
||||
from .. import backends
|
||||
from ..utils import import_submodule
|
||||
|
||||
@ -169,7 +168,7 @@ def _lazy_import():
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _discover_entrypoint_backends():
|
||||
def _discover_entrypoint_backends() -> None:
|
||||
# importing here so it will pick up the mocked version in test_backends.py
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
@ -177,9 +176,9 @@ def _discover_entrypoint_backends():
|
||||
if sys.version_info < (3, 10):
|
||||
eps = entry_points()
|
||||
eps = eps[group_name] if group_name in eps else []
|
||||
eps = {ep.name: ep for ep in eps}
|
||||
eps_dict = {ep.name: ep for ep in eps}
|
||||
else:
|
||||
eps = entry_points(group=group_name)
|
||||
eps = {name: eps[name] for name in eps.names}
|
||||
for backend_name in eps:
|
||||
_BACKENDS[backend_name] = eps[backend_name]
|
||||
eps_dict = {name: eps[name] for name in eps.names}
|
||||
for backend_name in eps_dict:
|
||||
_BACKENDS[backend_name] = eps_dict[backend_name]
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# import torch # type: ignore[import]
|
||||
# from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import]
|
||||
# from .registry import register_backend # type: ignore[import]
|
||||
|
||||
@ -1,26 +1,33 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from functorch.compile import make_boxed_func
|
||||
from torch import fx
|
||||
|
||||
from ..backends.common import aot_autograd
|
||||
from .registry import register_backend, register_experimental_backend
|
||||
from .registry import CompiledFn, register_backend, register_experimental_backend
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_experimental_backend
|
||||
def openxla_eval(model, fake_tensor_inputs):
|
||||
def openxla_eval(
|
||||
model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
|
||||
) -> CompiledFn:
|
||||
return xla_backend_helper(model, fake_tensor_inputs, boxed=False)
|
||||
|
||||
|
||||
def openxla_eval_boxed(model, fake_tensor_inputs):
|
||||
def openxla_eval_boxed(
|
||||
model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
|
||||
) -> Callable[..., Any]:
|
||||
return xla_backend_helper(model, fake_tensor_inputs, boxed=True)
|
||||
|
||||
|
||||
def xla_backend_helper(model, fake_tensor_inputs, boxed=False):
|
||||
def xla_backend_helper(
|
||||
model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], boxed: bool = False
|
||||
) -> Callable[..., Any]:
|
||||
try:
|
||||
import torch_xla.core.dynamo_bridge as bridge
|
||||
except ImportError as e:
|
||||
@ -30,7 +37,7 @@ def xla_backend_helper(model, fake_tensor_inputs, boxed=False):
|
||||
|
||||
compiled_graph = None
|
||||
|
||||
def fwd(*args):
|
||||
def fwd(*args: torch.Tensor) -> Any:
|
||||
nonlocal model
|
||||
nonlocal compiled_graph
|
||||
if compiled_graph is None:
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module provides TVM backend integration for TorchDynamo.
|
||||
|
||||
@ -29,9 +27,10 @@ import os
|
||||
import sys
|
||||
import tempfile
|
||||
from types import MappingProxyType
|
||||
from typing import Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
|
||||
from .common import device_from_inputs, fake_tensor_unsupported
|
||||
from .registry import register_backend
|
||||
@ -41,15 +40,16 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_backend
|
||||
@fake_tensor_unsupported
|
||||
@fake_tensor_unsupported # type: ignore[arg-type]
|
||||
def tvm(
|
||||
gm,
|
||||
example_inputs,
|
||||
gm: fx.GraphModule,
|
||||
example_inputs: list[torch.Tensor],
|
||||
*,
|
||||
options: Optional[MappingProxyType] = MappingProxyType(
|
||||
{"scheduler": None, "trials": 20000, "opt_level": 3}
|
||||
),
|
||||
):
|
||||
options: Optional[MappingProxyType[str, Any]] = None,
|
||||
) -> Callable[..., Any]:
|
||||
if options is None:
|
||||
options = MappingProxyType({"scheduler": None, "trials": 20000, "opt_level": 3})
|
||||
assert options is not None
|
||||
import tvm # type: ignore[import]
|
||||
from tvm import relay # type: ignore[import]
|
||||
from tvm.contrib import graph_executor # type: ignore[import]
|
||||
@ -147,7 +147,7 @@ def tvm(
|
||||
)
|
||||
m = graph_executor.GraphModule(lib["default"](dev))
|
||||
|
||||
def to_torch_tensor(nd_tensor):
|
||||
def to_torch_tensor(nd_tensor: tvm.nd.array) -> torch.Tensor:
|
||||
"""A helper function to transfer a NDArray to torch.tensor."""
|
||||
if nd_tensor.dtype == "bool":
|
||||
# DLPack does not support boolean so it can't be handled by
|
||||
@ -156,7 +156,7 @@ def tvm(
|
||||
return torch.from_numpy(nd_tensor.numpy())
|
||||
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
|
||||
|
||||
def to_tvm_tensor(torch_tensor):
|
||||
def to_tvm_tensor(torch_tensor: torch.Tensor) -> tvm.nd.array:
|
||||
"""A helper function to transfer a torch.tensor to NDArray."""
|
||||
if torch_tensor.dtype == torch.bool:
|
||||
# same reason as above, fallback to numpy conversion which
|
||||
@ -164,7 +164,7 @@ def tvm(
|
||||
return tvm.nd.array(torch_tensor.cpu().numpy())
|
||||
return tvm.nd.from_dlpack(torch_tensor)
|
||||
|
||||
def exec_tvm(*i_args):
|
||||
def exec_tvm(*i_args: torch.Tensor) -> list[torch.Tensor]:
|
||||
args = [a.contiguous() for a in i_args]
|
||||
shape_info, _ = m.get_input_info()
|
||||
active_inputs = {name for name, _ in shape_info.items()}
|
||||
@ -193,7 +193,7 @@ tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule")
|
||||
tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler")
|
||||
|
||||
|
||||
def has_tvm():
|
||||
def has_tvm() -> bool:
|
||||
try:
|
||||
importlib.import_module("tvm")
|
||||
return True
|
||||
@ -202,7 +202,7 @@ def has_tvm():
|
||||
|
||||
|
||||
@functools.cache
|
||||
def llvm_target():
|
||||
def llvm_target() -> str:
|
||||
if sys.platform == "linux":
|
||||
cpuinfo = open("/proc/cpuinfo").read()
|
||||
if "avx512" in cpuinfo:
|
||||
|
||||
@ -12,7 +12,7 @@ from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Sequence, Set as AbstractSet
|
||||
|
||||
|
||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
@ -108,7 +108,8 @@ def format_default_skip_message(reason: str) -> str:
|
||||
|
||||
|
||||
def get_mutation_stack_trace(
|
||||
placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int]
|
||||
placeholders: Sequence[PlaceholderInfo],
|
||||
mutation_indices: Union[AbstractSet[int], Sequence[int]],
|
||||
) -> str:
|
||||
stack_trace: Optional[str] = ""
|
||||
|
||||
|
||||
@ -5,8 +5,8 @@ import logging
|
||||
import operator
|
||||
import types
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import ParamSpec, TypeAliasType, TypeVar
|
||||
|
||||
import torch
|
||||
from torch._C import _fx_map_aggregate, _fx_map_arg, _NodeBase
|
||||
@ -46,7 +46,7 @@ BaseArgumentTypes = Union[
|
||||
]
|
||||
base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined]
|
||||
|
||||
Target = Union[Callable[..., Any], str]
|
||||
Target = TypeAliasType("Target", Union[Callable[..., Any], str])
|
||||
|
||||
Argument = Optional[
|
||||
Union[
|
||||
|
||||
Reference in New Issue
Block a user