Optimize dynamo typing (#147499)

Optimize dynamo methods type annotation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147499
Approved by: https://github.com/anijain2305
This commit is contained in:
zeshengzong
2025-08-25 13:20:40 +00:00
committed by PyTorch MergeBot
parent ab7787fb82
commit 510825e5fe
6 changed files with 72 additions and 52 deletions

View File

@ -16,6 +16,8 @@ Key classes:
- BuckTargetWriter: Manages Buck build system integration
"""
from __future__ import annotations
import atexit
import copy
import cProfile
@ -31,9 +33,8 @@ import sys
import tempfile
import textwrap
from collections import Counter
from collections.abc import Sequence
from importlib import import_module
from typing import Any, Callable, Optional, TypeVar
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
import torch
import torch._prims_common as utils
@ -42,15 +43,20 @@ from torch import Tensor
from torch._dynamo.testing import rand_strided
from torch._inductor.cpp_builder import normalize_path_separator
from torch._prims_common import is_float_dtype
from torch.hub import tqdm
from torch.multiprocessing.reductions import StorageWeakRef
from torch.storage import UntypedStorage
from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
from . import config
from .utils import clone_inputs, get_debug_dir
if TYPE_CHECKING:
from collections.abc import Sequence
from torch.hub import tqdm
from torch.storage import UntypedStorage
log = logging.getLogger(__name__)
T = TypeVar("T")
@ -534,10 +540,10 @@ def backend_accuracy_fails(
def _stride_or_default(
stride: Optional["torch._prims_common.StrideType"],
stride: Optional[torch._prims_common.StrideType],
*,
shape: "torch._prims_common.ShapeType",
) -> "torch._prims_common.StrideType":
shape: torch._prims_common.ShapeType,
) -> torch._prims_common.StrideType:
return stride if stride is not None else utils.make_contiguous_strides_for(shape)
@ -561,7 +567,7 @@ class NopInputReader:
storage_hash: Optional[str],
nbytes: int,
*,
device: Optional["torch._prims_common.DeviceLikeType"] = None,
device: Optional[torch._prims_common.DeviceLikeType] = None,
dtype_hint: Optional[torch.dtype] = None,
) -> None:
self.total += 1
@ -592,7 +598,7 @@ class InputReader:
storage_hash: Optional[str],
nbytes: int,
*,
device: Optional["torch._prims_common.DeviceLikeType"] = None,
device: Optional[torch._prims_common.DeviceLikeType] = None,
dtype_hint: Optional[torch.dtype] = None,
) -> UntypedStorage:
if self.pbar is not None:
@ -619,8 +625,8 @@ class InputReader:
def tensor(
self,
storage: UntypedStorage,
shape: "torch._prims_common.ShapeType",
stride: Optional["torch._prims_common.StrideType"] = None,
shape: torch._prims_common.ShapeType,
stride: Optional[torch._prims_common.StrideType] = None,
*,
storage_offset: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
@ -698,7 +704,7 @@ class InputWriter:
self,
untyped_storage: UntypedStorage,
*,
device_hint: Optional["torch._prims_common.DeviceLikeType"] = None,
device_hint: Optional[torch._prims_common.DeviceLikeType] = None,
dtype_hint: Optional[torch.dtype] = None,
) -> str:
ws = StorageWeakRef(untyped_storage)
@ -841,9 +847,7 @@ def aot_graph_input_parser(
)
return sym_shapes_dict.get(symint, default_sym_shape) # type: ignore[return-value]
def gen_tensor(
shape: "torch._prims_common.ShapeType", dtype: torch.dtype
) -> Tensor:
def gen_tensor(shape: torch._prims_common.ShapeType, dtype: torch.dtype) -> Tensor:
# Resolve symbolic shapes to concrete values
resolved_shape = []
dynamic_dims = []

View File

@ -13,6 +13,8 @@ mappings between nodes and their duplicates, enabling efficient graph analysis a
optimization operations.
"""
from __future__ import annotations
import copyreg
import io
import logging
@ -163,7 +165,7 @@ class BackwardBfsArgIter:
self._queue: deque[Optional[Node]] = deque()
@staticmethod
def create(origin: Node) -> "BackwardBfsArgIter":
def create(origin: Node) -> BackwardBfsArgIter:
it = BackwardBfsArgIter(origin)
it.add_children(origin)
# pop the origin node, since it is the origin of
@ -238,7 +240,7 @@ class GraphRegionTracker:
and n0 is not n1
)
def track_node(self, tx: "InstructionTranslatorBase", node: Node) -> None:
def track_node(self, tx: InstructionTranslatorBase, node: Node) -> None:
"""
The main entry point for tracking a node. This function will hash the node argument and group
nodes with the same hash together. It updates the hash_to_duplicates and node_to_duplicates dictionaries

View File

@ -13,12 +13,17 @@ The metrics system enables comprehensive monitoring and analysis of both compila
execution performance.
"""
from __future__ import annotations
import heapq
import logging
import time
from collections.abc import Iterator
from typing import Any, Callable, Optional
from typing_extensions import TypeAlias
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing_extensions import Self, TypeAlias
if TYPE_CHECKING:
from collections.abc import Iterator
from torch.utils._traceback import CapturedTraceback
@ -67,7 +72,7 @@ class MetricsContext:
self._level: int = 0
self._edits: list[tuple[CapturedTraceback, set[str]]] = []
def __enter__(self) -> "MetricsContext":
def __enter__(self) -> Self:
"""
Initialize metrics recording.
"""

View File

@ -12,6 +12,8 @@ The profiler helps measure and optimize the performance of Dynamo-compiled code
by tracking both captured and total operations, timing, and graph statistics.
"""
from __future__ import annotations
import dataclasses
import os
from typing import Any
@ -35,7 +37,7 @@ class ProfileMetrics:
self.fusions += other.fusions
return self
def __add__(self, other: "ProfileMetrics") -> "ProfileMetrics":
def __add__(self, other: ProfileMetrics) -> ProfileMetrics:
assert isinstance(other, ProfileMetrics)
return ProfileMetrics(
self.microseconds + other.microseconds,
@ -43,7 +45,7 @@ class ProfileMetrics:
self.fusions + other.fusions,
)
def __truediv__(self, other: Any) -> "ProfileMetrics":
def __truediv__(self, other: Any) -> ProfileMetrics:
if isinstance(other, int):
other = ProfileMetrics(other, other, other)
return ProfileMetrics(

View File

@ -17,6 +17,8 @@ This is primarily used by PyTorch developers and researchers to debug issues in
the Dynamo AOT compilation pipeline, particularly for the Inductor backend.
"""
from __future__ import annotations
import argparse
import copy
import functools
@ -28,7 +30,6 @@ import subprocess
import sys
import textwrap
import uuid
from collections.abc import Sequence
from importlib import import_module
from tempfile import TemporaryFile
from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union
@ -76,7 +77,6 @@ from torch._dynamo.utils import clone_inputs, counters, same
from torch._environment import is_fbcode
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
from torch._inductor.cpp_builder import normalize_path_separator
from torch._inductor.output_code import OutputCode
from torch._library.fake_class_registry import FakeScriptObject
from torch._ops import OpOverload
from torch.fx.experimental.proxy_tensor import make_fx
@ -90,7 +90,10 @@ from .. import config
if TYPE_CHECKING:
from collections.abc import Sequence
from torch._inductor.compile_fx import _CompileFxCallable, _CompileFxKwargs
from torch._inductor.output_code import OutputCode
from torch._inductor.utils import InputType
@ -106,9 +109,9 @@ use_buck = is_fbcode()
def wrap_compiler_debug(
unconfigured_compiler_fn: "_CompileFxCallable",
unconfigured_compiler_fn: _CompileFxCallable,
compiler_name: str,
) -> "_CompileFxCallable":
) -> _CompileFxCallable:
"""
Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both
forward and backward call separately with the backend compiler_fn - like
@ -120,8 +123,8 @@ def wrap_compiler_debug(
@functools.wraps(unconfigured_compiler_fn)
def debug_wrapper(
gm: torch.fx.GraphModule,
example_inputs: Sequence["InputType"],
**kwargs: Unpack["_CompileFxKwargs"],
example_inputs: Sequence[InputType],
**kwargs: Unpack[_CompileFxKwargs],
) -> OutputCode:
from torch._subclasses import FakeTensorMode
@ -161,7 +164,7 @@ def wrap_compiler_debug(
# We may run regular PyTorch compute that may trigger Dynamo, do NOT
# recursively attempt to accuracy minify in that case!
def deferred_for_real_inputs(
real_inputs: Sequence["InputType"], **_kwargs: object
real_inputs: Sequence[InputType], **_kwargs: object
) -> Any:
# This is a bit obscure: if we recursively try to accuracy minify
# the SAME function, this would trigger. But most of the time
@ -173,7 +176,7 @@ def wrap_compiler_debug(
with config.patch(repro_after=None):
return inner_debug_fn(real_inputs)
def inner_debug_fn(real_inputs: Sequence["InputType"]) -> Any:
def inner_debug_fn(real_inputs: Sequence[InputType]) -> Any:
"""
Aot Autograd fw_compiler and bw_compiler can have fake tensors. So,
example_inputs can be fake tensors. We can call compiler_fn (which is

View File

@ -22,6 +22,8 @@ This is a core part of TorchDynamo's tracing system that enables ahead-of-time
optimization of PyTorch programs.
"""
from __future__ import annotations
import collections
import collections.abc
import contextlib
@ -41,7 +43,6 @@ import threading
import traceback
import types
import weakref
from collections.abc import Generator, Sequence
from traceback import StackSummary
from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias, TypeIs
@ -52,7 +53,6 @@ import torch._logging
from torch._dynamo.exc import ObservedException, TensorifyScalarRestartAnalysis
from torch._guards import tracing, TracingContext
from torch._logging.structured import dump_file
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import guard_bool
from torch.utils._functools import cache_method
@ -177,6 +177,10 @@ from .variables.user_defined import (
if TYPE_CHECKING:
from collections.abc import Generator, Sequence
from torch._subclasses.fake_tensor import FakeTensorMode
from .package import CompilePackage
log = logging.getLogger(__name__)
@ -238,7 +242,7 @@ class SpeculationEntry:
restart_reason = "Unknown fail_and_restart_analysis"
raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason)
def failed(self, tx: "InstructionTranslatorBase") -> bool:
def failed(self, tx: InstructionTranslatorBase) -> bool:
if self._failed:
assert self.error_on_graph_break is not None
tx.error_on_graph_break = self.error_on_graph_break
@ -364,7 +368,7 @@ def _step_logger() -> Callable[..., None]:
@contextlib.contextmanager
def save_and_restart_speculation_log(
tx: "InstructionTranslatorBase",
tx: InstructionTranslatorBase,
) -> Generator[None, None, None]:
# When reconstructing a generator after a graph break, we advance it until
# it is fully exhausted. This process adds new entries to the speculation
@ -384,7 +388,7 @@ def save_and_restart_speculation_log(
@contextlib.contextmanager
def temporarely_allow_writes_to_output_graph(
tx: "InstructionTranslatorBase",
tx: InstructionTranslatorBase,
) -> Generator[None, None, None]:
try:
tmp = tx.output.should_exit
@ -420,7 +424,7 @@ class BlockStackEntry:
else:
return ReenterWith(self.stack_index - 1)
def exit(self, tx: "InstructionTranslatorBase", is_graph_break: bool) -> None:
def exit(self, tx: InstructionTranslatorBase, is_graph_break: bool) -> None:
assert self.with_context is not None
if (
is_graph_break and self.with_context.exit_on_graph_break()
@ -448,7 +452,7 @@ def stack_op(fn: Callable[..., object]) -> Callable[..., Any]:
fn_var = BuiltinVariable(fn)
@functools.wraps(fn)
def impl(self: "InstructionTranslator", inst: Instruction) -> None:
def impl(self: InstructionTranslator, inst: Instruction) -> None:
self.push(fn_var.call_function(self, self.popn(nargs), {}))
return impl
@ -464,7 +468,7 @@ def is_stdlib(mod: object) -> bool:
def _detect_and_normalize_assert_statement(
self: "InstructionTranslatorBase",
self: InstructionTranslatorBase,
truth_fn: Callable[[object], bool],
push: bool,
) -> bool:
@ -615,7 +619,7 @@ def log_graph_break(
def generic_jump(
truth_fn: Callable[[object], bool], push: bool
) -> Callable[["InstructionTranslatorBase", Instruction], None]:
) -> Callable[[InstructionTranslatorBase, Instruction], None]:
# graph break message fields for data dependent branching
_gb_type = "Data-dependent branching"
_explanation = (
@ -628,7 +632,7 @@ def generic_jump(
]
def jump_graph_break(
self: "InstructionTranslatorBase",
self: InstructionTranslatorBase,
inst: Instruction,
value: VariableTracker,
extra_msg: str = "",
@ -679,7 +683,7 @@ def generic_jump(
jump_inst.copy_positions(inst)
self.output.add_output_instructions([jump_inst] + if_next + if_jump)
def inner(self: "InstructionTranslatorBase", inst: Instruction) -> None:
def inner(self: InstructionTranslatorBase, inst: Instruction) -> None:
value: VariableTracker = self.pop()
if (
config.rewrite_assert_with_torch_assert
@ -877,13 +881,13 @@ def generic_jump(
def break_graph_if_unsupported(
*, push: int
) -> Callable[
[Callable[..., None]], Callable[["InstructionTranslatorBase", Instruction], None]
[Callable[..., None]], Callable[[InstructionTranslatorBase, Instruction], None]
]:
def decorator(
inner_fn: Callable[..., None],
) -> Callable[["InstructionTranslatorBase", Instruction], None]:
) -> Callable[[InstructionTranslatorBase, Instruction], None]:
@functools.wraps(inner_fn)
def wrapper(self: "InstructionTranslatorBase", inst: Instruction) -> None:
def wrapper(self: InstructionTranslatorBase, inst: Instruction) -> None:
speculation = self.speculate()
if speculation.failed(self):
assert speculation.reason is not None
@ -933,7 +937,7 @@ def break_graph_if_unsupported(
speculation.fail_and_restart_analysis(self.error_on_graph_break)
def handle_graph_break(
self: "InstructionTranslatorBase",
self: InstructionTranslatorBase,
inst: Instruction,
reason: GraphCompileReason,
) -> None:
@ -1159,9 +1163,9 @@ class InstructionTranslatorBase(
strict_checks_fn: Optional[Callable[[VariableTracker], bool]]
start_point: Optional[int]
is_leaf_tracer: bool
parent: Optional["InstructionTranslatorBase"]
parent: Optional[InstructionTranslatorBase]
debug_locals: list[tuple[VariableTracker, list[VariableTracker]]]
package: Optional["CompilePackage"]
package: Optional[CompilePackage]
def mark_inconsistent_side_effects(self) -> None:
"""
@ -3298,7 +3302,7 @@ class InstructionTranslatorBase(
distributed_state: Optional[DistributedState],
# This determines whether to use the execution recorder.
closure: Optional[tuple[types.CellType]] = None,
package: Optional["CompilePackage"] = None,
package: Optional[CompilePackage] = None,
) -> None:
super().__init__()
self.speculation_log = speculation_log
@ -3398,7 +3402,7 @@ class InstructionTranslatorBase(
class InstructionTranslator(InstructionTranslatorBase):
@staticmethod
def current_tx() -> "InstructionTranslator":
def current_tx() -> InstructionTranslator:
return tls.current_tx
@contextlib.contextmanager
@ -3428,7 +3432,7 @@ class InstructionTranslator(InstructionTranslatorBase):
speculation_log: SpeculationLog,
exn_vt_stack: ExceptionStack,
distributed_state: Optional[DistributedState],
package: Optional["CompilePackage"],
package: Optional[CompilePackage],
) -> None:
_step_logger()(
logging.INFO,
@ -3886,7 +3890,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
func: VariableTracker,
args: list[VariableTracker],
kwargs: Any,
) -> "InliningInstructionTranslator":
) -> InliningInstructionTranslator:
assert isinstance(
func,
(