From a7f3bdf550635c796e53442375477efe98fe5447 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Mon, 4 Aug 2025 21:51:48 +0000 Subject: [PATCH] [Dynamo][Better Engineering] Type coverage for `torch/_dynamo/utils.py` (#159580) As part of better engineering effort, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to `torch/_dynamo/utils.py` Running ``` mypy torch/_dynamo/utils.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Annotated | Lines Total | % lines covered | Funcs Annotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 2163 | 4792 | 45.14% | 121 | 268 | 45.15% | | This PR | 4818 | 4818 | 100.00% | 268 | 268 | 100.00% | | Delta | +2655 | +26 | +54.84% | +147 | 0 | +54.85% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/159580 Approved by: https://github.com/williamwen42 --- torch/_dynamo/replay_record.py | 7 +- torch/_dynamo/utils.py | 573 ++++++++++++++++++--------------- 2 files changed, 313 insertions(+), 267 deletions(-) diff --git a/torch/_dynamo/replay_record.py b/torch/_dynamo/replay_record.py index b131160db25e..5d01217fdbb6 100644 --- a/torch/_dynamo/replay_record.py +++ b/torch/_dynamo/replay_record.py @@ -15,8 +15,9 @@ and recreate specific program states. import dataclasses from dataclasses import field +from io import BufferedReader, BufferedWriter from types import CellType, CodeType, ModuleType -from typing import Any, IO +from typing import Any, IO, Union from typing_extensions import Self from torch.utils._import_utils import import_dill @@ -51,12 +52,12 @@ class ExecutionRecord: builtins: dict[str, Any] = field(default_factory=dict) code_options: dict[str, Any] = field(default_factory=dict) - def dump(self, f: IO[str]) -> None: + def dump(self, f: Union[IO[str], BufferedWriter]) -> None: assert dill is not None, "replay_record requires `pip install dill`" dill.dump(self, f) @classmethod - def load(cls, f: IO[bytes]) -> Self: + def load(cls, f: Union[IO[bytes], BufferedReader]) -> Self: assert dill is not None, "replay_record requires `pip install dill`" return dill.load(f) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 575fe901fc15..588f1ddb99a1 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Utility functions and classes used throughout the TorchDynamo system. @@ -62,7 +60,7 @@ from typing import ( TypeVar, Union, ) -from typing_extensions import Literal, TypeAlias, TypeGuard, TypeIs +from typing_extensions import Literal, ParamSpec, TypeAlias, TypeGuard, TypeIs import torch import torch._functorch.config @@ -106,6 +104,14 @@ if typing.TYPE_CHECKING: ValuesView, ) + from torch._dynamo.replay_record import ExecutionRecord + from torch._dynamo.symbolic_convert import ( + InstructionTranslator, + InstructionTranslatorBase, + ) + from torch._dynamo.variables.base import VariableTracker + from torch._prims_common import DeviceLikeType + try: import numpy as np @@ -145,6 +151,7 @@ except ImportError: T = TypeVar("T") +_P = ParamSpec("_P") unpatched_nn_module_getattr = torch.nn.Module.__getattr__ unpatched_nn_module_call = torch.nn.Module.__call__ @@ -184,43 +191,43 @@ class ReinplaceCounters: # Track sizes of known not re-inplaced tensors (exclude dynamic shapes). @classmethod - def add_missed_bytes(cls, trigger: ReInplaceTrigger, bytes: int): + def add_missed_bytes(cls, trigger: ReInplaceTrigger, bytes: int) -> None: if bytes != 0: cls._values[f"missed_bytes_{trigger.name}"] += bytes # Track number of not re-inplaced tensors. @classmethod - def add_missed_opportunities(cls, trigger: ReInplaceTrigger, count: int): + def add_missed_opportunities(cls, trigger: ReInplaceTrigger, count: int) -> None: if count != 0: cls._values[f"missed_tensors_{trigger}"] += count @classmethod - def clear(cls): + def clear(cls) -> None: cls._values.clear() @classmethod - def get_total_missed(cls): + def get_total_missed(cls) -> int: sum = 0 for trigger in ReInplaceTrigger: sum += cls._values.get(f"missed_tensors_{trigger}", 0) return sum @classmethod - def get_total_missed_bytes(cls): + def get_total_missed_bytes(cls) -> int: sum = 0 for trigger in ReInplaceTrigger: sum += cls._values.get(f"missed_bytes_{trigger.name}", 0) return sum @classmethod - def log(cls): + def log(cls) -> None: # if not empty log. if cls._values: signpost_event("inductor", "reinplace_counters", cls._values) def tabulate( - rows: Union[list[tuple[str, object]], list[list[object]]], + rows: Union[list[tuple[str, Any]], list[list[Any]]], headers: Union[tuple[str, ...], list[str]], ) -> str: try: @@ -385,7 +392,7 @@ class CompileEventLogger: metadata: dict[str, Any], time_ns: Optional[int] = None, log_level: CompileEventLogLevel = CompileEventLogLevel.CHROMIUM, - ): + ) -> None: if time_ns is None: time_ns = time.time_ns() chromium_log = get_chromium_event_logger() @@ -407,7 +414,7 @@ class CompileEventLogger: log_level: CompileEventLogLevel, overwrite: bool = False, **metadata: object, - ): + ) -> None: """ Centralized API for adding data to various events Log an event to a toplevel "dynamo" event or metrics context @@ -450,7 +457,7 @@ class CompileEventLogger: @staticmethod def add_toplevel( log_level: CompileEventLogLevel, overwrite: bool = False, **metadata: object - ): + ) -> None: """ Syntactic sugar for logging to the toplevel event """ @@ -464,7 +471,7 @@ class CompileEventLogger: @staticmethod def increment( event_name: str, log_level: CompileEventLogLevel, key: str, value: int - ): + ) -> None: """ Increments an existing field, or adds it """ @@ -497,7 +504,7 @@ class CompileEventLogger: key: str, value: int = 1, log_level: CompileEventLogLevel = CompileEventLogLevel.COMPILATION_METRIC, - ): + ) -> None: """ Increments a value on the toplevel metric. By default, logs to metric. """ @@ -512,7 +519,7 @@ class CompileEventLogger: @staticmethod def add_to_set( event_name: str, log_level: CompileEventLogLevel, key: str, value: Any - ): + ) -> None: """ Add metadata to a set of values with key . Creates a set if it doesn't exist. """ @@ -545,7 +552,7 @@ class CompileEventLogger: key: str, value: Any, log_level: CompileEventLogLevel = CompileEventLogLevel.COMPILATION_METRIC, - ): + ) -> None: """ Same as add to set, just does it automatically to the toplevel event instead of having to explicitly name it. Defaults to COMPILATION_METRIC log level. @@ -561,7 +568,7 @@ class CompileEventLogger: # Helper functions that are syntactic sugar @staticmethod - def chromium(event_name: str, **metadata: object): + def chromium(event_name: str, **metadata: object) -> None: """ Add to in chromium. Each key/value of metadata will appear in the chromium trace. should be the name of a timed event span passed to `dynamo_timed`. @@ -571,7 +578,7 @@ class CompileEventLogger: ) @staticmethod - def pt2_compile(event_name: str, **metadata: object): + def pt2_compile(event_name: str, **metadata: object) -> None: """ Add to in chromium and PT2 Compile Events. Each key/value of metadata will appear in the chromium trace. Each kwarg name becomes @@ -584,7 +591,7 @@ class CompileEventLogger: ) @staticmethod - def compilation_metric(overwrite: bool = False, **metadata: object): + def compilation_metric(overwrite: bool = False, **metadata: object) -> None: """ Add to the CompilationMetrics context. Also logs to PT2 Compile Events and chromium. @@ -598,7 +605,7 @@ class CompileEventLogger: @staticmethod def instant( event_name: str, metadata: dict[str, Any], time_ns: Optional[int] = None - ): + ) -> None: """ Log an instant event to chromium logs with name at time . The `args` field in Perfetto will point to metadata. should be a value obtained from time.time_ns(). @@ -608,7 +615,7 @@ class CompileEventLogger: ) @staticmethod - def try_add_pt2_compile(event_name: str, **metadata: object): + def try_add_pt2_compile(event_name: str, **metadata: object) -> None: """ Adds to an existing pt2_compile event, but silently returns if the event doesn't exist or ChromiumEventLogger is not initialized. @@ -620,7 +627,7 @@ class CompileEventLogger: chromium_log.try_add_event_data(event_name, **metadata) @staticmethod - def try_(method_fn, *args, **kwargs): + def try_(method_fn: Callable[_P, Any], *args: _P.args, **kwargs: _P.kwargs) -> None: """ Special function that quietly runs a given method, returning if CHROMIUM_EVENT_LOG is None or metrics context is not set """ @@ -791,7 +798,9 @@ def compile_times( ) -> tuple[list[str], list[object]]: ... -def compile_times(repr="str", aggregate: bool = False): +def compile_times( # type: ignore[misc] + repr: str = "str", aggregate: bool = False +) -> Union[str, None, tuple[list[str], list[str]]]: """ Get metrics about torchdynamo frontend/backend compilation times. @@ -805,7 +814,7 @@ def compile_times(repr="str", aggregate: bool = False): per metric. """ - def fmt_fn(values, item_fn=lambda x: x): + def fmt_fn(values: list[float], item_fn: Callable[[float], str] = str) -> str: if aggregate: return item_fn(sum(values)) return ", ".join(map(item_fn, values)) @@ -852,8 +861,8 @@ class DuplicateWarningChecker: self.maxsize = maxsize self.reset() - def reset(self): - self.set = OrderedDict() + def reset(self) -> None: + self.set: OrderedDict[Any, Any] = OrderedDict() def add(self, key: Union[str, tuple[object, object]]) -> bool: if key in self.set: @@ -870,7 +879,7 @@ class DuplicateWarningChecker: graph_break_dup_warning_checker = DuplicateWarningChecker() -def setup_compile_debug(): +def setup_compile_debug() -> contextlib.ExitStack: compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" if compile_debug: @@ -883,7 +892,7 @@ def reset_graph_break_dup_checker() -> None: graph_break_dup_warning_checker.reset() -def add_file_handler(): +def add_file_handler() -> contextlib.ExitStack: log_path = os.path.join(get_debug_dir(), "torchdynamo") os.makedirs(log_path, exist_ok=True) @@ -896,7 +905,7 @@ def add_file_handler(): return exitstack -def setup_log_file(): +def setup_log_file() -> contextlib.ExitStack: exitstack = contextlib.ExitStack() if config.log_file_name is not None: log_file_handler = logging.FileHandler(config.log_file_name) @@ -908,12 +917,12 @@ def setup_log_file(): return exitstack -def gen_record_file_name(exc, code) -> str: +def gen_record_file_name(exc: Exception, code: CodeType) -> str: return f"{get_debug_dir()}/error_recordings/\ {code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec" -def write_record_to_file(filename: str, exec_record) -> None: +def write_record_to_file(filename: str, exec_record: ExecutionRecord) -> None: try: if os.path.exists(filename): log.warning( @@ -939,7 +948,7 @@ def identity(x: T) -> T: return x -def hashable(x): +def hashable(x: Any) -> bool: try: hash(x) return True @@ -950,39 +959,39 @@ def hashable(x): return False -def nothing(*args, **kwargs): +def nothing(*args: Any, **kwargs: Any) -> None: pass class ExactWeakKeyDictionary: """Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality""" - def __init__(self): - self.values = {} - self.refs = {} + def __init__(self) -> None: + self.values: dict[int, Any] = {} + self.refs: dict[int, weakref.ReferenceType[Any]] = {} - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Any: return self.values[id(key)] - def get(self, key, default=None): + def get(self, key: Any, default: Any = None) -> Any: return self.values.get(id(key), default) - def __contains__(self, key): + def __contains__(self, key: Any) -> bool: return id(key) in self.values - def __setitem__(self, key, value): + def __setitem__(self, key: Any, value: Any) -> None: idx = id(key) if idx not in self.refs: self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx)) self.values[idx] = value - def _remove_id(self, idx): + def _remove_id(self, idx: int) -> None: if idx in self.values: del self.values[idx] if idx in self.refs: del self.refs[idx] - def clear(self): + def clear(self) -> None: self.refs.clear() self.values.clear() @@ -1001,7 +1010,7 @@ def istype( def istype(obj: object, allowed_types: Iterable[type]) -> bool: ... -def istype(obj, allowed_types): +def istype(obj: object, allowed_types: Any) -> bool: """isinstance() without subclasses""" if isinstance(allowed_types, (tuple, list, set)): return type(obj) in allowed_types @@ -1021,7 +1030,7 @@ if sys.version_info >= (3, 12): ) -def is_typing(value): +def is_typing(value: Any) -> bool: # _Final catches most of typing classes: # - Any # - Callable @@ -1035,7 +1044,7 @@ def is_typing(value): return isinstance(value, typing._Final) or value is typing.Generic # type: ignore[attr-defined] -def is_numpy_int_type(value): +def is_numpy_int_type(value: Any) -> bool: if not np: return False @@ -1054,7 +1063,7 @@ def is_numpy_int_type(value): ) -def is_numpy_float_type(value): +def is_numpy_float_type(value: Any) -> bool: if not np: return False @@ -1166,11 +1175,11 @@ def is_wrapper_or_member_descriptor( ) -def unwrap_if_wrapper(fn): +def unwrap_if_wrapper(fn: Any) -> Any: return unwrap_with_attr_name_if_wrapper(fn)[0] -def unwrap_with_attr_name_if_wrapper(fn): +def unwrap_with_attr_name_if_wrapper(fn: Any) -> tuple[Any, Optional[str]]: # TODO(anijain2305) - Investigate if we can get rid of this function # unpack @torch._dynamo.optimize()(fn) wrapped function if is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False): @@ -1181,14 +1190,14 @@ def unwrap_with_attr_name_if_wrapper(fn): return fn, attr_name -def is_numpy_ndarray(value): +def is_numpy_ndarray(value: Any) -> TypeGuard[np.ndarray]: # type: ignore[type-arg] if not np: return False return istype(value, np.ndarray) -def istensor(obj): +def istensor(obj: Any) -> bool: """Check of obj is a tensor""" tensor_list: tuple[type, ...] = ( torch.Tensor, @@ -1199,27 +1208,27 @@ def istensor(obj): return istype(obj, tensor_list) -def is_lazy_module(mod): +def is_lazy_module(mod: Any) -> bool: return isinstance(mod, LazyModuleMixin) @functools.lru_cache(4096) -def print_once(*args): +def print_once(*args: Any) -> None: print(*args) -def make_cell(val=None): +def make_cell(val: Any = None) -> types.CellType: """Some black magic to create a cell object that usually only exists in a closure""" x = val - def f(): + def f() -> Any: return x assert f.__closure__ is not None and len(f.__closure__) == 1 return f.__closure__[0] -def proxy_args_kwargs(args, kwargs): +def proxy_args_kwargs(args: Any, kwargs: Any) -> tuple[tuple[Any, ...], dict[str, Any]]: try: proxy_args = tuple(arg.as_proxy() for arg in args) proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} @@ -1350,7 +1359,7 @@ class CompilationMetrics: recompile_user_contexts: Optional[set[str]] = None @classmethod - def create(cls, metrics: dict[str, Any]): + def create(cls, metrics: dict[str, Any]) -> CompilationMetrics: """ Factory method to create a CompilationMetrics from a dict of fields. Includes the logic to add legacy fields and any pre-processing, e.g., @@ -1475,15 +1484,15 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics) -> None: fail_user_frame_filename=c.fail_user_frame_filename, fail_user_frame_lineno=c.fail_user_frame_lineno, # Sets aren't JSON serializable - non_compliant_ops=list(c.non_compliant_ops) - if c.non_compliant_ops is not None - else None, - compliant_custom_ops=list(c.compliant_custom_ops) - if c.compliant_custom_ops is not None - else None, - restart_reasons=list(c.restart_reasons) - if c.restart_reasons is not None - else None, + non_compliant_ops=( + list(c.non_compliant_ops) if c.non_compliant_ops is not None else None + ), + compliant_custom_ops=( + list(c.compliant_custom_ops) if c.compliant_custom_ops is not None else None + ), + restart_reasons=( + list(c.restart_reasons) if c.restart_reasons is not None else None + ), dynamo_time_before_restart_s=c.dynamo_time_before_restart_s, has_guarded_code=c.has_guarded_code, dynamo_config=c.dynamo_config, @@ -1533,7 +1542,7 @@ def _scrubbed_inductor_config_for_logging() -> Optional[str]: # TypeSafeSerializer for json.dumps() # Skips complex types as values in config dict class TypeSafeSerializer(json.JSONEncoder): - def default(self, o): + def default(self, o: Any) -> Any: try: return super().default(o) except Exception: @@ -1574,7 +1583,7 @@ def record_compilation_metrics( metrics: dict[str, Any], exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], -): +) -> None: if torch._inductor.utils.should_use_remote_fx_graph_cache(): try: from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION @@ -1696,7 +1705,7 @@ class ChromiumEventLogger: stack = self.get_stack() return stack[0] if stack else None - def get_pt2_compile_substack(self): + def get_pt2_compile_substack(self) -> list[str]: """ A smaller subset of the main stack that gets used to log PT2 Compile Events internally. @@ -1712,7 +1721,7 @@ class ChromiumEventLogger: self.tls.event_data = {} return self.tls.event_data - def __init__(self): + def __init__(self) -> None: self.tls = threading.local() from . import config @@ -1727,7 +1736,7 @@ class ChromiumEventLogger: # TODO: log to init/id tlparse after I add support for it log.info("ChromiumEventLogger initialized with id %s", self.id_) - def try_add_event_data(self, event_name: str, **kwargs) -> None: + def try_add_event_data(self, event_name: str, **kwargs: Any) -> None: """ Same as add_event_data, but will silently not log if the event isn't in the stack. """ @@ -1738,7 +1747,7 @@ class ChromiumEventLogger: def add_event_data( self, event_name: str, - **kwargs, + **kwargs: Any, ) -> None: """ Adds additional metadata info to an in-progress event @@ -1755,7 +1764,7 @@ class ChromiumEventLogger: event_data[event_name] = {} event_data[event_name].update(kwargs) - def increment(self, event_name: str, key: str, value: int): + def increment(self, event_name: str, key: str, value: int) -> None: """ Increment an integer event data field by the given amount """ @@ -1778,7 +1787,7 @@ class ChromiumEventLogger: event_name: str, key: str, value: Any, - ): + ) -> None: """ Add a value to a set within a event_name's metadata if it exists """ @@ -1874,7 +1883,7 @@ class ChromiumEventLogger: event_metadata, ) - def pop_stack(stack): + def pop_stack(stack: list[str]) -> None: while event_name != stack[-1]: # If the event isn't the most recent one to end, pop # off the stack until it is. @@ -2035,14 +2044,14 @@ class CleanupHook: scope: dict[str, Any] name: str - def __call__(self, *args): + def __call__(self, *args: Any) -> None: # Make sure we're not shutting down if CleanupManager is not None: CleanupManager.count -= 1 del self.scope[self.name] @staticmethod - def create(scope, name, val): + def create(scope: dict[str, Any], name: str, val: Any) -> CleanupHook: assert name not in scope CleanupManager.count += 1 scope[name] = val @@ -2053,7 +2062,7 @@ class CleanupManager(ExactWeakKeyDictionary): count = 0 instance: ClassVar[CleanupManager] - def _remove_id(self, idx): + def _remove_id(self, idx: int) -> None: for hook in self.values[idx]: hook() super()._remove_id(idx) @@ -2062,7 +2071,7 @@ class CleanupManager(ExactWeakKeyDictionary): CleanupManager.instance = CleanupManager() -def clone_tensor(x): +def clone_tensor(x: torch.Tensor) -> torch.Tensor: """Clone the tensor and its gradient""" y = x.clone().requires_grad_(x.requires_grad) if x.is_leaf and x.grad is not None: @@ -2070,14 +2079,16 @@ def clone_tensor(x): return y -def clone_input(x, *, dtype=None): +def clone_input( + x: torch.Tensor, *, dtype: Optional[torch.dtype] = None +) -> torch.Tensor: """copy while preserving strides""" # TODO: this is questionable if is_fake(x): # this func fails on fake tensors in __torch_dispatch__ return x - def torch_clone(x): + def torch_clone(x: torch.Tensor) -> torch.Tensor: y = torch.clone(x) if x.is_leaf: y.requires_grad_(x.requires_grad) @@ -2154,7 +2165,7 @@ def clone_inputs( def clone_inputs(example_inputs: Sequence[T]) -> list[T]: ... -def clone_inputs(example_inputs): +def clone_inputs(example_inputs: Any) -> Any: res: Union[dict[str, Any], list[Any]] if type(example_inputs) is dict: res = dict(example_inputs) @@ -2173,7 +2184,7 @@ def clone_inputs(example_inputs): return res -def skip_frame_if_in_functorch_mode(val: torch.Tensor): +def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None: try: val.data_ptr() # will throw for functorch tensors except RuntimeError as e: @@ -2187,7 +2198,7 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor): @contextmanager -def preserve_rng_state(): +def preserve_rng_state() -> Generator[None, None, None]: disable_functorch = torch._C._DisableFuncTorch disable_current_modes = torch.utils._python_dispatch._disable_current_modes with disable_current_modes(), disable_functorch(): @@ -2205,8 +2216,8 @@ def preserve_rng_state(): def is_jit_model( - model0, -): + model0: Any, +) -> bool: return isinstance( model0, ( @@ -2218,7 +2229,7 @@ def is_jit_model( ) -def torchscript(model, example_inputs, verbose=False): +def torchscript(model: Any, example_inputs: Any, verbose: bool = False) -> Any: if is_jit_model(model): # already done? return model @@ -2243,12 +2254,12 @@ def getfile(obj: Any) -> Optional[str]: return None -def is_namedtuple(obj): +def is_namedtuple(obj: Any) -> bool: """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple""" return is_namedtuple_cls(type(obj)) -def is_namedtuple_cls(cls): +def is_namedtuple_cls(cls: Any) -> bool: """Test if an object is a namedtuple or a (torch.return_types|torch.autograd.forward_ad).* quasi-namedtuple""" try: if issubclass(cls, tuple): @@ -2279,7 +2290,7 @@ def is_namedtuple_cls(cls): @functools.lru_cache(1) -def namedtuple_fields(cls) -> tuple[str, ...]: +def namedtuple_fields(cls: type) -> tuple[str, ...]: """Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple""" if cls is slice: return ("start", "stop", "step") @@ -2295,16 +2306,16 @@ def namedtuple_fields(cls) -> tuple[str, ...]: # frustrating ones e.g. torch.return_types.max assert cls.__module__ == "torch.return_types" - obj = cls(map(Marker, range(cls.n_fields))) + obj = cls(map(Marker, range(cls.n_fields))) # type: ignore[attr-defined] fields: dict[str, int] = {} for name in dir(obj): if name[0] != "_" and isinstance(getattr(obj, name), Marker): fields[name] = getattr(obj, name).index - assert len(fields) == cls.n_fields + assert len(fields) == cls.n_fields # type: ignore[attr-defined] return tuple(sorted(fields, key=fields.get)) # type: ignore[arg-type] -def checkpoint_params(gm): +def checkpoint_params(gm: torch.fx.GraphModule) -> Callable[[], None]: with torch.no_grad(): rng_state = torch.clone(torch.random.get_rng_state()) if torch.cuda.is_available(): @@ -2314,7 +2325,7 @@ def checkpoint_params(gm): for param in itertools.chain(gm.parameters(), gm.buffers()) ] - def restore(): + def restore() -> None: with torch.no_grad(): torch.random.set_rng_state(rng_state) if torch.cuda.is_available(): @@ -2326,7 +2337,7 @@ def checkpoint_params(gm): return restore -def timed(model, example_inputs, times=1): +def timed(model: Any, example_inputs: Any, times: int = 1) -> tuple[Any, float]: if torch.cuda.is_available(): synchronize = torch.cuda.synchronize else: @@ -2343,12 +2354,12 @@ def timed(model, example_inputs, times=1): return result, t1 - t0 # type: ignore[possibly-undefined] -def check_is_cuda(gm, example_inputs): +def check_is_cuda(gm: torch.fx.GraphModule, example_inputs: Any) -> bool: return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True))) @lru_cache(32) -def rot_n_helper(n): +def rot_n_helper(n: int) -> Callable[..., Any]: assert n > 1 vars = [f"v{i}" for i in range(n)] rotated = reversed(vars[-1:] + vars[:-1]) @@ -2392,7 +2403,7 @@ if has_triton_package(): """ -def is_safe_constant(v): +def is_safe_constant(v: Any) -> bool: if istype(v, (tuple, frozenset)): return all(map(is_safe_constant, v)) return isinstance( @@ -2411,7 +2422,7 @@ def is_safe_constant(v): @functools.cache -def common_constants(): +def common_constants() -> set[int]: return { # We zero-one specialize shapes, so specialize these constants # too @@ -2426,7 +2437,7 @@ def is_torch_sym(value: Any) -> TypeGuard[Union[torch.SymBool, torch.SymInt]]: ) -def is_int_specialization_case(value, source): +def is_int_specialization_case(value: Any, source: Any) -> bool: from .source import is_from_defaults return not TracingContext.get().force_unspec_int_unbacked_size_like and ( @@ -2457,7 +2468,7 @@ def is_int_specialization_case(value, source): ) -def specialize_symnode(arg): +def specialize_symnode(arg: Any) -> Any: from .variables import ConstantVariable, LazyVariableTracker, SymNodeVariable # Guard and specialize @@ -2482,7 +2493,7 @@ def specialize_symnode(arg): return arg -def guard_if_dyn(arg): +def guard_if_dyn(arg: Any) -> Any: from .variables import ConstantVariable arg = specialize_symnode(arg) @@ -2493,11 +2504,11 @@ def guard_if_dyn(arg): return arg -def check_constant_args(args, kwargs): +def check_constant_args(args: Any, kwargs: Any) -> bool: return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values())) -def check_unspec_python_args(args, kwargs): +def check_unspec_python_args(args: Any, kwargs: Any) -> bool: from .variables.constant import ConstantVariable from .variables.tensor import UnspecializedPythonVariable @@ -2510,7 +2521,7 @@ def check_unspec_python_args(args, kwargs): return unspec_count > 0 -def check_unspec_or_constant_args(args, kwargs): +def check_unspec_or_constant_args(args: Any, kwargs: Any) -> bool: # A fused version of: # return check_constant_args(args, kwargs) or check_unspec_python_args(args, kwargs) from .variables.tensor import UnspecializedPythonVariable @@ -2521,7 +2532,7 @@ def check_unspec_or_constant_args(args, kwargs): return True -def check_numpy_ndarray_args(args, kwargs): +def check_numpy_ndarray_args(args: Any, kwargs: Any) -> bool: from .variables.tensor import NumpyNdarrayVariable return any( @@ -2557,13 +2568,13 @@ list_getitem = list.__getitem__ str_methods = {method for method in str.__dict__.values() if callable(method)} -def builtin_dict_keys(d): +def builtin_dict_keys(d: dict[Any, Any]) -> KeysView[Any]: # Avoids overridden keys method of the dictionary assert isinstance(d, dict) return dict.keys(d) -def get_items_from_dict(obj): +def get_items_from_dict(obj: dict[Any, Any]) -> Any: # Get items without calling the user defined __getitem__ or keys method. assert isinstance(obj, dict) if istype(obj, (dict, OrderedDict)): @@ -2574,29 +2585,29 @@ def get_items_from_dict(obj): return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)] -def nn_module_new(cls): +def nn_module_new(cls: Any) -> Any: obj = object_new(cls) torch.nn.Module.__init__(obj) return obj -def product(it): +def product(it: Iterable[Any]) -> Any: return functools.reduce(operator.mul, it, 1) -def tuple_iterator_getitem(it, index): +def tuple_iterator_getitem(it: Any, index: int) -> Any: _, (obj,), start = it.__reduce__() return obj[start + index] -def dataclass_fields(cls): +def dataclass_fields(cls: Any) -> Any: return torch._dynamo.disable(dataclasses.fields)(cls) iter_next = next -def normalize_range_iter(range_iter) -> tuple[int, int, int]: +def normalize_range_iter(range_iter: Any) -> tuple[int, int, int]: _, (range_obj,), maybe_idx = range_iter.__reduce__() # In 3.12+, `maybe_idx` could be None, and `range_obj.start` would've been # already incremented by the current index. @@ -2606,14 +2617,14 @@ def normalize_range_iter(range_iter) -> tuple[int, int, int]: return (start, stop, step) -def to_subclass(t, cls): +def to_subclass(t: Any, cls: type) -> Any: return t.as_subclass(cls) dict_getitem = dict.__getitem__ -def dict_keys_getitem(d, n): +def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any: # Call dict(d) to prevent calling overridden __iter__/keys dict_class = dict if isinstance(d, OrderedDict): @@ -2621,12 +2632,12 @@ def dict_keys_getitem(d, n): return next(itertools.islice(dict_class.keys(d), n, n + 1)) -def set_getitem(s, n): +def set_getitem(s: set[Any], n: int) -> Any: # Set ordering might not be stable return list(s)[n] -def enum_repr(value, local): +def enum_repr(value: Any, local: bool) -> str: # enum class can override __str__ method. Use __class__ and name attribute # to extract the class name and key name. name = value.__class__.__name__ @@ -2636,7 +2647,7 @@ def enum_repr(value, local): return local_name -def set_example_value(node, example_value): +def set_example_value(node: torch.fx.Node, example_value: Any) -> None: # NB: example_value is a bit of a misnomer, because this is always a fake # tensor of some sort. Furthermore, these example values serve as the # runtime state of Dynamo tracing, which means if metadata mutation @@ -2656,7 +2667,7 @@ def set_example_value(node, example_value): node.meta["unbacked_bindings"] = symbol_to_path -def _get_fake_tensor(vt): +def _get_fake_tensor(vt: VariableTracker) -> Any: fake_tensor = vt.as_proxy().node.meta.get("example_value") if not is_fake(fake_tensor): from . import graph_break_hints @@ -2676,7 +2687,7 @@ def slice_length(s: slice, seq_len: int) -> int: return max(0, (stop - start + (step - (1 if step > 0 else -1))) // step) -def raise_args_mismatch(tx, name): +def raise_args_mismatch(tx: InstructionTranslatorBase, name: str) -> None: from torch._dynamo.exc import raise_observed_exception from torch._dynamo.variables import ConstantVariable @@ -2687,13 +2698,13 @@ def raise_args_mismatch(tx, name): ) -def iter_contains(items, search, tx, check_tensor_identity=False): - from .variables import ( - BuiltinVariable, - ConstantVariable, - TensorVariable, - VariableTracker, - ) +def iter_contains( + items: Any, + search: Any, + tx: InstructionTranslator, + check_tensor_identity: bool = False, +) -> Any: + from .variables import BuiltinVariable, ConstantVariable, TensorVariable if search.is_python_constant(): found_const = any( @@ -2735,11 +2746,11 @@ def key_is_id( return isinstance(k, (torch.Tensor, torch.nn.Module, MethodWrapperType)) -def key_to_id(value): +def key_to_id(value: Any) -> list[Any]: return [id(k) if key_is_id(k) else k for k in value.keys()] -def const_repr(x, *, local) -> str: +def const_repr(x: Any, *, local: Any) -> str: from .trace_rules import is_builtin_callable if isinstance(x, (list, tuple)): @@ -2760,7 +2771,7 @@ def const_repr(x, *, local) -> str: return x.__name__ elif isinstance(x, type): - def fullname(o): + def fullname(o: Any) -> str: klass = o.__class__ module = klass.__module__ if module == "builtins": @@ -2772,7 +2783,7 @@ def const_repr(x, *, local) -> str: return f"{x!r}" -def dict_keys_repr(const_keys, *, local) -> str: +def dict_keys_repr(const_keys: Any, *, local: Any) -> str: keys_str = ",".join(const_repr(s, local=local) for s in const_keys) return "[" + keys_str + "]" @@ -2783,7 +2794,7 @@ GLOBAL_KEY_PREFIX = "__dict_key" from torch._subclasses import UnsupportedFakeTensorException # noqa: F401 -def get_safe_global_name(tx, root, obj): +def get_safe_global_name(tx: InstructionTranslatorBase, root: str, obj: Any) -> str: # The global_mangled_class_name should be different for different # invocations of torch.compile. Otherwise, we can run into a situation # where multiple torch.compile invocations reuse the same global name, @@ -2793,14 +2804,16 @@ def get_safe_global_name(tx, root, obj): return f"{root}_{id(obj)}_c{tx.output.compile_id}" -def is_in(item: Any, *containers) -> bool: +def is_in(item: str, *containers: Any) -> bool: for container in containers: if item in container: return True return False -def get_unique_name_wrt(prefix: str, *containers, requires_suffix=False) -> str: +def get_unique_name_wrt( + prefix: str, *containers: Any, requires_suffix: bool = False +) -> str: """ Return a name that starts with `prefix` and is not in any of the `containers` (e.g., map, set). @@ -2816,7 +2829,7 @@ def get_unique_name_wrt(prefix: str, *containers, requires_suffix=False) -> str: raise AssertionError("unreachable") -def wrap_fake_exception(fn): +def wrap_fake_exception(fn: Callable[[], Any]) -> Any: try: return fn() except UnsupportedFakeTensorException as e: @@ -2833,12 +2846,14 @@ def wrap_fake_exception(fn): ) -def deepcopy_to_fake_tensor(obj, fake_mode): +def deepcopy_to_fake_tensor( + obj: Any, fake_mode: torch._subclasses.fake_tensor.FakeTensorMode +) -> Any: with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode): return wrap_fake_exception(lambda: copy.deepcopy(obj)) -def rmse(ref, res): +def rmse(ref: torch.Tensor, res: torch.Tensor) -> torch.Tensor: """ Calculate root mean squared error """ @@ -2846,19 +2861,19 @@ def rmse(ref, res): def same( - ref, - res, - fp64_ref=None, - cos_similarity=False, - tol=1e-4, - equal_nan=False, - exact_dtype=True, - relax_numpy_equality=False, - ignore_non_fp=False, - log_error=log.error, - use_larger_multiplier_for_smaller_tensor=False, + ref: Any, + res: Any, + fp64_ref: Any = None, + cos_similarity: bool = False, + tol: float = 1e-4, + equal_nan: bool = False, + exact_dtype: bool = True, + relax_numpy_equality: bool = False, + ignore_non_fp: bool = False, + log_error: Callable[..., None] = log.error, + use_larger_multiplier_for_smaller_tensor: bool = False, force_max_multiplier: bool = False, -): +) -> bool: """Check correctness to see if ref and res match""" if fp64_ref is None: fp64_ref = ref @@ -2939,7 +2954,7 @@ def same( assert not isinstance(ref, torch._subclasses.FakeTensor) assert not isinstance(res, torch._subclasses.FakeTensor) - def to_tensor(t): + def to_tensor(t: Any) -> Any: return t if isinstance(t, torch.Tensor) else torch.tensor(t) ref, res, fp64_ref = (to_tensor(val) for val in (ref, res, fp64_ref)) @@ -2978,7 +2993,7 @@ def same( score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6) if score < 0.99: log.warning("Similarity score=%s", score.detach().cpu().item()) - return score >= 0.99 + return bool(score >= 0.99) else: if not exact_dtype: ref = ref.to(res.dtype) @@ -3018,7 +3033,7 @@ def same( res_error = rmse(fp64_ref, res).item() - def get_multiplier(): + def get_multiplier() -> float: # In some particular cases, we expect high difference in results. # At the moment one of this cases is inductor freezing bfloat16 convolution const folding. # In case of it the res_error is at least one order of magnitude higher. @@ -3149,13 +3164,13 @@ def same( raise RuntimeError(f"unsupported type: {type(ref).__name__}") -def format_func_info(code): +def format_func_info(code: CodeType) -> str: short_filename = code.co_filename.split("/")[-1] return f"'{code.co_name}' ({short_filename}:{code.co_firstlineno})" @contextlib.contextmanager -def disable_cache_limit(): +def disable_cache_limit() -> Generator[None, None, None]: prior = config.recompile_limit config.recompile_limit = sys.maxsize prior_acc_limit = config.accumulated_recompile_limit @@ -3184,7 +3199,7 @@ seen_code_map = ExactWeakKeyDictionary() # return same dir unless user changes config between calls @functools.cache -def _get_debug_dir(root_dir): +def _get_debug_dir(root_dir: str) -> str: dir_name = ( "run_" + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") @@ -3195,12 +3210,12 @@ def _get_debug_dir(root_dir): return os.path.join(root_dir, dir_name) -def get_debug_dir(): +def get_debug_dir() -> str: debug_root = config.debug_dir_root return _get_debug_dir(debug_root) -def extract_fake_example_value(node, required=True): +def extract_fake_example_value(node: torch.fx.Node, required: bool = True) -> Any: if "example_value" in node.meta and is_fake(node.meta["example_value"]): return node.meta["example_value"] elif required: @@ -3218,13 +3233,15 @@ def extract_fake_example_value(node, required=True): return None -def ensure_graph_fake(e, tx): +def ensure_graph_fake(e: Any, tx: InstructionTranslatorBase) -> Any: assert maybe_get_fake_mode(e) is tx.fake_mode return e -def get_fake_values_from_nodes(tx, nodes, allow_non_graph_fake): - def visit(n: torch.fx.Node): +def get_fake_values_from_nodes( + tx: InstructionTranslatorBase, nodes: Any, allow_non_graph_fake: bool +) -> Any: + def visit(n: torch.fx.Node) -> Any: if n.op == "call_function" and "example_value" not in n.meta: # fake tensor validity is checked inside get_fake_value using # ensure_graph_fake @@ -3232,7 +3249,7 @@ def get_fake_values_from_nodes(tx, nodes, allow_non_graph_fake): elif n.op == "get_attr" and "example_value" not in n.meta: assert n.target in tx.output.nn_modules - gm = tx.output.nn_modules[n.target] + gm = tx.output.nn_modules[n.target] # type: ignore[index] assert isinstance(gm, torch.fx.GraphModule) return gm @@ -3244,7 +3261,11 @@ def get_fake_values_from_nodes(tx, nodes, allow_non_graph_fake): return torch.fx.node.map_arg(nodes, visit) -def get_fake_value(node, tx, allow_non_graph_fake=False): +def get_fake_value( + node: torch.fx.Node, + tx: InstructionTranslatorBase, + allow_non_graph_fake: bool = False, +) -> Any: """ Run the computation represented by `node` using fake tensors and return the result. @@ -3293,7 +3314,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): args = (deepcopy_to_fake_tensor(args[0], tx.fake_mode),) + tuple(args[1:]) if op == "call_module": - nnmodule = tx.output.nn_modules[node.target] + nnmodule = tx.output.nn_modules[node.target] # type: ignore[index] if is_lazy_module(nnmodule) and hasattr(nnmodule, "_initialize_hook"): # In the case of a lazy module, we want to run @@ -3310,9 +3331,11 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): ): # We need to specialize symfloats for now. Eventually we should do a tensorify pass in dynamo. args = tuple( - float(arg) - if isinstance(arg, torch.SymFloat) and arg.node.hint is not None - else arg + ( + float(arg) + if isinstance(arg, torch.SymFloat) and arg.node.hint is not None + else arg + ) for arg in args ) @@ -3379,7 +3402,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): elif isinstance( cause, torch._subclasses.fake_tensor.UnsupportedOperatorException ): - op = cause.func + op = cause.func # type: ignore[assignment] import_suggestion = "" if isinstance(op, torch._ops.OpOverload): maybe_pystub = torch._C._dispatch_pystub( @@ -3443,12 +3466,12 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): _current_node = threading.local() -def get_current_node(): +def get_current_node() -> Optional[torch.fx.Node]: return getattr(_current_node, "value", None) @contextmanager -def set_current_node(node): +def set_current_node(node: torch.fx.Node) -> Generator[None, None, None]: old = get_current_node() _current_node.value = node try: @@ -3457,7 +3480,9 @@ def set_current_node(node): _current_node.value = old -def run_node(tracer, node, args, kwargs, nnmodule): +def run_node( + tracer: Any, node: torch.fx.Node, args: Any, kwargs: Any, nnmodule: Any +) -> Any: """ Runs a given node, with the given args and kwargs. @@ -3476,7 +3501,7 @@ def run_node(tracer, node, args, kwargs, nnmodule): with set_current_node(node): - def make_error_message(e): + def make_error_message(e: Any) -> str: return ( f"Dynamo failed to run FX node with fake tensors: {op} {node.target}(*{args}, **{kwargs}): got " + repr(e) @@ -3486,9 +3511,9 @@ def run_node(tracer, node, args, kwargs, nnmodule): try: if op == "call_function": - return node.target(*args, **kwargs) + return node.target(*args, **kwargs) # type: ignore[operator] elif op == "call_method": - if not hasattr(args[0], node.target): + if not hasattr(args[0], node.target): # type: ignore[arg-type] from .exc import unimplemented_v2 unimplemented_v2( @@ -3497,7 +3522,7 @@ def run_node(tracer, node, args, kwargs, nnmodule): explanation=make_error_message("attribute not defined"), hints=[], ) - return getattr(args[0], node.target)(*args[1:], **kwargs) + return getattr(args[0], node.target)(*args[1:], **kwargs) # type: ignore[arg-type] elif op == "call_module": assert nnmodule is not None return nnmodule(*args, **kwargs) @@ -3534,7 +3559,7 @@ def run_node(tracer, node, args, kwargs, nnmodule): raise AssertionError(op) -def get_real_value(node, tracer): +def get_real_value(node: torch.fx.Node, tracer: Any) -> Any: """ Run the actual computation represented by `node` and return the result. This will execute any dependent nodes in the graph as well. @@ -3573,10 +3598,10 @@ def get_real_value(node, tracer): return real_value -def assert_no_fake_params_or_buffers(gm): +def assert_no_fake_params_or_buffers(gm: torch.fx.GraphModule) -> None: from torch._subclasses.fake_tensor import FakeTensorConfig, is_fake - def stack_or_hint(t): + def stack_or_hint(t: Any) -> str: if FakeTensorConfig.debug: import traceback @@ -3594,21 +3619,21 @@ def assert_no_fake_params_or_buffers(gm): ) -def fqn(obj: Any): +def fqn(obj: Any) -> str: """ Returns the fully qualified name of the object. """ return f"{obj.__module__}.{obj.__qualname__}" -def ifdynstaticdefault(count1, count2): +def ifdynstaticdefault(count1: Any, count2: Any) -> Any: if torch._dynamo.config.assume_static_by_default: return count1 else: return count2 -def import_submodule(mod: types.ModuleType): +def import_submodule(mod: types.ModuleType) -> None: """ Ensure all the files in a given submodule are imported """ @@ -3617,17 +3642,17 @@ def import_submodule(mod: types.ModuleType): importlib.import_module(f"{mod.__name__}.{filename[:-3]}") -def object_has_getattribute(value: Any): +def object_has_getattribute(value: Any) -> bool: return class_has_getattribute(type(value)) -def object_setattr_ignore_descriptor(obj, name, value): +def object_setattr_ignore_descriptor(obj: Any, name: str, value: Any) -> None: # https://github.com/python/cpython/blob/3.11/Objects/object.c#L1286-L1335 d = object.__getattribute__(obj, "__dict__") d[name] = value -def class_has_getattribute(cls: type): +def class_has_getattribute(cls: type) -> bool: try: if isinstance( inspect.getattr_static(cls, "__getattribute__"), @@ -3639,7 +3664,9 @@ def class_has_getattribute(cls: type): return False -def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False): +def get_custom_getattr( + value: Any, ignore_nn_module_getattr: bool = False +) -> Optional[Any]: try: getattr_fn = inspect.getattr_static(type(value), "__getattr__") except AttributeError: @@ -3656,7 +3683,7 @@ class TensorStaticReason(enum.Enum): NN_MODULE_PROPERTY = 5 -def tensor_static_reason_to_message(reason: TensorStaticReason): +def tensor_static_reason_to_message(reason: TensorStaticReason) -> str: if reason == TensorStaticReason.PARAMETER: return "mark_dynamic on parameter, parameters are always static today." if reason == TensorStaticReason.NOT_TENSOR: @@ -3700,8 +3727,8 @@ def tensor_always_has_static_shape( return False, None -def lazy_format_graph_tabular(fn_name, gm): - def inner(): +def lazy_format_graph_tabular(fn_name: str, gm: torch.fx.GraphModule) -> Any: + def inner() -> str: try: from tabulate import tabulate # TODO: Check that this is installed except ImportError: @@ -3721,7 +3748,9 @@ def lazy_format_graph_tabular(fn_name, gm): return LazyString(inner) -def format_bytecode(prefix, name, filename, line_no, code): +def format_bytecode( + prefix: str, name: str, filename: str, line_no: int, code: Any +) -> str: return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n" @@ -3736,20 +3765,21 @@ state_dict_hook_names = [ all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names -def nn_module_has_global_hooks(): +def nn_module_has_global_hooks() -> bool: # This is limited to backward hooks for now because NNModuleVariable # supports fwd hooks underneath. - return len(torch.nn.modules.module._global_backward_hooks) or len( - torch.nn.modules.module._global_backward_pre_hooks + return bool( + len(torch.nn.modules.module._global_backward_hooks) + or len(torch.nn.modules.module._global_backward_pre_hooks) ) def nn_module_get_all_hooks( - mod, - check_forward_hooks=False, - check_backward_hooks=False, - check_state_dict_hooks=False, -): + mod: torch.nn.Module, + check_forward_hooks: bool = False, + check_backward_hooks: bool = False, + check_state_dict_hooks: bool = False, +) -> list[Any]: """ Sometimes its useful to differentiate between types of hooks such as forward/backward/pre hooks executed during module.__call__, and state_dict hooks which are executed separately. @@ -3778,11 +3808,11 @@ def nn_module_get_all_hooks( def nnmodule_has_hooks( - mod, - check_forward_hooks=False, - check_backward_hooks=False, - check_state_dict_hooks=False, -): + mod: torch.nn.Module, + check_forward_hooks: bool = False, + check_backward_hooks: bool = False, + check_state_dict_hooks: bool = False, +) -> bool: """ Helper function to check if a module has any hooks attached to it. """ @@ -3795,7 +3825,7 @@ def nnmodule_has_hooks( return bool(hooks) -def to_numpy_helper(value): +def to_numpy_helper(value: Any) -> Any: """Convert tensor and tnp.ndarray to numpy.ndarray.""" if is_fake(value): return value @@ -3809,7 +3839,7 @@ def to_numpy_helper(value): return value -def numpy_to_tensor(value): +def numpy_to_tensor(value: Any) -> Any: """Convert tnp.ndarray to tensor, leave other types intact. If a list/tuple, loop through it to convert.""" assert np is not None if isinstance(value, np.ndarray): @@ -3823,19 +3853,19 @@ def numpy_to_tensor(value): class numpy_to_tensor_wrapper: - def __init__(self, f): + def __init__(self, f: Any) -> None: self.f = f self.__name__ = "wrapped_" + self.f.__name__ def __repr__(self) -> str: return f">" - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: out = self.f(*args, **kwargs) return numpy_to_tensor(out) -def numpy_attr_wrapper(obj, name): +def numpy_attr_wrapper(obj: Any, name: str) -> Any: if isinstance(obj, tnp.ndarray): out = getattr(obj, name) return numpy_to_tensor(out) @@ -3847,14 +3877,14 @@ def numpy_attr_wrapper(obj, name): class numpy_method_wrapper: """Convert obj from torch.Tensor to tnp.ndarray and call method. Then convert result back to torch.Tensor.""" - def __init__(self, method: str): + def __init__(self, method: str) -> None: self.method = method self.__name__ = "wrapped_" + self.method def __repr__(self) -> str: return f">" - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: obj = args[0] if isinstance(obj, torch.Tensor): obj = tnp.ndarray(obj) @@ -3866,14 +3896,14 @@ class numpy_method_wrapper: class numpy_operator_wrapper: """Implements dunder methods for tnp.ndarray via functions from the operator library""" - def __init__(self, op: Callable[..., Any]): + def __init__(self, op: Callable[..., Any]) -> None: self.op = op self.__name__ = f"wrapped_{op.__name__}" def __repr__(self) -> str: return f">" - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: assert not kwargs args = ( @@ -3883,7 +3913,7 @@ class numpy_operator_wrapper: return numpy_to_tensor(out) -def defake(x): +def defake(x: Any) -> Any: if not isinstance(x, FakeTensor): return x size: torch._prims_common.ShapeType @@ -3915,24 +3945,26 @@ def defake(x): return y -def _disable_side_effect_safety_checks_for_current_subtracer(fn, *args, **kwargs): +def _disable_side_effect_safety_checks_for_current_subtracer( + fn: Callable[_P, Any], *args: _P.args, **kwargs: _P.kwargs +) -> Any: return fn(*args, **kwargs) -def is_utils_checkpoint(obj): +def is_utils_checkpoint(obj: Any) -> bool: # Lazy import to avoid circular dependencies import torch.utils.checkpoint return obj is torch.utils.checkpoint.checkpoint -def is_invoke_subgraph(obj): +def is_invoke_subgraph(obj: Any) -> bool: from torch._higher_order_ops.invoke_subgraph import invoke_subgraph_placeholder return obj is invoke_subgraph_placeholder -def build_invoke_subgraph_variable(**options): +def build_invoke_subgraph_variable(**options: Any) -> Any: from .variables.higher_order_ops import TorchHigherOrderOperatorVariable return TorchHigherOrderOperatorVariable.make( @@ -3941,7 +3973,7 @@ def build_invoke_subgraph_variable(**options): ) -def build_checkpoint_variable(**options): +def build_checkpoint_variable(**options: Any) -> Any: import torch._higher_order_ops.wrap as higher_order_ops from .variables.higher_order_ops import TorchHigherOrderOperatorVariable @@ -3960,7 +3992,7 @@ def build_checkpoint_variable(**options): ) -def is_compile_supported(device_type): +def is_compile_supported(device_type: DeviceLikeType) -> Any: from .eval_frame import is_dynamo_supported type = torch.device(device_type).type @@ -4026,12 +4058,12 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]: lines = segment.split("\n") # get character index given byte offset - def normalize(lineno, offset): + def normalize(lineno: int, offset: int) -> int: return _fix_offset(lines[lineno], offset) # Gets the next valid character index in `lines`, if # the current location is not valid. Handles empty lines. - def next_valid_char(lineno, col): + def next_valid_char(lineno: int, col: int) -> tuple[int, int]: while lineno < len(lines) and col >= len(lines[lineno]): col = 0 lineno += 1 @@ -4039,14 +4071,14 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]: return lineno, col # Get the next valid character index in `lines`. - def increment(lineno, col): + def increment(lineno: int, col: int) -> tuple[int, int]: col += 1 lineno, col = next_valid_char(lineno, col) assert lineno < len(lines) and col < len(lines[lineno]) return lineno, col # Get the next valid character at least on the next line - def nextline(lineno, col): + def nextline(lineno: int, col: int) -> tuple[int, int]: col = 0 lineno += 1 lineno, col = next_valid_char(lineno, col) @@ -4063,6 +4095,7 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]: # -2 since end_lineno is 1-indexed and because we added an extra # bracket to `segment` when calling ast.parse cur_lineno = cast(int, expr.left.end_lineno) - 2 + assert expr.left.end_col_offset is not None cur_col = normalize(cur_lineno, expr.left.end_col_offset) cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col) @@ -4095,12 +4128,14 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]: # subscript^^^^^^^^^^^^^^^^^^^^ # find left bracket (first '[' after value) left_lineno = cast(int, expr.value.end_lineno) - 2 + assert expr.value.end_col_offset is not None left_col = normalize(left_lineno, expr.value.end_col_offset) left_lineno, left_col = next_valid_char(left_lineno, left_col) while lines[left_lineno][left_col] != "[": left_lineno, left_col = increment(left_lineno, left_col) # find right bracket (final character of expression) right_lineno = cast(int, expr.end_lineno) - 2 + assert expr.end_col_offset is not None right_col = normalize(right_lineno, expr.end_col_offset) return _Anchors(left_lineno, left_col, right_lineno, right_col) elif isinstance(expr, ast.Call): @@ -4109,12 +4144,14 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]: # call^^^^^^^^^^^^^^^^^^^^^^^^ # find left bracket (first '(' after func) left_lineno = cast(int, expr.func.end_lineno) - 2 + assert expr.func.end_col_offset is not None left_col = normalize(left_lineno, expr.func.end_col_offset) left_lineno, left_col = next_valid_char(left_lineno, left_col) while lines[left_lineno][left_col] != "(": left_lineno, left_col = increment(left_lineno, left_col) # find right bracket (final character of expression) right_lineno = cast(int, expr.end_lineno) - 2 + assert expr.end_col_offset is not None right_col = normalize(right_lineno, expr.end_col_offset) return _Anchors(left_lineno, left_col, right_lineno, right_col) @@ -4253,14 +4290,14 @@ def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> s return result -def get_static_address_type(t): +def get_static_address_type(t: Any) -> Any: if isinstance(t, torch.Tensor): return getattr(t, "_dynamo_static_input_type", None) return None -def is_rng_state_getter_or_setter(value): +def is_rng_state_getter_or_setter(value: Any) -> bool: getters = ( # The following two functions are not identical, so don't remove anyone! torch._C.Generator.get_state, @@ -4277,7 +4314,7 @@ def is_rng_state_getter_or_setter(value): return value in (*setters, *getters) -def is_tensor_base_attr_getter(value): +def is_tensor_base_attr_getter(value: Any) -> bool: return ( isinstance(value, types.MethodWrapperType) and value.__name__ == "__get__" @@ -4285,7 +4322,7 @@ def is_tensor_base_attr_getter(value): ) -def is_tensor_getset_descriptor(name): +def is_tensor_getset_descriptor(name: str) -> bool: try: attr = inspect.getattr_static(torch.Tensor, name) return type(attr) is types.GetSetDescriptorType @@ -4293,11 +4330,11 @@ def is_tensor_getset_descriptor(name): return False -def is_torch_function_object(value): +def is_torch_function_object(value: Any) -> bool: return hasattr(value, "__torch_function__") -def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool: +def has_torch_function(vt: VariableTracker) -> bool: # This emulates # https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/csrc/utils/disable_torch_function.cpp#L315-L323 from torch._dynamo.variables import UserDefinedObjectVariable @@ -4327,7 +4364,9 @@ def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool # see note [Tensor Fakification and Symbol Caching] -def to_fake_tensor(t, fake_mode): +def to_fake_tensor( + t: torch.Tensor, fake_mode: torch._subclasses.fake_tensor.FakeTensorMode +) -> Any: symbolic_context = None source = None if tracing_context := torch._guards.TracingContext.try_get(): @@ -4341,7 +4380,7 @@ def to_fake_tensor(t, fake_mode): # NB: this works for both classes and instances -def is_frozen_dataclass(value): +def is_frozen_dataclass(value: Any) -> bool: return ( not object_has_getattribute(value) and not class_has_getattribute(value) @@ -4352,7 +4391,7 @@ def is_frozen_dataclass(value): ) -def get_first_attr(obj, *attrs): +def get_first_attr(obj: Any, *attrs: str) -> Any: """ Return the first available attribute or throw an exception if none is present. """ @@ -4364,13 +4403,15 @@ def get_first_attr(obj, *attrs): @contextlib.contextmanager -def maybe_enable_compiled_autograd(should_enable, fullgraph=True, dynamic=True): +def maybe_enable_compiled_autograd( + should_enable: bool, fullgraph: bool = True, dynamic: bool = True +) -> Generator[Any, None, None]: if not should_enable: yield else: - def compiler_fn(gm): - def inner_compiler(gm_, example_inputs_): + def compiler_fn(gm: Any) -> Any: + def inner_compiler(gm_: Any, example_inputs_: Any) -> Any: torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1 return torch._inductor.compile(gm_, example_inputs_) @@ -4382,7 +4423,7 @@ def maybe_enable_compiled_autograd(should_enable, fullgraph=True, dynamic=True): yield ctx -def invalid_removeable_handle(): +def invalid_removeable_handle() -> RemovableHandle: # need a subclass so weakref works class Invalid(dict): # type: ignore[type-arg] pass @@ -4394,7 +4435,7 @@ def invalid_removeable_handle(): # Attribute changes to the original object/proxy will be reflected in the other. # This is useful for cases where we want a keep-alive reference to a module without increasing # its reference count. -def nn_module_proxy(mod): +def nn_module_proxy(mod: Any) -> Any: if not isinstance(mod, torch.nn.Module): return mod if isinstance(mod, torch.fx.GraphModule): @@ -4406,17 +4447,21 @@ def nn_module_proxy(mod): class GmWrapper(torch.nn.Module): - def __init__(self, gm, unflatten_fn): + def __init__( + self, gm: torch.fx.GraphModule, unflatten_fn: Callable[[list[Any]], Any] + ) -> None: super().__init__() self.gm = gm self.unflatten_fn = unflatten_fn - def forward(self, *args): + def forward(self, *args: Any) -> Any: args: list[Any] = list(args) return self.gm(*self.unflatten_fn(args)) -def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): +def flatten_graph_inputs( + gm: torch.fx.GraphModule, inputs: Any, compile_gm: Callable[[Any, Any], Any] +) -> Callable[..., Any]: """ Mutate inputs so that they are flat and wrap gm such that it accepts those inputs. This is needed for graphs that take @@ -4435,10 +4480,10 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): assert isinstance(inputs[0], list) boxed_inputs_count = len(inputs[0]) - def flatten_fn(args): + def flatten_fn(args: Any) -> Any: return args[0] + list(args[1:]) - def unflatten_fn(flat_args): + def unflatten_fn(flat_args: Any) -> Any: return (flat_args[:boxed_inputs_count], *flat_args[boxed_inputs_count:]) compiled_fn = compile_gm(GmWrapper(gm, unflatten_fn), flatten_fn(inputs)) @@ -4450,7 +4495,7 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): # note this doesn't check the spec, assuming it is the same flatten_fn = pytree.arg_tree_leaves - def wrapper(*args): + def wrapper(*args: Any) -> Any: flat_args = flatten_fn(args) # flat_args is a new list, so we need to clear references from the old list @@ -4463,18 +4508,18 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): return wrapper -def get_locals_to_steal(maybe_gm): +def get_locals_to_steal(maybe_gm: Any) -> list[Any]: if not isinstance(maybe_gm, torch.fx.GraphModule) or not hasattr(maybe_gm, "meta"): return [] return maybe_gm.meta.get("locals_to_steal", []) -def set_locals_to_steal(gm, locals_to_steal): +def set_locals_to_steal(gm: torch.fx.GraphModule, locals_to_steal: list[Any]) -> None: gm.meta["locals_to_steal"] = locals_to_steal class Lit: - def __init__(self, s): + def __init__(self, s: str) -> None: self.s = s def __repr__(self) -> str: @@ -4484,7 +4529,7 @@ class Lit: warn_once_cache: set[str] = set() -def warn_once(msg, stacklevel=1): +def warn_once(msg: str, stacklevel: int = 1) -> None: # Dynamo causes all warnings.warn (in user code and in Dynamo code) to print all the time. # https://github.com/pytorch/pytorch/issues/128427. # warn_once is a workaround: if the msg has been warned on before, then we will not @@ -4496,14 +4541,14 @@ def warn_once(msg, stacklevel=1): warnings.warn(msg, stacklevel=stacklevel + 1) -def strip_color_from_string(text): +def strip_color_from_string(text: str) -> str: # This regular expression matches ANSI escape codes ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") return ansi_escape.sub("", text) @contextlib.contextmanager -def _disable_saved_tensors_hooks_during_tracing(): +def _disable_saved_tensors_hooks_during_tracing() -> Generator[None, None, None]: # See NOTE: [Deferring tensor pack/unpack hooks until runtime] try: prior = torch._C._autograd._saved_tensors_hooks_set_tracing(True) @@ -4512,22 +4557,22 @@ def _disable_saved_tensors_hooks_during_tracing(): torch._C._autograd._saved_tensors_hooks_set_tracing(prior) -def is_parameter_freezing(): +def is_parameter_freezing() -> bool: return torch._inductor.config.freezing and not torch.is_grad_enabled() -def get_torch_function_mode_stack(): +def get_torch_function_mode_stack() -> list[Any]: return [ get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) ] -def get_torch_function_mode_stack_at(ind): +def get_torch_function_mode_stack_at(ind: int) -> Any: assert ind < _len_torch_function_stack() and ind >= 0 return torch._C._get_function_stack_at(ind) -def set_torch_function_mode_stack(stack): +def set_torch_function_mode_stack(stack: list[Any]) -> None: for _ in range(_len_torch_function_stack()): _pop_torch_function_stack() @@ -4535,17 +4580,17 @@ def set_torch_function_mode_stack(stack): _push_on_torch_function_stack(mode) -def clear_torch_function_mode_stack(): +def clear_torch_function_mode_stack() -> None: for _ in range(_len_torch_function_stack()): _pop_torch_function_stack() # call from C dynamo in order to inspect values in pdb -def _breakpoint_for_c_dynamo(*args): +def _breakpoint_for_c_dynamo(*args: Any) -> None: breakpoint() -def verify_guard_fn_signature(value): +def verify_guard_fn_signature(value: Any) -> None: fn = value.__metadata_guard__ sig = inspect.signature(fn) if len(sig.parameters) != 2: @@ -4562,7 +4607,7 @@ def verify_guard_fn_signature(value): ) -def does_not_override_dict_iter_methods(user_cls): +def does_not_override_dict_iter_methods(user_cls: Any) -> bool: return ( user_cls.items in (dict.items, OrderedDict.items) and user_cls.values in (dict.values, OrderedDict.values) @@ -4575,23 +4620,23 @@ def does_not_override_dict_iter_methods(user_cls): # __torch_function__ calls triggered on tensor properties in the pre graph # bytecode. @torch._disable_dynamo -def call_size(x, i): +def call_size(x: Any, i: int) -> int: return x.size(i) @torch._disable_dynamo -def call_stride(x, i): +def call_stride(x: Any, i: int) -> int: return x.stride(i) @torch._disable_dynamo -def call_storage_offset(x): +def call_storage_offset(x: Any) -> int: return x.storage_offset() # Helper function to extract relevant parts of a tensor's __dict__ to store in node meta. # To avoid ref cycles, it's important that no tensors are present here, so leave those out. -def _extract_tensor_dict(t): +def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]: KEYS_TO_COPY = [ "_dynamo_static_input_type", "tag", @@ -4610,13 +4655,13 @@ def _extract_tensor_dict(t): user_obj_id_to_weakref: dict[int, weakref.ReferenceType[object]] = {} -def get_user_object_from_id(obj_id): +def get_user_object_from_id(obj_id: int) -> Any: obj = user_obj_id_to_weakref[obj_id]() assert obj is not None, "User object is no longer alive" return obj -def store_user_object_weakref(obj): +def store_user_object_weakref(obj: object) -> None: obj_id = id(obj) user_obj_id_to_weakref[obj_id] = weakref.ref(obj) @@ -4649,7 +4694,7 @@ class CompileTimeInstructionCounter: @classmethod @contextmanager - def record(cls): + def record(cls) -> Generator[None, None, None]: try: if config.record_compile_time_instruction_count: cls.start() @@ -4659,7 +4704,7 @@ class CompileTimeInstructionCounter: cls.end() -def set_feature_use(feature: str, usage: bool): +def set_feature_use(feature: str, usage: bool) -> None: """ Records whether we are using a feature Generally a feature is a JK. @@ -4677,7 +4722,7 @@ _ddp_optimization_mode: tuple[str, ...] = ( ) -def get_optimize_ddp_mode(): +def get_optimize_ddp_mode() -> str: optimize_ddp = config.optimize_ddp if isinstance(optimize_ddp, bool): mode = "ddp_optimizer" if optimize_ddp else "no_optimization"