[Dynamo][Better Engineering] Add typing for comptime, cache, and convert_frame (#158379)

As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo

This PR adds strict typing support to a critical tracing point for dynamo, primarily for`comptime.py` but also `cache_size.py` and `convert_frame.py`.

Running
```
mypy torch/_dynamo/comptime.py torch/_dynamo/cache_size.py torch/_dynamo/convert_frame.py --linecount-report /tmp/coverage_log
```

| -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main  |  1837 | 2215 | 82.93% | 45 | 82 | 54.88% |
| This PR | 2230 | 2230 | 100.00% | 82 | 82 | 100.00% |
| Delta    | +393 | +15 | +17.07% | +37 | 0 | +45.12% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158379
Approved by: https://github.com/mlazos
This commit is contained in:
Lucas Kabela
2025-07-18 02:11:52 +00:00
committed by PyTorch MergeBot
parent 6fd6fc418d
commit 583138d170
4 changed files with 70 additions and 55 deletions

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import logging
import weakref
from dataclasses import dataclass
from typing import Any, Optional
from torch._guards import CompileId
@ -9,7 +9,7 @@ from . import config
from .types import DynamoFrameType
log = logging.getLogger(__name__)
log: logging.Logger = logging.getLogger(__name__)
"""
[Note on cache size limit]
@ -99,7 +99,9 @@ class CacheSizeRelevantForFrame:
return self.num_cache_entries_with_same_id_matched_objs >= limit
def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str):
def _get_weakref_from_f_locals(
frame: DynamoFrameType, local_name: str
) -> Optional[weakref.ref[Any]]:
obj = frame.f_locals.get(local_name, None)
weak_id = None
try:
@ -109,7 +111,7 @@ def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str):
return weak_id
def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool:
def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry: Any) -> bool:
"""
Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones
in frame.f_locals.
@ -131,7 +133,7 @@ def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool:
def compute_cache_size(
frame: DynamoFrameType, cache_entry
frame: DynamoFrameType, cache_entry: Any
) -> CacheSizeRelevantForFrame:
# Walk the linked list to calculate the cache size
num_cache_entries = 0

View File

@ -1,5 +1,3 @@
# mypy: allow-untyped-defs
"""
This module provides the public comptime interface to TorchDynamo, enabling users to execute
arbitrary Python code during symbolic evaluation of their programs.
@ -40,9 +38,13 @@ import builtins
import dis
import time
import traceback
from typing import Optional, Union
from collections.abc import Sequence
from typing import Any, Callable, Optional, TextIO, Union
import torch
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
from torch._dynamo.variables.base import VariableTracker
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.symbolic_shapes import free_symbols
from .exc import unimplemented_v2
@ -62,10 +64,10 @@ class ComptimeVar:
actual data in the Tensor is.)
"""
def __init__(self, v) -> None:
def __init__(self, v: VariableTracker) -> None:
self.__variable = v
def as_proxy(self):
def as_proxy(self) -> Union[VariableTracker, Sequence[VariableTracker]]:
"""
Returns an fx.Proxy (or tuple/list of fx.Proxy) representing
this variable in the FX graph we are assembling to pass
@ -79,13 +81,13 @@ class ComptimeVar:
"""
return self.__variable.as_proxy()
def is_proxy(self):
def is_proxy(self) -> bool:
"""
Returns True if as_proxy() would succeed.
"""
return self.__variable.is_proxy()
def as_fake(self):
def as_fake(self) -> Union[FakeTensor, torch.SymInt]:
"""
Returns a "fake" value (either a FakeTensor or a SymInt)
representing the variable in question. This only works
@ -102,16 +104,16 @@ class ComptimeVar:
Returns the size of the tensor (if dim is None) or the size
at the dimension dim. The returned size may be a SymInt.
"""
return self.as_fake().size(dim)
return self.as_fake().size(dim) # type: ignore[union-attr, return-value]
def python_type(self):
def python_type(self) -> type:
"""
Returns what type(v) would have returned for the variable
at compile time.
"""
return self.__variable.python_type()
def as_python_constant(self):
def as_python_constant(self) -> Any:
"""
Returns the Python value this variable would have, but only if it is
completely known at compile-time (e.g., it is constant).
@ -123,19 +125,19 @@ class ComptimeVar:
"""
return self.__variable.as_python_constant()
def is_python_constant(self):
def is_python_constant(self) -> bool:
"""
Returns True if as_python_constant would succeed.
"""
return self.__variable.is_python_constant()
def is_dynamic(self):
def is_dynamic(self) -> bool:
if isinstance(self.__variable, SymNodeVariable):
fs = free_symbols(self.__variable.sym_num)
return bool(fs)
return False
def force_static(self):
def force_static(self) -> None:
"""
Forces that a value is static, inducing a guard on its specific value
"""
@ -149,7 +151,7 @@ class ComptimeVar:
f"cannot force {self.__variable} ({type(self.__variable)}) static"
)
def _i_will_not_complain_if_bc_breaks_VariableTracker(self):
def _i_will_not_complain_if_bc_breaks_VariableTracker(self) -> VariableTracker:
"""
Returns the internal data structure VariableTracker that Dynamo uses
to represent variables at compile time. There are no BC guarantees on
@ -171,10 +173,10 @@ class ComptimeContext:
file a feature request at https://github.com/pytorch/pytorch/
"""
def __init__(self, tx) -> None:
def __init__(self, tx: InstructionTranslatorBase) -> None:
self.__tx = tx
def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar:
def get_local(self, name: str, *, stacklevel: int = 0) -> ComptimeVar:
"""
Retrieve the compile-time known information about a local.
"""
@ -187,7 +189,7 @@ class ComptimeContext:
return ComptimeVar(var)
def graph_break(self, msg="ComptimeContext.graph_break"):
def graph_break(self, msg: str = "ComptimeContext.graph_break") -> None:
"""
Manually trigger a graph break
"""
@ -198,14 +200,14 @@ class ComptimeContext:
hints=[],
)
def graph(self):
def graph(self) -> torch.fx.Graph:
"""
Retrieve the partially constructed FX graph that would be
passed to the user compiler after compilation.
"""
return self.__tx.output.graph
def assert_static(self, val):
def assert_static(self, val: ComptimeVar) -> None:
"""
Asserts that the int is static (and not dynamic, per dynamic shapes)
"""
@ -213,7 +215,9 @@ class ComptimeContext:
"expected static but got dynamic (run with TORCH_LOGS=dynamic for more info)"
)
def print_graph(self, *, verbose=True, file=None):
def print_graph(
self, *, verbose: bool = True, file: Optional[TextIO] = None
) -> None:
"""
Print the partially constructed FX graph that would be passed
to the user compiler after compilation.
@ -222,19 +226,21 @@ class ComptimeContext:
self.__tx.output.graph.python_code("self", verbose=verbose).src, file=file
)
def parent(self):
return ComptimeContext(self.__tx.parent)
def parent(self) -> "ComptimeContext":
return ComptimeContext(self.__tx.parent) # type: ignore[arg-type]
def __get_tx(self, stacklevel):
def __get_tx(self, stacklevel: int) -> Any:
tx = self.__tx
for _ in range(stacklevel):
tx = tx.parent
tx = tx.parent # type: ignore[assignment]
return tx
def print(self, val, *, file=None):
def print(self, val: Any, *, file: Optional[TextIO] = None) -> None:
print(repr(val), file=file)
def print_disas(self, *, file=None, stacklevel=0):
def print_disas(
self, *, file: Optional[TextIO] = None, stacklevel: int = 0
) -> None:
"""
Print the current series of opcodes being executed (not including
parent frames), including where you are in the particular opcode
@ -249,7 +255,9 @@ class ComptimeContext:
file=file,
)
def print_value_stack(self, *, file=None, stacklevel=0):
def print_value_stack(
self, *, file: Optional[TextIO] = None, stacklevel: int = 0
) -> None:
"""
Print the current Python value stack. Note that this is NOT the same
as the traceback; use print_bt() to print that. Note that at
@ -264,7 +272,9 @@ class ComptimeContext:
for s in tx.stack:
print(f"- {s.debug_repr()}", file=file)
def print_locals(self, *, file=None, stacklevel=0):
def print_locals(
self, *, file: Optional[TextIO] = None, stacklevel: int = 0
) -> None:
"""
Print all of the locals available in the current context.
By default this view is very limited; you can get more information
@ -274,7 +284,7 @@ class ComptimeContext:
for k, v in tx.symbolic_locals.items():
print(f"{k} = {v.debug_repr()}", file=file)
def print_bt(self, *, file=None, stacklevel=0):
def print_bt(self, *, file: Optional[TextIO] = None, stacklevel: int = 0) -> None:
"""
Print the user code backtrace, starting at the beginning of the
frame Dynamo started evaluating. Note that this MAY NOT go all
@ -293,7 +303,7 @@ class ComptimeContext:
file=file,
)
def print_guards(self, *, file=None):
def print_guards(self, *, file: Optional[TextIO] = None) -> None:
"""
Print the currently installed guards for the Dynamo context.
This does NOT include guards associated with variables that
@ -307,7 +317,9 @@ class ComptimeContext:
file=file,
)
def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self):
def _i_will_not_complain_if_bc_breaks_InstructionTranslator(
self,
) -> InstructionTranslatorBase:
"""
Returns the internal data structure InstructionTranslator that Dynamo
uses to track state of symbolic evaluation. There are no BC
@ -316,32 +328,35 @@ class ComptimeContext:
"""
return self.__tx
def sleep(self, sec):
def sleep(self, sec: Union[int, float]) -> None:
time.sleep(sec)
class _Comptime:
@staticmethod
def __call__(fn, fallback_fn=lambda: None):
def __call__(
fn: Callable[[ComptimeContext], Any],
fallback_fn: Callable[[], Any] = lambda: None,
) -> Any:
"""fn gets called at compile time in TorchDynamo, calls fallback_fn otherwise"""
fallback_fn()
# Convenience wrappers that are more compact to use
@staticmethod
def graph_break():
def graph_break() -> None:
comptime(lambda ctx: ctx.graph_break())
@staticmethod
def print(e):
def print(e: Any) -> None:
comptime(lambda ctx: ctx.print(ctx.get_local("e")), lambda: print(e))
@staticmethod
def print_graph():
def print_graph() -> None:
comptime(lambda ctx: ctx.print_graph())
@staticmethod
def print_disas(*, stacklevel=0):
def print_disas(*, stacklevel: int = 0) -> None:
comptime(
lambda ctx: ctx.print_disas(
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
@ -349,7 +364,7 @@ class _Comptime:
)
@staticmethod
def print_value_stack(*, stacklevel=0):
def print_value_stack(*, stacklevel: int = 0) -> None:
comptime(
lambda ctx: ctx.print_value_stack(
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
@ -360,7 +375,7 @@ class _Comptime:
# in an expression context; e.g., x + print_value_stack_and_return(y + z),
# you will see x on the stack prior to the addition operation
@staticmethod
def print_value_stack_and_return(e, *, stacklevel=0):
def print_value_stack_and_return(e: Any, *, stacklevel: int = 0) -> Any:
comptime(
lambda ctx: ctx.print_value_stack(
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
@ -369,7 +384,7 @@ class _Comptime:
return e
@staticmethod
def print_locals(*, stacklevel=0):
def print_locals(*, stacklevel: int = 0) -> None:
comptime(
lambda ctx: ctx.print_locals(
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
@ -377,7 +392,7 @@ class _Comptime:
)
@staticmethod
def print_bt(*, stacklevel=0):
def print_bt(*, stacklevel: int = 0) -> None:
comptime(
lambda ctx: ctx.print_bt(
stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
@ -385,19 +400,19 @@ class _Comptime:
)
@staticmethod
def print_guards():
def print_guards() -> None:
comptime(lambda ctx: ctx.print_guards())
@staticmethod
def assert_static(val):
def assert_static(val: Any) -> None:
comptime(lambda ctx: ctx.assert_static(ctx.get_local("val")))
@staticmethod
def force_static(val):
def force_static(val: Any) -> None:
comptime(lambda ctx: ctx.get_local("val").force_static())
@staticmethod
def breakpoint():
def breakpoint() -> None:
"""
Like pdb breakpoint(), but drop into pdb whenever this line
of code is compiled by dynamo. Use it by putting
@ -415,14 +430,14 @@ class _Comptime:
(Pdb) p ctx.get_local("attention").as_fake()
"""
def inner(inner_ctx):
def inner(inner_ctx: ComptimeContext) -> None:
ctx = inner_ctx.parent() # noqa: F841
builtins.breakpoint()
comptime(inner)
@staticmethod
def sleep(sec):
def sleep(sec: Union[int, float]) -> None:
comptime(lambda ctx: ctx.sleep(ctx.get_local("sec").as_python_constant()))

View File

@ -1,5 +1,3 @@
# mypy: allow-untyped-decorators
"""
This module implements TorchDynamo's core frame conversion functionality, transforming Python
frames into FX graphs. It handles:

View File

@ -1705,7 +1705,7 @@ def export(
_log_export_usage: bool = True,
constraints: Optional[list[Constraint]] = None,
**extra_kwargs: Any,
) -> Callable[[tuple[Any, Any]], ExportResult]:
) -> Callable[..., ExportResult]:
"""
Export an input function f to a format that can be executed outside of PyTorch using the FX graph.