mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo][Better Engineering] Add typing for comptime, cache, and convert_frame (#158379)
As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to a critical tracing point for dynamo, primarily for`comptime.py` but also `cache_size.py` and `convert_frame.py`. Running ``` mypy torch/_dynamo/comptime.py torch/_dynamo/cache_size.py torch/_dynamo/convert_frame.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 1837 | 2215 | 82.93% | 45 | 82 | 54.88% | | This PR | 2230 | 2230 | 100.00% | 82 | 82 | 100.00% | | Delta | +393 | +15 | +17.07% | +37 | 0 | +45.12% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158379 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
6fd6fc418d
commit
583138d170
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch._guards import CompileId
|
||||
|
||||
@ -9,7 +9,7 @@ from . import config
|
||||
from .types import DynamoFrameType
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
"""
|
||||
[Note on cache size limit]
|
||||
|
||||
@ -99,7 +99,9 @@ class CacheSizeRelevantForFrame:
|
||||
return self.num_cache_entries_with_same_id_matched_objs >= limit
|
||||
|
||||
|
||||
def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str):
|
||||
def _get_weakref_from_f_locals(
|
||||
frame: DynamoFrameType, local_name: str
|
||||
) -> Optional[weakref.ref[Any]]:
|
||||
obj = frame.f_locals.get(local_name, None)
|
||||
weak_id = None
|
||||
try:
|
||||
@ -109,7 +111,7 @@ def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str):
|
||||
return weak_id
|
||||
|
||||
|
||||
def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool:
|
||||
def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry: Any) -> bool:
|
||||
"""
|
||||
Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones
|
||||
in frame.f_locals.
|
||||
@ -131,7 +133,7 @@ def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool:
|
||||
|
||||
|
||||
def compute_cache_size(
|
||||
frame: DynamoFrameType, cache_entry
|
||||
frame: DynamoFrameType, cache_entry: Any
|
||||
) -> CacheSizeRelevantForFrame:
|
||||
# Walk the linked list to calculate the cache size
|
||||
num_cache_entries = 0
|
||||
|
@ -1,5 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
This module provides the public comptime interface to TorchDynamo, enabling users to execute
|
||||
arbitrary Python code during symbolic evaluation of their programs.
|
||||
@ -40,9 +38,13 @@ import builtins
|
||||
import dis
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, TextIO, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
|
||||
from torch._dynamo.variables.base import VariableTracker
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.fx.experimental.symbolic_shapes import free_symbols
|
||||
|
||||
from .exc import unimplemented_v2
|
||||
@ -62,10 +64,10 @@ class ComptimeVar:
|
||||
actual data in the Tensor is.)
|
||||
"""
|
||||
|
||||
def __init__(self, v) -> None:
|
||||
def __init__(self, v: VariableTracker) -> None:
|
||||
self.__variable = v
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Union[VariableTracker, Sequence[VariableTracker]]:
|
||||
"""
|
||||
Returns an fx.Proxy (or tuple/list of fx.Proxy) representing
|
||||
this variable in the FX graph we are assembling to pass
|
||||
@ -79,13 +81,13 @@ class ComptimeVar:
|
||||
"""
|
||||
return self.__variable.as_proxy()
|
||||
|
||||
def is_proxy(self):
|
||||
def is_proxy(self) -> bool:
|
||||
"""
|
||||
Returns True if as_proxy() would succeed.
|
||||
"""
|
||||
return self.__variable.is_proxy()
|
||||
|
||||
def as_fake(self):
|
||||
def as_fake(self) -> Union[FakeTensor, torch.SymInt]:
|
||||
"""
|
||||
Returns a "fake" value (either a FakeTensor or a SymInt)
|
||||
representing the variable in question. This only works
|
||||
@ -102,16 +104,16 @@ class ComptimeVar:
|
||||
Returns the size of the tensor (if dim is None) or the size
|
||||
at the dimension dim. The returned size may be a SymInt.
|
||||
"""
|
||||
return self.as_fake().size(dim)
|
||||
return self.as_fake().size(dim) # type: ignore[union-attr, return-value]
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
"""
|
||||
Returns what type(v) would have returned for the variable
|
||||
at compile time.
|
||||
"""
|
||||
return self.__variable.python_type()
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
"""
|
||||
Returns the Python value this variable would have, but only if it is
|
||||
completely known at compile-time (e.g., it is constant).
|
||||
@ -123,19 +125,19 @@ class ComptimeVar:
|
||||
"""
|
||||
return self.__variable.as_python_constant()
|
||||
|
||||
def is_python_constant(self):
|
||||
def is_python_constant(self) -> bool:
|
||||
"""
|
||||
Returns True if as_python_constant would succeed.
|
||||
"""
|
||||
return self.__variable.is_python_constant()
|
||||
|
||||
def is_dynamic(self):
|
||||
def is_dynamic(self) -> bool:
|
||||
if isinstance(self.__variable, SymNodeVariable):
|
||||
fs = free_symbols(self.__variable.sym_num)
|
||||
return bool(fs)
|
||||
return False
|
||||
|
||||
def force_static(self):
|
||||
def force_static(self) -> None:
|
||||
"""
|
||||
Forces that a value is static, inducing a guard on its specific value
|
||||
"""
|
||||
@ -149,7 +151,7 @@ class ComptimeVar:
|
||||
f"cannot force {self.__variable} ({type(self.__variable)}) static"
|
||||
)
|
||||
|
||||
def _i_will_not_complain_if_bc_breaks_VariableTracker(self):
|
||||
def _i_will_not_complain_if_bc_breaks_VariableTracker(self) -> VariableTracker:
|
||||
"""
|
||||
Returns the internal data structure VariableTracker that Dynamo uses
|
||||
to represent variables at compile time. There are no BC guarantees on
|
||||
@ -171,10 +173,10 @@ class ComptimeContext:
|
||||
file a feature request at https://github.com/pytorch/pytorch/
|
||||
"""
|
||||
|
||||
def __init__(self, tx) -> None:
|
||||
def __init__(self, tx: InstructionTranslatorBase) -> None:
|
||||
self.__tx = tx
|
||||
|
||||
def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar:
|
||||
def get_local(self, name: str, *, stacklevel: int = 0) -> ComptimeVar:
|
||||
"""
|
||||
Retrieve the compile-time known information about a local.
|
||||
"""
|
||||
@ -187,7 +189,7 @@ class ComptimeContext:
|
||||
|
||||
return ComptimeVar(var)
|
||||
|
||||
def graph_break(self, msg="ComptimeContext.graph_break"):
|
||||
def graph_break(self, msg: str = "ComptimeContext.graph_break") -> None:
|
||||
"""
|
||||
Manually trigger a graph break
|
||||
"""
|
||||
@ -198,14 +200,14 @@ class ComptimeContext:
|
||||
hints=[],
|
||||
)
|
||||
|
||||
def graph(self):
|
||||
def graph(self) -> torch.fx.Graph:
|
||||
"""
|
||||
Retrieve the partially constructed FX graph that would be
|
||||
passed to the user compiler after compilation.
|
||||
"""
|
||||
return self.__tx.output.graph
|
||||
|
||||
def assert_static(self, val):
|
||||
def assert_static(self, val: ComptimeVar) -> None:
|
||||
"""
|
||||
Asserts that the int is static (and not dynamic, per dynamic shapes)
|
||||
"""
|
||||
@ -213,7 +215,9 @@ class ComptimeContext:
|
||||
"expected static but got dynamic (run with TORCH_LOGS=dynamic for more info)"
|
||||
)
|
||||
|
||||
def print_graph(self, *, verbose=True, file=None):
|
||||
def print_graph(
|
||||
self, *, verbose: bool = True, file: Optional[TextIO] = None
|
||||
) -> None:
|
||||
"""
|
||||
Print the partially constructed FX graph that would be passed
|
||||
to the user compiler after compilation.
|
||||
@ -222,19 +226,21 @@ class ComptimeContext:
|
||||
self.__tx.output.graph.python_code("self", verbose=verbose).src, file=file
|
||||
)
|
||||
|
||||
def parent(self):
|
||||
return ComptimeContext(self.__tx.parent)
|
||||
def parent(self) -> "ComptimeContext":
|
||||
return ComptimeContext(self.__tx.parent) # type: ignore[arg-type]
|
||||
|
||||
def __get_tx(self, stacklevel):
|
||||
def __get_tx(self, stacklevel: int) -> Any:
|
||||
tx = self.__tx
|
||||
for _ in range(stacklevel):
|
||||
tx = tx.parent
|
||||
tx = tx.parent # type: ignore[assignment]
|
||||
return tx
|
||||
|
||||
def print(self, val, *, file=None):
|
||||
def print(self, val: Any, *, file: Optional[TextIO] = None) -> None:
|
||||
print(repr(val), file=file)
|
||||
|
||||
def print_disas(self, *, file=None, stacklevel=0):
|
||||
def print_disas(
|
||||
self, *, file: Optional[TextIO] = None, stacklevel: int = 0
|
||||
) -> None:
|
||||
"""
|
||||
Print the current series of opcodes being executed (not including
|
||||
parent frames), including where you are in the particular opcode
|
||||
@ -249,7 +255,9 @@ class ComptimeContext:
|
||||
file=file,
|
||||
)
|
||||
|
||||
def print_value_stack(self, *, file=None, stacklevel=0):
|
||||
def print_value_stack(
|
||||
self, *, file: Optional[TextIO] = None, stacklevel: int = 0
|
||||
) -> None:
|
||||
"""
|
||||
Print the current Python value stack. Note that this is NOT the same
|
||||
as the traceback; use print_bt() to print that. Note that at
|
||||
@ -264,7 +272,9 @@ class ComptimeContext:
|
||||
for s in tx.stack:
|
||||
print(f"- {s.debug_repr()}", file=file)
|
||||
|
||||
def print_locals(self, *, file=None, stacklevel=0):
|
||||
def print_locals(
|
||||
self, *, file: Optional[TextIO] = None, stacklevel: int = 0
|
||||
) -> None:
|
||||
"""
|
||||
Print all of the locals available in the current context.
|
||||
By default this view is very limited; you can get more information
|
||||
@ -274,7 +284,7 @@ class ComptimeContext:
|
||||
for k, v in tx.symbolic_locals.items():
|
||||
print(f"{k} = {v.debug_repr()}", file=file)
|
||||
|
||||
def print_bt(self, *, file=None, stacklevel=0):
|
||||
def print_bt(self, *, file: Optional[TextIO] = None, stacklevel: int = 0) -> None:
|
||||
"""
|
||||
Print the user code backtrace, starting at the beginning of the
|
||||
frame Dynamo started evaluating. Note that this MAY NOT go all
|
||||
@ -293,7 +303,7 @@ class ComptimeContext:
|
||||
file=file,
|
||||
)
|
||||
|
||||
def print_guards(self, *, file=None):
|
||||
def print_guards(self, *, file: Optional[TextIO] = None) -> None:
|
||||
"""
|
||||
Print the currently installed guards for the Dynamo context.
|
||||
This does NOT include guards associated with variables that
|
||||
@ -307,7 +317,9 @@ class ComptimeContext:
|
||||
file=file,
|
||||
)
|
||||
|
||||
def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self):
|
||||
def _i_will_not_complain_if_bc_breaks_InstructionTranslator(
|
||||
self,
|
||||
) -> InstructionTranslatorBase:
|
||||
"""
|
||||
Returns the internal data structure InstructionTranslator that Dynamo
|
||||
uses to track state of symbolic evaluation. There are no BC
|
||||
@ -316,32 +328,35 @@ class ComptimeContext:
|
||||
"""
|
||||
return self.__tx
|
||||
|
||||
def sleep(self, sec):
|
||||
def sleep(self, sec: Union[int, float]) -> None:
|
||||
time.sleep(sec)
|
||||
|
||||
|
||||
class _Comptime:
|
||||
@staticmethod
|
||||
def __call__(fn, fallback_fn=lambda: None):
|
||||
def __call__(
|
||||
fn: Callable[[ComptimeContext], Any],
|
||||
fallback_fn: Callable[[], Any] = lambda: None,
|
||||
) -> Any:
|
||||
"""fn gets called at compile time in TorchDynamo, calls fallback_fn otherwise"""
|
||||
fallback_fn()
|
||||
|
||||
# Convenience wrappers that are more compact to use
|
||||
|
||||
@staticmethod
|
||||
def graph_break():
|
||||
def graph_break() -> None:
|
||||
comptime(lambda ctx: ctx.graph_break())
|
||||
|
||||
@staticmethod
|
||||
def print(e):
|
||||
def print(e: Any) -> None:
|
||||
comptime(lambda ctx: ctx.print(ctx.get_local("e")), lambda: print(e))
|
||||
|
||||
@staticmethod
|
||||
def print_graph():
|
||||
def print_graph() -> None:
|
||||
comptime(lambda ctx: ctx.print_graph())
|
||||
|
||||
@staticmethod
|
||||
def print_disas(*, stacklevel=0):
|
||||
def print_disas(*, stacklevel: int = 0) -> None:
|
||||
comptime(
|
||||
lambda ctx: ctx.print_disas(
|
||||
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
|
||||
@ -349,7 +364,7 @@ class _Comptime:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def print_value_stack(*, stacklevel=0):
|
||||
def print_value_stack(*, stacklevel: int = 0) -> None:
|
||||
comptime(
|
||||
lambda ctx: ctx.print_value_stack(
|
||||
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
|
||||
@ -360,7 +375,7 @@ class _Comptime:
|
||||
# in an expression context; e.g., x + print_value_stack_and_return(y + z),
|
||||
# you will see x on the stack prior to the addition operation
|
||||
@staticmethod
|
||||
def print_value_stack_and_return(e, *, stacklevel=0):
|
||||
def print_value_stack_and_return(e: Any, *, stacklevel: int = 0) -> Any:
|
||||
comptime(
|
||||
lambda ctx: ctx.print_value_stack(
|
||||
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
|
||||
@ -369,7 +384,7 @@ class _Comptime:
|
||||
return e
|
||||
|
||||
@staticmethod
|
||||
def print_locals(*, stacklevel=0):
|
||||
def print_locals(*, stacklevel: int = 0) -> None:
|
||||
comptime(
|
||||
lambda ctx: ctx.print_locals(
|
||||
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
|
||||
@ -377,7 +392,7 @@ class _Comptime:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def print_bt(*, stacklevel=0):
|
||||
def print_bt(*, stacklevel: int = 0) -> None:
|
||||
comptime(
|
||||
lambda ctx: ctx.print_bt(
|
||||
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
|
||||
@ -385,19 +400,19 @@ class _Comptime:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def print_guards():
|
||||
def print_guards() -> None:
|
||||
comptime(lambda ctx: ctx.print_guards())
|
||||
|
||||
@staticmethod
|
||||
def assert_static(val):
|
||||
def assert_static(val: Any) -> None:
|
||||
comptime(lambda ctx: ctx.assert_static(ctx.get_local("val")))
|
||||
|
||||
@staticmethod
|
||||
def force_static(val):
|
||||
def force_static(val: Any) -> None:
|
||||
comptime(lambda ctx: ctx.get_local("val").force_static())
|
||||
|
||||
@staticmethod
|
||||
def breakpoint():
|
||||
def breakpoint() -> None:
|
||||
"""
|
||||
Like pdb breakpoint(), but drop into pdb whenever this line
|
||||
of code is compiled by dynamo. Use it by putting
|
||||
@ -415,14 +430,14 @@ class _Comptime:
|
||||
(Pdb) p ctx.get_local("attention").as_fake()
|
||||
"""
|
||||
|
||||
def inner(inner_ctx):
|
||||
def inner(inner_ctx: ComptimeContext) -> None:
|
||||
ctx = inner_ctx.parent() # noqa: F841
|
||||
builtins.breakpoint()
|
||||
|
||||
comptime(inner)
|
||||
|
||||
@staticmethod
|
||||
def sleep(sec):
|
||||
def sleep(sec: Union[int, float]) -> None:
|
||||
comptime(lambda ctx: ctx.sleep(ctx.get_local("sec").as_python_constant()))
|
||||
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
|
||||
"""
|
||||
This module implements TorchDynamo's core frame conversion functionality, transforming Python
|
||||
frames into FX graphs. It handles:
|
||||
|
@ -1705,7 +1705,7 @@ def export(
|
||||
_log_export_usage: bool = True,
|
||||
constraints: Optional[list[Constraint]] = None,
|
||||
**extra_kwargs: Any,
|
||||
) -> Callable[[tuple[Any, Any]], ExportResult]:
|
||||
) -> Callable[..., ExportResult]:
|
||||
"""
|
||||
Export an input function f to a format that can be executed outside of PyTorch using the FX graph.
|
||||
|
||||
|
Reference in New Issue
Block a user