mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Optimize dynamo
typing (#147499)
Optimize dynamo methods type annotation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147499 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
ab7787fb82
commit
510825e5fe
@ -16,6 +16,8 @@ Key classes:
|
||||
- BuckTargetWriter: Manages Buck build system integration
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import copy
|
||||
import cProfile
|
||||
@ -31,9 +33,8 @@ import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from importlib import import_module
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
@ -42,15 +43,20 @@ from torch import Tensor
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._inductor.cpp_builder import normalize_path_separator
|
||||
from torch._prims_common import is_float_dtype
|
||||
from torch.hub import tqdm
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.storage import UntypedStorage
|
||||
from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
|
||||
|
||||
from . import config
|
||||
from .utils import clone_inputs, get_debug_dir
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torch.hub import tqdm
|
||||
from torch.storage import UntypedStorage
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -534,10 +540,10 @@ def backend_accuracy_fails(
|
||||
|
||||
|
||||
def _stride_or_default(
|
||||
stride: Optional["torch._prims_common.StrideType"],
|
||||
stride: Optional[torch._prims_common.StrideType],
|
||||
*,
|
||||
shape: "torch._prims_common.ShapeType",
|
||||
) -> "torch._prims_common.StrideType":
|
||||
shape: torch._prims_common.ShapeType,
|
||||
) -> torch._prims_common.StrideType:
|
||||
return stride if stride is not None else utils.make_contiguous_strides_for(shape)
|
||||
|
||||
|
||||
@ -561,7 +567,7 @@ class NopInputReader:
|
||||
storage_hash: Optional[str],
|
||||
nbytes: int,
|
||||
*,
|
||||
device: Optional["torch._prims_common.DeviceLikeType"] = None,
|
||||
device: Optional[torch._prims_common.DeviceLikeType] = None,
|
||||
dtype_hint: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
self.total += 1
|
||||
@ -592,7 +598,7 @@ class InputReader:
|
||||
storage_hash: Optional[str],
|
||||
nbytes: int,
|
||||
*,
|
||||
device: Optional["torch._prims_common.DeviceLikeType"] = None,
|
||||
device: Optional[torch._prims_common.DeviceLikeType] = None,
|
||||
dtype_hint: Optional[torch.dtype] = None,
|
||||
) -> UntypedStorage:
|
||||
if self.pbar is not None:
|
||||
@ -619,8 +625,8 @@ class InputReader:
|
||||
def tensor(
|
||||
self,
|
||||
storage: UntypedStorage,
|
||||
shape: "torch._prims_common.ShapeType",
|
||||
stride: Optional["torch._prims_common.StrideType"] = None,
|
||||
shape: torch._prims_common.ShapeType,
|
||||
stride: Optional[torch._prims_common.StrideType] = None,
|
||||
*,
|
||||
storage_offset: Optional[int] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
@ -698,7 +704,7 @@ class InputWriter:
|
||||
self,
|
||||
untyped_storage: UntypedStorage,
|
||||
*,
|
||||
device_hint: Optional["torch._prims_common.DeviceLikeType"] = None,
|
||||
device_hint: Optional[torch._prims_common.DeviceLikeType] = None,
|
||||
dtype_hint: Optional[torch.dtype] = None,
|
||||
) -> str:
|
||||
ws = StorageWeakRef(untyped_storage)
|
||||
@ -841,9 +847,7 @@ def aot_graph_input_parser(
|
||||
)
|
||||
return sym_shapes_dict.get(symint, default_sym_shape) # type: ignore[return-value]
|
||||
|
||||
def gen_tensor(
|
||||
shape: "torch._prims_common.ShapeType", dtype: torch.dtype
|
||||
) -> Tensor:
|
||||
def gen_tensor(shape: torch._prims_common.ShapeType, dtype: torch.dtype) -> Tensor:
|
||||
# Resolve symbolic shapes to concrete values
|
||||
resolved_shape = []
|
||||
dynamic_dims = []
|
||||
|
@ -13,6 +13,8 @@ mappings between nodes and their duplicates, enabling efficient graph analysis a
|
||||
optimization operations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copyreg
|
||||
import io
|
||||
import logging
|
||||
@ -163,7 +165,7 @@ class BackwardBfsArgIter:
|
||||
self._queue: deque[Optional[Node]] = deque()
|
||||
|
||||
@staticmethod
|
||||
def create(origin: Node) -> "BackwardBfsArgIter":
|
||||
def create(origin: Node) -> BackwardBfsArgIter:
|
||||
it = BackwardBfsArgIter(origin)
|
||||
it.add_children(origin)
|
||||
# pop the origin node, since it is the origin of
|
||||
@ -238,7 +240,7 @@ class GraphRegionTracker:
|
||||
and n0 is not n1
|
||||
)
|
||||
|
||||
def track_node(self, tx: "InstructionTranslatorBase", node: Node) -> None:
|
||||
def track_node(self, tx: InstructionTranslatorBase, node: Node) -> None:
|
||||
"""
|
||||
The main entry point for tracking a node. This function will hash the node argument and group
|
||||
nodes with the same hash together. It updates the hash_to_duplicates and node_to_duplicates dictionaries
|
||||
|
@ -13,12 +13,17 @@ The metrics system enables comprehensive monitoring and analysis of both compila
|
||||
execution performance.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import heapq
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Callable, Optional
|
||||
from typing_extensions import TypeAlias
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
|
||||
@ -67,7 +72,7 @@ class MetricsContext:
|
||||
self._level: int = 0
|
||||
self._edits: list[tuple[CapturedTraceback, set[str]]] = []
|
||||
|
||||
def __enter__(self) -> "MetricsContext":
|
||||
def __enter__(self) -> Self:
|
||||
"""
|
||||
Initialize metrics recording.
|
||||
"""
|
||||
|
@ -12,6 +12,8 @@ The profiler helps measure and optimize the performance of Dynamo-compiled code
|
||||
by tracking both captured and total operations, timing, and graph statistics.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import Any
|
||||
@ -35,7 +37,7 @@ class ProfileMetrics:
|
||||
self.fusions += other.fusions
|
||||
return self
|
||||
|
||||
def __add__(self, other: "ProfileMetrics") -> "ProfileMetrics":
|
||||
def __add__(self, other: ProfileMetrics) -> ProfileMetrics:
|
||||
assert isinstance(other, ProfileMetrics)
|
||||
return ProfileMetrics(
|
||||
self.microseconds + other.microseconds,
|
||||
@ -43,7 +45,7 @@ class ProfileMetrics:
|
||||
self.fusions + other.fusions,
|
||||
)
|
||||
|
||||
def __truediv__(self, other: Any) -> "ProfileMetrics":
|
||||
def __truediv__(self, other: Any) -> ProfileMetrics:
|
||||
if isinstance(other, int):
|
||||
other = ProfileMetrics(other, other, other)
|
||||
return ProfileMetrics(
|
||||
|
@ -17,6 +17,8 @@ This is primarily used by PyTorch developers and researchers to debug issues in
|
||||
the Dynamo AOT compilation pipeline, particularly for the Inductor backend.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import functools
|
||||
@ -28,7 +30,6 @@ import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from importlib import import_module
|
||||
from tempfile import TemporaryFile
|
||||
from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union
|
||||
@ -76,7 +77,6 @@ from torch._dynamo.utils import clone_inputs, counters, same
|
||||
from torch._environment import is_fbcode
|
||||
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
|
||||
from torch._inductor.cpp_builder import normalize_path_separator
|
||||
from torch._inductor.output_code import OutputCode
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@ -90,7 +90,10 @@ from .. import config
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torch._inductor.compile_fx import _CompileFxCallable, _CompileFxKwargs
|
||||
from torch._inductor.output_code import OutputCode
|
||||
from torch._inductor.utils import InputType
|
||||
|
||||
|
||||
@ -106,9 +109,9 @@ use_buck = is_fbcode()
|
||||
|
||||
|
||||
def wrap_compiler_debug(
|
||||
unconfigured_compiler_fn: "_CompileFxCallable",
|
||||
unconfigured_compiler_fn: _CompileFxCallable,
|
||||
compiler_name: str,
|
||||
) -> "_CompileFxCallable":
|
||||
) -> _CompileFxCallable:
|
||||
"""
|
||||
Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both
|
||||
forward and backward call separately with the backend compiler_fn - like
|
||||
@ -120,8 +123,8 @@ def wrap_compiler_debug(
|
||||
@functools.wraps(unconfigured_compiler_fn)
|
||||
def debug_wrapper(
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: Sequence["InputType"],
|
||||
**kwargs: Unpack["_CompileFxKwargs"],
|
||||
example_inputs: Sequence[InputType],
|
||||
**kwargs: Unpack[_CompileFxKwargs],
|
||||
) -> OutputCode:
|
||||
from torch._subclasses import FakeTensorMode
|
||||
|
||||
@ -161,7 +164,7 @@ def wrap_compiler_debug(
|
||||
# We may run regular PyTorch compute that may trigger Dynamo, do NOT
|
||||
# recursively attempt to accuracy minify in that case!
|
||||
def deferred_for_real_inputs(
|
||||
real_inputs: Sequence["InputType"], **_kwargs: object
|
||||
real_inputs: Sequence[InputType], **_kwargs: object
|
||||
) -> Any:
|
||||
# This is a bit obscure: if we recursively try to accuracy minify
|
||||
# the SAME function, this would trigger. But most of the time
|
||||
@ -173,7 +176,7 @@ def wrap_compiler_debug(
|
||||
with config.patch(repro_after=None):
|
||||
return inner_debug_fn(real_inputs)
|
||||
|
||||
def inner_debug_fn(real_inputs: Sequence["InputType"]) -> Any:
|
||||
def inner_debug_fn(real_inputs: Sequence[InputType]) -> Any:
|
||||
"""
|
||||
Aot Autograd fw_compiler and bw_compiler can have fake tensors. So,
|
||||
example_inputs can be fake tensors. We can call compiler_fn (which is
|
||||
|
@ -22,6 +22,8 @@ This is a core part of TorchDynamo's tracing system that enables ahead-of-time
|
||||
optimization of PyTorch programs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import contextlib
|
||||
@ -41,7 +43,6 @@ import threading
|
||||
import traceback
|
||||
import types
|
||||
import weakref
|
||||
from collections.abc import Generator, Sequence
|
||||
from traceback import StackSummary
|
||||
from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
@ -52,7 +53,6 @@ import torch._logging
|
||||
from torch._dynamo.exc import ObservedException, TensorifyScalarRestartAnalysis
|
||||
from torch._guards import tracing, TracingContext
|
||||
from torch._logging.structured import dump_file
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.symbolic_shapes import guard_bool
|
||||
from torch.utils._functools import cache_method
|
||||
|
||||
@ -177,6 +177,10 @@ from .variables.user_defined import (
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Sequence
|
||||
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
||||
from .package import CompilePackage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -238,7 +242,7 @@ class SpeculationEntry:
|
||||
restart_reason = "Unknown fail_and_restart_analysis"
|
||||
raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason)
|
||||
|
||||
def failed(self, tx: "InstructionTranslatorBase") -> bool:
|
||||
def failed(self, tx: InstructionTranslatorBase) -> bool:
|
||||
if self._failed:
|
||||
assert self.error_on_graph_break is not None
|
||||
tx.error_on_graph_break = self.error_on_graph_break
|
||||
@ -364,7 +368,7 @@ def _step_logger() -> Callable[..., None]:
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_and_restart_speculation_log(
|
||||
tx: "InstructionTranslatorBase",
|
||||
tx: InstructionTranslatorBase,
|
||||
) -> Generator[None, None, None]:
|
||||
# When reconstructing a generator after a graph break, we advance it until
|
||||
# it is fully exhausted. This process adds new entries to the speculation
|
||||
@ -384,7 +388,7 @@ def save_and_restart_speculation_log(
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temporarely_allow_writes_to_output_graph(
|
||||
tx: "InstructionTranslatorBase",
|
||||
tx: InstructionTranslatorBase,
|
||||
) -> Generator[None, None, None]:
|
||||
try:
|
||||
tmp = tx.output.should_exit
|
||||
@ -420,7 +424,7 @@ class BlockStackEntry:
|
||||
else:
|
||||
return ReenterWith(self.stack_index - 1)
|
||||
|
||||
def exit(self, tx: "InstructionTranslatorBase", is_graph_break: bool) -> None:
|
||||
def exit(self, tx: InstructionTranslatorBase, is_graph_break: bool) -> None:
|
||||
assert self.with_context is not None
|
||||
if (
|
||||
is_graph_break and self.with_context.exit_on_graph_break()
|
||||
@ -448,7 +452,7 @@ def stack_op(fn: Callable[..., object]) -> Callable[..., Any]:
|
||||
fn_var = BuiltinVariable(fn)
|
||||
|
||||
@functools.wraps(fn)
|
||||
def impl(self: "InstructionTranslator", inst: Instruction) -> None:
|
||||
def impl(self: InstructionTranslator, inst: Instruction) -> None:
|
||||
self.push(fn_var.call_function(self, self.popn(nargs), {}))
|
||||
|
||||
return impl
|
||||
@ -464,7 +468,7 @@ def is_stdlib(mod: object) -> bool:
|
||||
|
||||
|
||||
def _detect_and_normalize_assert_statement(
|
||||
self: "InstructionTranslatorBase",
|
||||
self: InstructionTranslatorBase,
|
||||
truth_fn: Callable[[object], bool],
|
||||
push: bool,
|
||||
) -> bool:
|
||||
@ -615,7 +619,7 @@ def log_graph_break(
|
||||
|
||||
def generic_jump(
|
||||
truth_fn: Callable[[object], bool], push: bool
|
||||
) -> Callable[["InstructionTranslatorBase", Instruction], None]:
|
||||
) -> Callable[[InstructionTranslatorBase, Instruction], None]:
|
||||
# graph break message fields for data dependent branching
|
||||
_gb_type = "Data-dependent branching"
|
||||
_explanation = (
|
||||
@ -628,7 +632,7 @@ def generic_jump(
|
||||
]
|
||||
|
||||
def jump_graph_break(
|
||||
self: "InstructionTranslatorBase",
|
||||
self: InstructionTranslatorBase,
|
||||
inst: Instruction,
|
||||
value: VariableTracker,
|
||||
extra_msg: str = "",
|
||||
@ -679,7 +683,7 @@ def generic_jump(
|
||||
jump_inst.copy_positions(inst)
|
||||
self.output.add_output_instructions([jump_inst] + if_next + if_jump)
|
||||
|
||||
def inner(self: "InstructionTranslatorBase", inst: Instruction) -> None:
|
||||
def inner(self: InstructionTranslatorBase, inst: Instruction) -> None:
|
||||
value: VariableTracker = self.pop()
|
||||
if (
|
||||
config.rewrite_assert_with_torch_assert
|
||||
@ -877,13 +881,13 @@ def generic_jump(
|
||||
def break_graph_if_unsupported(
|
||||
*, push: int
|
||||
) -> Callable[
|
||||
[Callable[..., None]], Callable[["InstructionTranslatorBase", Instruction], None]
|
||||
[Callable[..., None]], Callable[[InstructionTranslatorBase, Instruction], None]
|
||||
]:
|
||||
def decorator(
|
||||
inner_fn: Callable[..., None],
|
||||
) -> Callable[["InstructionTranslatorBase", Instruction], None]:
|
||||
) -> Callable[[InstructionTranslatorBase, Instruction], None]:
|
||||
@functools.wraps(inner_fn)
|
||||
def wrapper(self: "InstructionTranslatorBase", inst: Instruction) -> None:
|
||||
def wrapper(self: InstructionTranslatorBase, inst: Instruction) -> None:
|
||||
speculation = self.speculate()
|
||||
if speculation.failed(self):
|
||||
assert speculation.reason is not None
|
||||
@ -933,7 +937,7 @@ def break_graph_if_unsupported(
|
||||
speculation.fail_and_restart_analysis(self.error_on_graph_break)
|
||||
|
||||
def handle_graph_break(
|
||||
self: "InstructionTranslatorBase",
|
||||
self: InstructionTranslatorBase,
|
||||
inst: Instruction,
|
||||
reason: GraphCompileReason,
|
||||
) -> None:
|
||||
@ -1159,9 +1163,9 @@ class InstructionTranslatorBase(
|
||||
strict_checks_fn: Optional[Callable[[VariableTracker], bool]]
|
||||
start_point: Optional[int]
|
||||
is_leaf_tracer: bool
|
||||
parent: Optional["InstructionTranslatorBase"]
|
||||
parent: Optional[InstructionTranslatorBase]
|
||||
debug_locals: list[tuple[VariableTracker, list[VariableTracker]]]
|
||||
package: Optional["CompilePackage"]
|
||||
package: Optional[CompilePackage]
|
||||
|
||||
def mark_inconsistent_side_effects(self) -> None:
|
||||
"""
|
||||
@ -3298,7 +3302,7 @@ class InstructionTranslatorBase(
|
||||
distributed_state: Optional[DistributedState],
|
||||
# This determines whether to use the execution recorder.
|
||||
closure: Optional[tuple[types.CellType]] = None,
|
||||
package: Optional["CompilePackage"] = None,
|
||||
package: Optional[CompilePackage] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.speculation_log = speculation_log
|
||||
@ -3398,7 +3402,7 @@ class InstructionTranslatorBase(
|
||||
|
||||
class InstructionTranslator(InstructionTranslatorBase):
|
||||
@staticmethod
|
||||
def current_tx() -> "InstructionTranslator":
|
||||
def current_tx() -> InstructionTranslator:
|
||||
return tls.current_tx
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -3428,7 +3432,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
speculation_log: SpeculationLog,
|
||||
exn_vt_stack: ExceptionStack,
|
||||
distributed_state: Optional[DistributedState],
|
||||
package: Optional["CompilePackage"],
|
||||
package: Optional[CompilePackage],
|
||||
) -> None:
|
||||
_step_logger()(
|
||||
logging.INFO,
|
||||
@ -3886,7 +3890,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
func: VariableTracker,
|
||||
args: list[VariableTracker],
|
||||
kwargs: Any,
|
||||
) -> "InliningInstructionTranslator":
|
||||
) -> InliningInstructionTranslator:
|
||||
assert isinstance(
|
||||
func,
|
||||
(
|
||||
|
Reference in New Issue
Block a user