[Dynamo][Better Engineering] Add typing annotations to guard and source (#158397) (#159491)

Summary:
X-link: https://github.com/pytorch/executorch/pull/12986

As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo

This PR adds strict typing support to a critical set of files for dynamo, `source.py` and the base `_guards.py`

Running
```
mypy torch/_dynamo/source.py torch/_guards.py --linecount-report /tmp/coverage_log
```

| -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main  |  1227 | 2208 | 55.57% | 207 | 362 | 57.18% |
| This PR | 2217 | 2217 | 100.00% | 362 | 362 | 100.00% |
| Delta    | +990 | +9 | +44.43% | +155 | 0 | +42.82% |

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

Rollback Plan:

Reviewed By: JacobSzwejbka, yangw-dev

Differential Revision: D79199389

Pulled By: Lucaskabela

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159491
Approved by: https://github.com/anijain2305, https://github.com/yangw-dev
This commit is contained in:
Lucas Kabela
2025-07-30 22:57:46 +00:00
committed by PyTorch MergeBot
parent 1293405c8d
commit 2b1ae29960
22 changed files with 335 additions and 278 deletions

View File

@ -1848,7 +1848,7 @@ def export(
ignore_fresh_unbacked = null_context()
assert ambient_fake_mode is not None
if shape_env := ambient_fake_mode.shape_env:
ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols()
ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() # type: ignore[assignment]
with (
ambient_fake_mode,
@ -1900,7 +1900,9 @@ def export(
fakify_with_ambient, graph_inputs
)
graph_captured_result = torch.func.functional_call(
graph, fake_params_buffers, fake_graph_inputs
graph,
fake_params_buffers, # type: ignore[arg-type]
fake_graph_inputs, # type: ignore[arg-type]
)
return graph_captured_result

View File

@ -3685,7 +3685,7 @@ def strip_local_scope(s: str) -> str:
def get_guard_fail_reason_helper(
guard_manager: GuardFn,
f_locals: dict[str, object],
compile_id: CompileId,
compile_id: Optional[CompileId],
) -> str:
"""
Return the reason why `guard_manager` failed.

View File

@ -809,6 +809,7 @@ class OutputGraph(OutputGraphGuardsState):
@property
def shape_env(self):
assert self.tracing_context.fake_mode is not None
return self.tracing_context.fake_mode.shape_env
@property
@ -1691,6 +1692,7 @@ class OutputGraph(OutputGraphGuardsState):
)
self.call_cleanup_hooks()
old_fake_mode = self.tracing_context.fake_mode
assert old_fake_mode is not None
if not self.export:
import torch._functorch.config as _config
@ -1738,6 +1740,7 @@ class OutputGraph(OutputGraphGuardsState):
)
counters["stats"]["unique_graphs"] += 1
assert old_fake_mode.shape_env is not None
if specializations := old_fake_mode.shape_env.specializations:
specialization_guards = []
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}

View File

@ -1,5 +1,3 @@
# mypy: allow-untyped-defs
"""
This module provides Source classes that track the origins of values in PyTorch Dynamo.
Sources represent where values come from (e.g. local variables, globals, attributes) and
@ -22,9 +20,9 @@ the code needed to recreate values.
import dataclasses
import enum
import functools
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from torch._guards import ChainedSource, GuardSource, Source
from torch._guards import ChainedSource, Guard, GuardSource, Source
from . import utils
from .bytecode_transformation import create_call_function, create_instruction
@ -96,7 +94,7 @@ _GUARD_SOURCE_FSDP_MODULE = {
}
def is_constant_source(source):
def is_constant_source(source: Source) -> bool:
if isinstance(source, ConstantSource):
return True
try:
@ -124,16 +122,16 @@ class LocalSource(Source):
# or `co_freevars`.
is_derefed_cell_contents: bool = False
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
if self.is_derefed_cell_contents:
codegen.load_deref(self.local_name)
else:
codegen.append_output(codegen.create_load(self.local_name))
def guard_source(self):
def guard_source(self) -> GuardSource:
return GuardSource.LOCAL
def name(self):
def name(self) -> str:
return f"L[{repr(self.local_name)}]"
@ -141,13 +139,13 @@ class LocalSource(Source):
class SyntheticLocalSource(Source):
local_name: str
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.append_output(codegen.create_load(self.local_name))
def guard_source(self):
def guard_source(self) -> GuardSource:
return GuardSource.SYNTHETIC_LOCAL
def name(self):
def name(self) -> str:
return f"SYNTHETIC_LOCAL[{self.local_name!r}]"
@ -155,15 +153,15 @@ class SyntheticLocalSource(Source):
class RandomValueSource(Source):
random_call_index: int
def guard_source(self):
def guard_source(self) -> GuardSource:
return GuardSource.RANDOM_VALUE
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
codegen.append_output(codegen.create_load_const(self.random_call_index))
codegen.append_output(create_instruction("BINARY_SUBSCR"))
def name(self):
def name(self) -> str:
return f"random_value_{self.random_call_index}"
@ -171,13 +169,13 @@ class RandomValueSource(Source):
class GlobalSource(Source):
global_name: str
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.append_output(codegen.create_load_global(self.global_name, add=True))
def guard_source(self):
def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL
def name(self):
def name(self) -> str:
return f"G[{repr(self.global_name)}]"
@ -185,7 +183,7 @@ class GlobalSource(Source):
class GlobalWeakRefSource(Source):
global_name: str
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_global(self.global_name, add=True)
@ -193,23 +191,23 @@ class GlobalWeakRefSource(Source):
)
codegen.extend_output(create_call_function(0, False))
def guard_source(self):
def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL
def name(self):
def name(self) -> str:
return f"G[{repr(self.global_name)}]()"
@dataclasses.dataclass(frozen=True)
class WeakRefCallSource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(lambda: codegen(self.base))
codegen.extend_output(create_call_function(0, False))
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
return f"{self.base.name()}()"
@ -222,7 +220,7 @@ class CallFunctionNoArgsSource(WeakRefCallSource):
class AttrSource(ChainedSource):
member: str
def __post_init__(self):
def __post_init__(self) -> None:
assert self.base, "Can't construct an AttrSource without a valid base source"
if "." in self.member:
member_parts = self.member.split(".")
@ -231,14 +229,14 @@ class AttrSource(ChainedSource):
)
object.__setattr__(self, "member", member_parts[-1])
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base)
codegen.extend_output(codegen.create_load_attrs(self.member))
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
if not self.member.isidentifier():
return f"getattr({self.base.name()}, {self.member!r})"
return f"{self.base.name()}.{self.member}"
@ -248,7 +246,7 @@ class AttrSource(ChainedSource):
class GenericAttrSource(ChainedSource):
member: str
def __post_init__(self):
def __post_init__(self) -> None:
assert self.base, "Can't construct an AttrSource without a valid base source"
if "." in self.member:
member_parts = self.member.split(".")
@ -257,14 +255,14 @@ class GenericAttrSource(ChainedSource):
)
object.__setattr__(self, "member", member_parts[-1])
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base)
codegen.extend_output(codegen.create_load_attrs(self.member))
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
return f"object.__getattribute__({self.base.name()}, {self.member!r})"
@ -277,7 +275,7 @@ class LocalCellSource(Source):
local_name: str
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
# Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
# Dynamo's bytecode transformation differentiates them slightly, so we
# always emit `LOAD_CLOSURE` here.
@ -295,20 +293,20 @@ class LocalCellSource(Source):
class GradSource(ChainedSource):
member: str = "grad"
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base)
codegen.extend_output(codegen.create_load_attrs(self.member))
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
return f"{self.base.name()}.{self.member}"
@dataclasses.dataclass(frozen=True)
class ParamBufferSource(AttrSource):
def guard_source(self):
def guard_source(self) -> GuardSource:
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
@ -331,16 +329,16 @@ class UnspecializedParamBufferSource(AttrSource):
class EphemeralSource(Source):
desc: Optional[str] = None
def guard_source(self):
def guard_source(self) -> GuardSource:
return GuardSource.EPHEMERAL
def name(self):
def name(self) -> str:
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
def make_guard(self, fn):
def make_guard(self, fn: Callable[..., Any]) -> Guard:
raise NotImplementedError
def is_ephemeral(self):
def is_ephemeral(self) -> bool:
return True
@ -349,13 +347,15 @@ class TensorProperty(enum.Enum):
STRIDE = 1
STORAGE_OFFSET = 2
def method_name(self):
def method_name(self) -> str:
if self is TensorProperty.SIZE:
return "size"
elif self is TensorProperty.STRIDE:
return "stride"
elif self is TensorProperty.STORAGE_OFFSET:
return "storage_offset"
else:
raise AssertionError(f"unhandled {self}")
@dataclasses.dataclass(frozen=True)
@ -363,14 +363,14 @@ class TensorPropertySource(ChainedSource):
prop: TensorProperty
idx: Optional[int] = None # None for STORAGE_OFFSET
def __post_init__(self):
def __post_init__(self) -> None:
assert self.base is not None
if self.prop is TensorProperty.STORAGE_OFFSET:
assert self.idx is None
else:
assert self.idx is not None
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.load_import_from(
utils.__name__, f"call_{self.prop.method_name()}"
@ -384,10 +384,10 @@ class TensorPropertySource(ChainedSource):
create_call_function(2 if self.idx is not None else 1, False)
)
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
if self.prop is TensorProperty.SIZE:
return f"{self.base.name()}.size()[{self.idx}]"
elif self.prop is TensorProperty.STRIDE:
@ -403,88 +403,88 @@ class TensorPropertySource(ChainedSource):
class IndexedSource(ChainedSource):
idx: int
def __post_init__(self):
def __post_init__(self) -> None:
assert self.base is not None
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
raise NotImplementedError
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
return f"({self.idx}, {self.base.name()})"
@dataclasses.dataclass(frozen=True)
class NegateSource(ChainedSource):
def __post_init__(self):
def __post_init__(self) -> None:
assert self.base is not None
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
raise NotImplementedError
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
# NB: use method call so that function stripping regexes work
return f"{self.base.name()}.__neg__()"
@dataclasses.dataclass(frozen=True)
class ConvertIntSource(ChainedSource):
def __post_init__(self):
def __post_init__(self) -> None:
assert self.base is not None
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base)
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
return f"cast_symbool_to_symint_guardless({self.base.name()})"
@dataclasses.dataclass(frozen=True)
class FlattenScriptObjectSource(ChainedSource):
def __post_init__(self):
def __post_init__(self) -> None:
assert self.base is not None
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base)
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
return f"{self.base.name()}.__obj_flatten__()"
@dataclasses.dataclass(frozen=True)
class ScriptObjectQualifiedNameSource(ChainedSource):
def __post_init__(self):
def __post_init__(self) -> None:
assert self.base is not None
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base)
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
return f"{self.base.name()}._type().qualified_name()"
class AttrProxySource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base)
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
return f"{self.base.name()}.get_base()"
@ -495,7 +495,7 @@ class DefaultsSource(ChainedSource):
field: str = dataclasses.field(init=False, repr=False, compare=False)
_name: str = dataclasses.field(init=False, repr=False, compare=False)
def __post_init__(self):
def __post_init__(self) -> None:
assert self.base, (
"Base must be a valid source in order to properly track and guard this Defaults to its origin."
)
@ -512,16 +512,16 @@ class DefaultsSource(ChainedSource):
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
)
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base)
codegen.extend_output(codegen.create_load_attrs(self.field))
codegen.append_output(codegen.create_load_const(self.idx_key))
codegen.append_output(create_instruction("BINARY_SUBSCR"))
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
return self._name
@ -530,14 +530,14 @@ class GetItemSource(ChainedSource):
index: Any
index_is_slice: bool = False
def __post_init__(self):
def __post_init__(self) -> None:
assert self.base is not None
if isinstance(self.index, slice):
# store the hashable version of the slice so the whole GetItemSource is hashable
super().__setattr__("index", self.index.__reduce__())
super().__setattr__("index_is_slice", True)
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base)
if self.index_is_slice:
codegen.append_output(codegen.create_load_const(self.unpack_slice()))
@ -545,15 +545,15 @@ class GetItemSource(ChainedSource):
codegen.append_output(codegen.create_load_const(self.index))
codegen.append_output(create_instruction("BINARY_SUBSCR"))
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def unpack_slice(self):
def unpack_slice(self) -> slice:
assert self.index_is_slice
slice_class, slice_args = self.index
return slice_class(*slice_args)
def name(self):
def name(self) -> str:
# Index can be of following types
# 1) index is a slice - example 1:4
# 2) index is a constant - example string, integer
@ -568,10 +568,10 @@ class GetItemSource(ChainedSource):
class ConstDictKeySource(ChainedSource):
index: Any
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem")
)
@ -579,11 +579,11 @@ class ConstDictKeySource(ChainedSource):
codegen.append_output(codegen.create_load_const(self.index))
codegen.extend_output(create_call_function(2, False))
def name(self):
def name(self) -> str:
# The list creation will be CSE'd by PyExprCSEPass
return f"list(dict.keys({self.base.name()}))[{self.index!r}]"
def is_dict_key(self):
def is_dict_key(self) -> bool:
return True
@ -591,15 +591,15 @@ class ConstDictKeySource(ChainedSource):
class NonSerializableSetGetItemSource(ChainedSource):
index: int
def __post_init__(self):
def __post_init__(self) -> None:
from .variables import ConstantVariable
assert ConstantVariable.is_literal(self.index)
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "set_getitem")
)
@ -607,11 +607,11 @@ class NonSerializableSetGetItemSource(ChainedSource):
codegen.append_output(codegen.create_load_const(self.index))
codegen.extend_output(create_call_function(2, False))
def name(self):
def name(self) -> str:
# set ordering might not be stable
return f"list({self.base.name()})[{self.index!r}]"
def is_dict_key(self):
def is_dict_key(self) -> bool:
return False
@ -623,17 +623,17 @@ class DictGetItemSource(ChainedSource):
# 2) constant - like string, integer
index: Any
def __post_init__(self):
def __post_init__(self) -> None:
from .variables import ConstantVariable
assert isinstance(
self.index, ConstDictKeySource
) or ConstantVariable.is_literal(self.index)
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
# Load dict
codegen(self.base)
@ -644,7 +644,7 @@ class DictGetItemSource(ChainedSource):
codegen.append_output(codegen.create_load_const(self.index))
codegen.append_output(create_instruction("BINARY_SUBSCR"))
def name(self):
def name(self) -> str:
if isinstance(self.index, ConstDictKeySource):
return f"{self.base.name()}[{self.index.name()}]"
else:
@ -660,17 +660,17 @@ class DictSubclassGetItemSource(ChainedSource):
# 2) constant - like string, integer
index: Any
def __post_init__(self):
def __post_init__(self) -> None:
from .variables import ConstantVariable
assert isinstance(
self.index, ConstDictKeySource
) or ConstantVariable.is_literal(self.index)
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
# reconstruct dict.__getitem__(dct, key)
# Load dict.__getitem__
@ -689,7 +689,7 @@ class DictSubclassGetItemSource(ChainedSource):
codegen.extend_output(create_call_function(2, False))
def name(self):
def name(self) -> str:
if isinstance(self.index, ConstDictKeySource):
return f"dict.__getitem__({self.base.name()}, {self.index.name()})"
else:
@ -702,7 +702,7 @@ class ListGetItemSource(GetItemSource):
Same as GetItemSource with reconstruct and name overridden to be list specific.
"""
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
# Reconstruct list.__getitem__(lst, index) to avoid any side effects
# from possibly overridden __getitem__.
@ -724,7 +724,7 @@ class ListGetItemSource(GetItemSource):
codegen.extend_output(create_call_function(2, False))
def name(self):
def name(self) -> str:
# Index can be of following types
# 1) index is a slice - example 1:4
# 2) index is a constant - example string, integer
@ -739,7 +739,7 @@ class ListGetItemSource(GetItemSource):
@dataclasses.dataclass(frozen=True)
class TupleIteratorGetItemSource(GetItemSource):
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
)
@ -747,91 +747,91 @@ class TupleIteratorGetItemSource(GetItemSource):
codegen.append_output(codegen.create_load_const(self.index))
codegen.extend_output(create_call_function(2, False))
def name(self):
def name(self) -> str:
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
@dataclasses.dataclass(frozen=True)
class DataclassFieldsSource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "dataclass_fields")
)
codegen(self.base)
codegen.extend_output(create_call_function(1, False))
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
return f"___dataclass_fields({self.base.name()})"
@dataclasses.dataclass(frozen=True)
class TypeSource(ChainedSource):
def __post_init__(self):
def __post_init__(self) -> None:
assert self.base is not None
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type"))
codegen(self.base)
codegen.extend_output(create_call_function(1, False))
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
return f"type({self.base.name()})"
@dataclasses.dataclass(frozen=True)
class OptimizerSource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base)
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self):
def name(self) -> str:
return self.base.name()
@dataclasses.dataclass(frozen=True)
class NNModuleSource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base)
def guard_source(self):
def guard_source(self) -> GuardSource:
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
def name(self):
def name(self) -> str:
return self.base.name()
@dataclasses.dataclass(frozen=True)
class UnspecializedNNModuleSource(NNModuleSource):
def guard_source(self):
def guard_source(self) -> GuardSource:
return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True)
class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource):
def guard_source(self):
def guard_source(self) -> GuardSource:
return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True)
class FSDPNNModuleSource(NNModuleSource):
def guard_source(self):
def guard_source(self) -> GuardSource:
return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True)
class GlobalStateSource(Source):
def name(self):
def name(self) -> str:
return ""
def guard_source(self):
def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL
@ -840,16 +840,16 @@ class TorchSource(Source):
"""Points to the actual `torch` module - used instead of GlobalSource
in case the user has overridden `torch` in their local namespace"""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
from .guards import GuardBuilder, install_guard
install_guard(self.make_guard(GuardBuilder.ID_MATCH))
def name(self):
def name(self) -> str:
return "__import__('torch')"
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.extend_output(
[
codegen.create_load_const(0), # level
@ -858,7 +858,7 @@ class TorchSource(Source):
]
)
def guard_source(self):
def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL
@ -866,15 +866,15 @@ class TorchSource(Source):
class TorchFunctionModeStackSource(Source):
ind: int
def name(self):
def name(self) -> str:
return f"___get_torch_function_mode_stack_at({self._get_index()})"
def _get_index(self):
def _get_index(self) -> int:
from .variables.torch_function import TorchFunctionModeStackVariable
return TorchFunctionModeStackVariable.get_mode_index(self.ind)
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.load_import_from(
utils.__name__, "get_torch_function_mode_stack_at"
@ -883,7 +883,7 @@ class TorchFunctionModeStackSource(Source):
codegen.extend_output([codegen.create_load_const(self._get_index())])
codegen.extend_output(create_call_function(1, False))
def guard_source(self):
def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL
@ -891,16 +891,16 @@ class TorchFunctionModeStackSource(Source):
class ConstantSource(Source):
source_name: str
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.append_output(codegen.create_load_global(self.source_name, add=False))
def guard_source(self):
def guard_source(self) -> GuardSource:
return GuardSource.CONSTANT
def name(self):
def name(self) -> str:
return self.source_name
def make_guard(self, fn):
def make_guard(self, fn: Any) -> Any:
raise NotImplementedError
@ -909,10 +909,10 @@ class NumpyTensorSource(ChainedSource):
def name(self) -> str:
return f"___from_numpy({self.base.name()})"
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor"))
codegen(self.base)
codegen.extend_output(create_call_function(1, False))
@ -923,7 +923,7 @@ class SubclassAttrListSource(ChainedSource):
def name(self) -> str:
return f"{self.base.name()}.__tensor_flatten__()[0]"
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@ -934,7 +934,7 @@ class FloatTensorSource(ChainedSource):
def name(self) -> str:
return f"___as_tensor({self.base.name()})"
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@ -943,7 +943,7 @@ class CallMethodItemSource(ChainedSource):
def name(self) -> str:
return f"{self.base.name()}.item()"
def guard_source(self):
def guard_source(self) -> GuardSource:
return self.base.guard_source()
@ -952,23 +952,25 @@ class CallMethodItemSource(ChainedSource):
# guard contents from the ambient ShapeEnv
@dataclasses.dataclass(frozen=True)
class ShapeEnvSource(Source):
def name(self):
def name(self) -> str:
return ""
def guard_source(self):
def guard_source(self) -> GuardSource:
return GuardSource.SHAPE_ENV
@dataclasses.dataclass(frozen=True)
class BackwardStateSource(Source):
def name(self):
def name(self) -> str:
return ""
def guard_source(self):
def guard_source(self) -> GuardSource:
return GuardSource.BACKWARD_STATE
def get_local_source_name(source: Source, *, only_allow_input=False) -> Optional[str]:
def get_local_source_name(
source: Source, *, only_allow_input: bool = False
) -> Optional[str]:
if isinstance(source, ChainedSource):
return get_local_source_name(source.base, only_allow_input=only_allow_input)
if not isinstance(source, LocalSource):
@ -978,7 +980,7 @@ def get_local_source_name(source: Source, *, only_allow_input=False) -> Optional
return source.local_name
def is_from_local_source(source: Source, *, only_allow_input=False):
def is_from_local_source(source: Source, *, only_allow_input: bool = False) -> bool:
return get_local_source_name(source, only_allow_input=only_allow_input) is not None
@ -994,7 +996,7 @@ def get_global_source_name(source: Source) -> Optional[str]:
return source.global_name
def is_from_nonlocal_source(source: Source):
def is_from_nonlocal_source(source: Source) -> bool:
if isinstance(source, ChainedSource):
return is_from_nonlocal_source(source.base)
return (
@ -1004,14 +1006,14 @@ def is_from_nonlocal_source(source: Source):
)
def is_from_source(source: Source, target: Source):
def is_from_source(source: Source, target: Source) -> bool:
if isinstance(source, ChainedSource):
return is_from_source(source.base, target)
return source == target
@functools.lru_cache
def is_from_unspecialized_nn_module_source(source: Source):
def is_from_unspecialized_nn_module_source(source: Source) -> bool:
if isinstance(source, UnspecializedNNModuleSource):
return True
if isinstance(source, ChainedSource):
@ -1020,7 +1022,7 @@ def is_from_unspecialized_nn_module_source(source: Source):
@functools.lru_cache
def is_from_unspecialized_builtin_nn_module_source(source: Source):
def is_from_unspecialized_builtin_nn_module_source(source: Source) -> bool:
if isinstance(source, UnspecializedBuiltinNNModuleSource):
return True
if isinstance(source, ChainedSource):
@ -1029,7 +1031,7 @@ def is_from_unspecialized_builtin_nn_module_source(source: Source):
@functools.lru_cache
def is_from_unspecialized_param_buffer_source(source: Source):
def is_from_unspecialized_param_buffer_source(source: Source) -> bool:
if isinstance(source, UnspecializedParamBufferSource):
return True
if isinstance(source, ChainedSource):
@ -1038,7 +1040,7 @@ def is_from_unspecialized_param_buffer_source(source: Source):
@functools.lru_cache
def is_from_flatten_script_object_source(source: Source):
def is_from_flatten_script_object_source(source: Source) -> bool:
if isinstance(source, FlattenScriptObjectSource):
return True
elif isinstance(source, ChainedSource):
@ -1047,7 +1049,7 @@ def is_from_flatten_script_object_source(source: Source):
@functools.lru_cache
def is_from_optimizer_source(source: Source):
def is_from_optimizer_source(source: Source) -> bool:
if isinstance(source, OptimizerSource):
return True
if isinstance(source, ChainedSource):
@ -1058,7 +1060,7 @@ def is_from_optimizer_source(source: Source):
# TODO: can probably write a generic "test this on everything in the chain"
# helper
@functools.lru_cache
def is_from_defaults(source: Source):
def is_from_defaults(source: Source) -> bool:
if isinstance(source, DefaultsSource):
return True

View File

@ -2644,7 +2644,9 @@ def set_example_value(node, example_value):
# this to accurately reflect what the state of the value was at the time
# the program was traced).
node.meta["example_value"] = example_value
shape_env = TracingContext.get().fake_mode.shape_env
fake_mode = TracingContext.get().fake_mode
assert fake_mode is not None
shape_env = fake_mode.shape_env
if (
symbol_to_path
:= torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
@ -4765,7 +4767,7 @@ def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None:
# Returns a set of code objects present traced in the current TracingContext, or None
# if there is no current TracingContext.
def get_traced_code() -> list[CodeType]:
def get_traced_code() -> Optional[list[CodeType]]:
from torch._guards import TracingContext
return TracingContext.get_traced_code()

View File

@ -365,6 +365,7 @@ def make_fake_inputs(
# a toplevel TracingContext with a fake mode, so we do not want to
# create another fake mode.
fake_mode = context.fake_mode
assert fake_mode is not None
else:
if isinstance(nn_module.forward, functools.partial):
# functools handles nesting by itself, no need to recurse
@ -852,7 +853,7 @@ def _fakify_script_objects(
mod: torch.nn.Module,
args: Sequence[Any],
kwargs: dict[Any, Any],
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
fake_mode: Optional[torch._subclasses.fake_tensor.FakeTensorMode],
):
# This context manager is used to fakify script objects into FakeScriptObject.
# Inputs:

View File

@ -1129,7 +1129,7 @@ def remove_proxy_from_state_dict(state_dict: dict, in_place: bool) -> dict:
def _detect_fake_mode_from_gm(
gm: torch.fx.GraphModule,
) -> torch._subclasses.fake_tensor.FakeTensorMode:
) -> Optional[torch._subclasses.fake_tensor.FakeTensorMode]:
"""
For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs.
Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes.

View File

@ -12,7 +12,7 @@ It does so by:
"""
import warnings
from contextlib import contextmanager, ExitStack, nullcontext
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass
from typing import Any, Callable, cast, Optional, TypeVar, Union
from unittest.mock import patch
@ -441,9 +441,10 @@ def create_functionalized_rng_ops_wrapper(
# It goes from (primals, tangents) to (seed, offset, primals, tangents)
# At runtime, we pass on the current seed and offset. This is hidden from
# the user.
fake_mode = detect_fake_mode()
if fake_mode is None:
fake_mode = nullcontext()
fake_mode_det = detect_fake_mode()
fake_mode: AbstractContextManager[Any] = nullcontext()
if fake_mode_det is not None:
fake_mode = fake_mode_det
def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"):
out = PhiloxStateTracker.get_state_as_tensor()
@ -1343,7 +1344,9 @@ def create_functional_call(
"ignore", "Anomaly Detection has been enabled."
)
with torch.autograd.detect_anomaly(check_nan=False):
detect_fake_mode().epoch += 1
fake_mode = detect_fake_mode()
assert fake_mode is not None
fake_mode.epoch += 1
out = PropagateUnbackedSymInts(mod).run(
*args[params_len:], **kwargs
)

View File

@ -306,11 +306,12 @@ def compute_overlapping_inputs(aot_config, fwd_inputs, aliased_input_indices):
tracing_context = torch._guards.TracingContext.try_get()
if tracing_context is not None:
assert tracing_context.fake_mode is not None
shape_env = tracing_context.fake_mode.shape_env
# Check whether we can actually get the dynamo sources from within AOTAutograd.
if aot_config.aot_autograd_arg_pos_to_source and shape_env is not None:
maybe_suppress_guards = shape_env.suppress_guards
maybe_suppress_guards = shape_env.suppress_guards # type: ignore[assignment]
# Check whether there are any symbolic values being used.
# We do this for 2 reasons:

View File

@ -495,6 +495,7 @@ class FunctionalizedRngRuntimeWrapper(InductorWrapper):
if config.functionalize_rng_ops:
# Update example inputs for the fw_compiler
fake_mode = detect_fake_mode()
assert fake_mode is not None
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
flat_args.extend([seed, offset])
# We are not clearing flat_args here because

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import contextlib
@ -37,10 +36,15 @@ log = logging.getLogger(__name__)
if TYPE_CHECKING:
from collections.abc import Generator, Iterator
from types import CodeType
import sympy
from torch._dynamo.codegen import PyCodegen
from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta
from torch._subclasses.fake_tensor import FakeTensorMode
"""
torch._guards is the definitional source of truth for general purpose guard structures.
@ -83,7 +87,7 @@ class CompileId:
# TODO: consider also tracking the recompilation count
# See Note: Updating CompileId
def __str__(self):
def __str__(self) -> str:
# NOTE: Keep this in sync with both from_string and the tlparse repo
if self.compiled_autograd_id is not None:
assert (self.frame_id is None) == (self.frame_compile_id is None)
@ -97,7 +101,7 @@ class CompileId:
return f"{self.frame_id}/{self.frame_compile_id}"
@classmethod
def from_string(cls, compile_id: Optional[str]):
def from_string(cls, compile_id: Optional[str]) -> Optional[CompileId]:
"""
Factory method that creates a CompileId from its string representation.
Keep this in sync with the __str__ method.
@ -125,7 +129,7 @@ class TraceId(NamedTuple):
# up by one
attempt: int
def __str__(self):
def __str__(self) -> str:
# Keep this in sync with tlparse repo
if self.attempt == 0:
return str(self.compile_id)
@ -185,7 +189,7 @@ class GuardSource(enum.Enum):
GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
)
def is_local(self):
def is_local(self) -> bool:
return self in (
GuardSource.LOCAL,
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
@ -218,7 +222,7 @@ class SLoc:
framework_loc: Optional[Union[traceback.FrameSummary, str]]
maybe_user_loc: Optional[str]
def __str__(self):
def __str__(self) -> str:
floc = (
self.framework_loc
if isinstance(self.framework_loc, str)
@ -257,7 +261,7 @@ class Guard:
# it is meaningless. Example create_fns that are like this include
# GRAD_MODE and SHAPE_ENV.
originating_source: Source
create_fn: Callable[[GuardBuilderBase, Guard], None]
create_fn: Callable[[GuardBuilderBase, Guard], Any]
# Export only. These values are written to at time of guard check_fn creation.
guard_types: Optional[list[str]] = None
@ -269,12 +273,12 @@ class Guard:
user_stack: Optional[traceback.StackSummary] = None
_hash: Optional[int] = None
def __hash__(self):
def __hash__(self) -> int:
if self._hash is None:
self._hash = hash((self.name, self.source, id(self.create_fn)))
return self._hash
def sort_key(self):
def sort_key(self) -> tuple[bool, int, int, str, int]:
# Put the duplicate input guards at the end. The duplicate guards have
# two sources while guard.name only considers one source.
@ -290,10 +294,10 @@ class Guard:
self.inner_create_fn().__code__.co_firstlineno,
)
def __lt__(self, other):
def __lt__(self, other: Guard) -> bool:
return self.sort_key() < other.sort_key()
def inner_create_fn(self):
def inner_create_fn(self) -> Callable[[GuardBuilderBase, Guard], Any]:
if isinstance(self.create_fn, functools.partial):
return self.create_fn.func
else:
@ -308,7 +312,7 @@ class Guard:
return self.originating_source.guard_source()
@staticmethod
def weakref_to_str(obj_weakref):
def weakref_to_str(obj_weakref: object) -> str:
"""
This is a workaround of a Python weakref bug.
@ -332,7 +336,7 @@ class Guard:
else:
return str(obj_weakref)
def __repr__(self):
def __repr__(self) -> str:
s = f"""
{self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
{{
@ -344,7 +348,7 @@ class Guard:
"""
return s
def __str__(self):
def __str__(self) -> str:
output = f"Name: {repr(self.name)}\n"
source = self.source.name.lower() if self.source else ""
output += f" Source: {source}\n"
@ -355,7 +359,7 @@ class Guard:
output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
return output
def create(self, builder: GuardBuilderBase):
def create(self, builder: GuardBuilderBase) -> Any:
try:
return self.create_fn(builder, self)
except Exception:
@ -364,16 +368,22 @@ class Guard:
log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
raise
def is_specialized_nn_module(self):
def is_specialized_nn_module(self) -> bool:
return self.source.is_specialized_nn_module()
def is_fsdp_module(self):
def is_fsdp_module(self) -> bool:
return self.source.is_fsdp_module()
def is_local(self):
def is_local(self) -> bool:
return self.source.is_local()
def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref):
def set_export_info(
self,
guard_type: str,
guarded_class: Optional[type],
code_list: list[str],
obj_weakref: object,
) -> None:
if not self.guard_types:
self.guard_types = []
@ -428,7 +438,7 @@ class DuplicateInputs(GuardEnvExpr):
input_source_a: Source
input_source_b: Source
def __post_init__(self):
def __post_init__(self) -> None:
assert self.input_source_a != self.input_source_b
@ -470,7 +480,7 @@ class Checkpointable(Generic[T]):
def copy_graphstate(self) -> T: ...
@abstractmethod
def restore_graphstate(self, state: T): ...
def restore_graphstate(self, state: T) -> None: ...
class GuardsCheckpointState:
@ -480,10 +490,10 @@ class GuardsCheckpointState:
dynamo_guards: set[Guard] = set()
def __init__(self, dynamo_guards):
def __init__(self, dynamo_guards: set[Guard]) -> None:
self.dynamo_guards = dynamo_guards
def diff(self, other):
def diff(self, other: GuardsCheckpointState) -> Optional[set[Guard]]:
"""
Produces a delta against another GuardsCheckpointState.
@ -495,17 +505,19 @@ class GuardsCheckpointState:
return None
return r
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, GuardsCheckpointState):
return False
return self.diff(other) is None
class ModuleContextCheckpointState:
nn_modules: dict[str, torch.nn.Module] = {}
def __init__(self, nn_modules):
def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None:
self.nn_modules = nn_modules
def diff(self, other):
def diff(self, other: ModuleContextCheckpointState) -> Optional[set[str]]:
"""
Produces a delta against another ModuleContextCheckpointState.
@ -517,7 +529,9 @@ class ModuleContextCheckpointState:
return None
return r
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, ModuleContextCheckpointState):
return False
return self.diff(other) is None
@ -525,21 +539,21 @@ class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
def __init__(self) -> None:
self.nn_modules: dict[str, Any] = {}
def copy_graphstate(self):
def copy_graphstate(self) -> ModuleContextCheckpointState:
return ModuleContextCheckpointState(dict(self.nn_modules))
def restore_graphstate(self, state):
def restore_graphstate(self, state: ModuleContextCheckpointState) -> None:
assert isinstance(state, ModuleContextCheckpointState)
self.nn_modules = state.nn_modules
class GlobalContextCheckpointState:
global_state: dict[str, tuple[Callable, ...]] = {}
global_state: dict[str, tuple[Callable, Any]] = {}
def __init__(self, global_states):
def __init__(self, global_states: dict[str, tuple[Callable, Any]]) -> None:
self.global_state = global_states
def diff(self, other):
def diff(self, other: GlobalContextCheckpointState) -> Optional[set[str]]:
"""
Produces a delta against another GlobalContextCheckpointState.
@ -551,7 +565,9 @@ class GlobalContextCheckpointState:
return None
return r
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, GlobalContextCheckpointState):
return False
return self.diff(other) is None
@ -571,12 +587,12 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
}
def __init__(self) -> None:
self.global_state: dict[str, tuple[Callable, ...]] = {}
self.global_state: dict[str, tuple[Callable, Any]] = {}
def copy_graphstate(self):
return GlobalContextCheckpointState(dict(self.global_state))
def copy_graphstate(self) -> GlobalContextCheckpointState:
return GlobalContextCheckpointState(self.global_state)
def restore_graphstate(self, state):
def restore_graphstate(self, state: GlobalContextCheckpointState) -> None:
assert isinstance(state, GlobalContextCheckpointState)
self.global_state = state.global_state
assert (
@ -590,26 +606,28 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
# Like a Set[Guard] but will record the user stack on all guards at the
# time they were installed at their destination
class GuardsSet:
def __init__(self, inner=None):
def __init__(self, inner: Optional[set[Guard]] = None) -> None:
if inner is None:
inner = set()
self.inner = inner
def __iter__(self):
def __iter__(self) -> Iterator[Guard]:
return iter(self.inner)
def __len__(self):
def __len__(self) -> int:
return len(self.inner)
# Subtraction along with bool is typically used to determine the delta of
# added guards between checkpoints for higher order ops
def __sub__(self, other):
def __sub__(self, other: GuardsSet) -> GuardsSet:
return GuardsSet(self.inner - other.inner)
def __bool__(self):
def __bool__(self) -> bool:
return bool(self.inner)
def add(self, guard: Guard, *, collect_debug_stack=True, skip=0):
def add(
self, guard: Guard, *, collect_debug_stack: bool = True, skip: int = 0
) -> None:
if guard in self.inner:
return
if collect_debug_stack:
@ -619,12 +637,12 @@ class GuardsSet:
guard.user_stack = TracingContext.extract_stack()
self.inner.add(guard)
def update(self, *others: set[Guard]):
def update(self, *others: set[Guard]) -> None:
for o in others:
for g in o:
self.add(g, skip=1)
def remove_guards_with_source(self, source):
def remove_guards_with_source(self, source: Source) -> None:
"""Delete all guards that contains a given source"""
from ._dynamo.source import is_from_source
@ -646,10 +664,10 @@ class GuardsContext(Checkpointable[GuardsCheckpointState]):
self.dynamo_guards: GuardsSet = GuardsSet()
self.aotautograd_guards: list[GuardEnvExpr] = []
def copy_graphstate(self):
def copy_graphstate(self) -> GuardsCheckpointState:
return GuardsCheckpointState(set(self.dynamo_guards.inner))
def restore_graphstate(self, state):
def restore_graphstate(self, state: GuardsCheckpointState) -> None:
# NB: "steals" the passed in state
assert isinstance(state, GuardsCheckpointState)
self.dynamo_guards = GuardsSet(state.dynamo_guards)
@ -657,22 +675,22 @@ class GuardsContext(Checkpointable[GuardsCheckpointState]):
class HopSubgraphCache:
@abstractmethod
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str): ...
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: ...
@abstractmethod
def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ...
@abstractmethod
def add_autograd_key_entry(self, identifier: str, key: Callable): ...
def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ...
@abstractmethod
def get_autograd_key_entry(self, identifier: str): ...
def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]: ...
@abstractmethod
def add_proxy_dispatch_entry(self, identifier: str, key: Callable): ...
def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: ...
@abstractmethod
def get_proxy_dispatch_entry(self, identifier: str): ...
def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: ...
@abstractmethod
def add_lazy_bwd_entry(
@ -680,12 +698,12 @@ class HopSubgraphCache:
identifier: str,
tangent_metadata: tuple[object],
gmod: torch.fx.GraphModule,
): ...
) -> int: ...
@abstractmethod
def get_lazy_bwd_entry(
self, identifier: str, tangent_metadata: tuple[object]
) -> int: ...
) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: ...
class InvokeSubgraphCache(HopSubgraphCache):
@ -697,22 +715,22 @@ class InvokeSubgraphCache(HopSubgraphCache):
str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
] = defaultdict(dict)
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str):
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None:
self.dynamo_installed_submodules[fn_id].append(identifier)
def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]:
return self.dynamo_installed_submodules.get(fn_id, [])
def add_autograd_key_entry(self, identifier: str, key: Callable):
def add_autograd_key_entry(self, identifier: str, key: Callable) -> None:
self.autograd_cache[identifier] = key
def get_autograd_key_entry(self, identifier: str):
def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]:
return self.autograd_cache.get(identifier, None)
def add_proxy_dispatch_entry(self, identifier: str, key: Callable):
def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None:
self.proxy_dispatch_cache[identifier] = key
def get_proxy_dispatch_entry(self, identifier: str):
def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]:
return self.proxy_dispatch_cache.get(identifier, None)
def add_lazy_bwd_entry(
@ -720,13 +738,15 @@ class InvokeSubgraphCache(HopSubgraphCache):
identifier: str,
tangent_metadata: tuple[object],
gmod: torch.fx.GraphModule,
):
) -> int:
# Save the number of existing graph modules in the dictionary to get the suffix
num_gmods = len(self.lazy_bwd_cache[identifier])
self.lazy_bwd_cache[identifier][tangent_metadata] = (gmod, num_gmods)
return num_gmods
def get_lazy_bwd_entry(self, identifier: str, tangent_metadata: tuple[object]):
def get_lazy_bwd_entry(
self, identifier: str, tangent_metadata: tuple[object]
) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]:
if identifier not in self.lazy_bwd_cache:
return (None, None)
@ -779,7 +799,7 @@ class CompileContext:
def try_get() -> Optional[CompileContext]:
return getattr(_TLS, "compile_context", None)
def __init__(self, compile_id):
def __init__(self, compile_id: Optional[CompileId]) -> None:
assert compile_id is None or isinstance(compile_id, CompileId)
self.compile_id: Optional[CompileId] = compile_id
self.attempt = 0
@ -787,14 +807,14 @@ class CompileContext:
self.shape_env_guards: list[str] = []
@staticmethod
def current_compile_id():
def current_compile_id() -> Optional[CompileId]:
self = CompileContext.try_get()
if self is None:
return None
return self.compile_id
@staticmethod
def current_trace_id():
def current_trace_id() -> Optional[TraceId]:
self = CompileContext.try_get()
if self is None:
return None
@ -823,28 +843,28 @@ class TracingContext:
"TracingContext.get() must be called within an ongoing trace."
)
def __init__(self, fake_mode):
def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None:
self.guards_context = GuardsContext()
self.module_context = ModuleContext()
self.global_context = GlobalContext()
self.previously_inlined_functions = dict()
self.previously_cleaned_instructions = dict()
self.fake_mode = fake_mode
self.frame_summary_stack = []
self.previously_inlined_functions: dict[Any, Any] = dict()
self.previously_cleaned_instructions: dict[Any, Any] = dict()
self.fake_mode: Optional[FakeTensorMode] = fake_mode
self.frame_summary_stack: list[traceback.FrameSummary] = []
# This is morally part of frame_summary_stack, but it is kept separate
# for clarity. As we process a frame, this variable gets updated
# to keep track of what line we are in the function. We make a
# function call, this gets cleared and the frame location is pushed
# to frame_summary_stack (prepping this variable for the inner frame's
# progress)
self.loc_in_frame = None
self.loc_in_frame: Optional[tuple[str, int, str]] = None
# this is only set after aot_autograd
self.fw_metadata = None
self.fw_metadata: Optional[ViewAndMutationMeta] = None
# this is only set after aot_autograd
self.aot_graph_name = None
self.params_flat = None
self.params_flat_unwrap_subclasses = None
self.params_unwrapped_to_flat_index = None
self.aot_graph_name: Optional[list[str]] = None
self.params_flat: Optional[list[Any]] = None
self.params_flat_unwrap_subclasses: Optional[list[Any]] = None
self.params_unwrapped_to_flat_index: Optional[list[Any]] = None
# this is for extended return calling convention from backend
# compiler to aot_autograd
# Per output, what the compiler specified stride of the output is,
@ -872,7 +892,7 @@ class TracingContext:
# list of code objects for inlined functions
self.traced_code: list[CodeType] = []
def clear(self):
def clear(self) -> None:
# Look at the note in output_graph.py in function `save_global_state`
# for the context on clearing global context.
self.global_context.global_state = {}
@ -881,7 +901,7 @@ class TracingContext:
@staticmethod
@contextmanager
def patch(**kwargs):
def patch(**kwargs: Any) -> Generator[None, None, None]:
prior = {}
ctx = TracingContext.get()
@ -897,7 +917,7 @@ class TracingContext:
setattr(ctx, key, val)
@staticmethod
def extract_stack():
def extract_stack() -> traceback.StackSummary:
self = TracingContext.try_get()
if self is None:
return traceback.StackSummary()
@ -906,7 +926,7 @@ class TracingContext:
stack = stack + [self._populate_loc_in_frame_summary()]
return traceback.StackSummary.from_list(stack)
def _populate_loc_in_frame_summary(self):
def _populate_loc_in_frame_summary(self) -> traceback.FrameSummary:
assert self.loc_in_frame is not None
filename, lineno, frame_name = self.loc_in_frame
return traceback.FrameSummary(filename, lineno, frame_name, lookup_line=False)
@ -915,7 +935,7 @@ class TracingContext:
# associated with the current frame state
@staticmethod
@contextlib.contextmanager
def clear_frame():
def clear_frame() -> Generator[None, None, None]:
tc = TracingContext.get()
with (
unittest.mock.patch.object(tc, "frame_summary_stack", []),
@ -947,7 +967,9 @@ class TracingContext:
@staticmethod
@contextlib.contextmanager
def current_frame(frame_summary):
def current_frame(
frame_summary: Optional[traceback.FrameSummary],
) -> Generator[None, None, None]:
# frame_summary can be None to solely take advantage of real_stack
# attachment to thrown exceptions
tc = TracingContext.get()
@ -968,7 +990,9 @@ class TracingContext:
@staticmethod
@contextlib.contextmanager
def report_output_strides():
def report_output_strides() -> Generator[
Optional[list[Optional[tuple[int, ...]]]], None, None
]:
tc = TracingContext.try_get()
if tc is None:
yield None
@ -981,13 +1005,13 @@ class TracingContext:
tc.output_strides = old_output_strides
@staticmethod
def set_current_loc(filename, lineno, frame_name):
def set_current_loc(filename: str, lineno: int, frame_name: str) -> None:
# Save the current location in the frame. Lazily generate the
# framesummary.
TracingContext.get().loc_in_frame = (filename, lineno, frame_name)
@staticmethod
def get_traced_code():
def get_traced_code() -> Optional[list[CodeType]]:
tc = TracingContext.try_get()
if tc is None:
return None
@ -995,7 +1019,9 @@ class TracingContext:
@contextmanager
def compile_context(context: Optional[CompileContext]):
def compile_context(
context: Optional[CompileContext],
) -> Generator[Optional[CompileContext], None, None]:
old_context = getattr(_TLS, "compile_context", None)
_TLS.compile_context = context
try:
@ -1005,7 +1031,9 @@ def compile_context(context: Optional[CompileContext]):
@contextmanager
def tracing(context: Optional[TracingContext]):
def tracing(
context: Optional[TracingContext],
) -> Generator[Optional[TracingContext], None, None]:
"""
This function installs the passed in tracing context as a dynamic scoped
global variable.
@ -1035,13 +1063,13 @@ def tracing(context: Optional[TracingContext]):
# TODO(voz): Consider a toplevel torch/_source.py
@dataclasses.dataclass(frozen=True)
class Source:
def is_dict_key(self):
def is_dict_key(self) -> bool:
return False
def is_ephemeral(self):
def is_ephemeral(self) -> bool:
return False
def reconstruct(self, codegen):
def reconstruct(self, codegen: PyCodegen) -> None:
raise NotImplementedError
def guard_source(self) -> GuardSource:
@ -1050,7 +1078,7 @@ class Source:
def name(self) -> str:
raise NotImplementedError
def make_guard(self, fn) -> Guard:
def make_guard(self, fn: Callable[..., Any]) -> Guard:
if self.guard_source() is GuardSource.CONSTANT:
raise NotImplementedError
return Guard(self, fn)
@ -1058,7 +1086,7 @@ class Source:
def is_specialized_nn_module(self) -> bool:
return self.guard_source().is_specialized_nn_module()
def subguards_allowed(self):
def subguards_allowed(self) -> bool:
"""True if you can guard on attributes of this"""
return self.guard_source() != GuardSource.SYNTHETIC_LOCAL
@ -1068,11 +1096,11 @@ class Source:
class ChainedSource(Source):
base: Source
def is_dict_key(self):
def is_dict_key(self) -> bool:
# Recurse until you either hit a ConstDictKey or a Source
return self.base.is_dict_key()
def is_ephemeral(self):
def is_ephemeral(self) -> bool:
return self.base.is_ephemeral()
def get_base(self) -> Source:
@ -1082,7 +1110,7 @@ class ChainedSource(Source):
return current
def detect_fake_mode(inputs: Any = None):
def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
"""
Attempts to "detect" what the current fake mode is. If there is one ambiently
available from TracingContext, we preferentially use that. Otherwise, we
@ -1126,7 +1154,7 @@ def detect_fake_mode(inputs: Any = None):
return None
def active_fake_mode():
def active_fake_mode() -> Optional[FakeTensorMode]:
"""
Inspects the dispatch mode stack for an active fake mode and returns it.
Returns None if no fake mode is active.

View File

@ -465,6 +465,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
from torch._subclasses.fake_tensor import extract_tensor_metadata
fake_mode = detect_fake_mode(primals + filtered_grad_outs)
assert fake_mode is not None, "fake_mode should be enabled for HOPs"
state = _CacheKeyState(fake_mode.shape_env)
tangent_metadata: list[object] = []
@ -607,6 +608,7 @@ def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, *operands):
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(operands)
assert fake_mode is not None and fake_mode.shape_env is not None
insert_deferred_runtime_asserts(
graph,
fake_mode.shape_env,

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import contextlib
import functools
from contextlib import contextmanager, ExitStack, nullcontext
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass
from typing import Any, Callable, Optional, overload, TypeVar, Union
@ -266,11 +266,12 @@ def _set_compilation_env():
# The invariant here is that we always trace the branch with fake tensor
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
fake_mode = detect_fake_mode(inputs)
tracing_mode = "real"
if fake_mode is None:
fake_mode = nullcontext()
tracing_mode = "fake"
fake_mode_det = detect_fake_mode(inputs)
fake_mode: AbstractContextManager = nullcontext()
tracing_mode = "fake"
if fake_mode_det is not None:
fake_mode = fake_mode_det
tracing_mode = "real"
# Note: we need to turn off proxy tensor mode to avoid tracing infra
# code that happens in make_fx e.g. we now call as_strided when wrapping tensor
@ -282,9 +283,12 @@ def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
pre_dispatch=pre_dispatch,
_error_on_data_dependent_ops=False,
)(*inputs)
if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None:
if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None: # type: ignore[attr-defined]
insert_deferred_runtime_asserts(
gm, fake_mode.shape_env, "hoo_maybe_fake_tracing", export=True
gm,
fake_mode.shape_env, # type: ignore[attr-defined]
"hoo_maybe_fake_tracing",
export=True, # type: ignore[attr-defined]
)
return gm

View File

@ -1065,7 +1065,7 @@ class GuardedCache(Generic[T]):
Helper to get the shape env from the tracing context.
"""
ctx = torch._guards.TracingContext.try_get()
if not ctx:
if not ctx or not ctx.fake_mode:
return None
return ctx.fake_mode.shape_env

View File

@ -1942,7 +1942,7 @@ def fw_compiler_freezing(
idx for idx, n in enumerate(model_outputs) if isinstance(n, torch.fx.Node)
]
static_input_idxs = []
static_input_idxs: list[Any] = []
# constant params will be real tensors, not fake
tracing_context = torch._guards.TracingContext.try_get()
unwrapped_args_offsets = [0]
@ -2461,6 +2461,7 @@ def compile_fx(
if node.op == "get_attr" and "val" not in node.meta:
target = attrgetter(node.target)(gm)
if isinstance(target, torch.Tensor):
assert fake_mode is not None
node.meta["val"] = fake_mode.from_tensor(
target, static_shapes=True
)

View File

@ -1429,7 +1429,9 @@ def register_replacement(
)
sym_args: list[torch.SymInt] = []
with torch._dynamo.utils.detect_fake_mode(args):
fake_mode = torch._dynamo.utils.detect_fake_mode(args)
assert fake_mode is not None
with fake_mode:
for i, grad in enumerate(requires_grad):
if isinstance(args[i], torch.Tensor):
if grad and is_integer_dtype(args[i].dtype):

View File

@ -203,6 +203,7 @@ def standalone_compile(
# Reuse fake_mode from the TracingContext.
# NB: The TracingContext only exists if we're currently in a torch.compile backend.
context = torch._guards.TracingContext.get()
assert context.fake_mode is not None
fake_mode = context.fake_mode
elif dynamic_shapes == "from_graph":
fake_mode = FakeTensorMode(shape_env=ShapeEnv())

View File

@ -2817,10 +2817,9 @@ def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[N
return contextlib.nullcontext()
# In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
shape_env = tracing_context.fake_mode.shape_env
if not shape_env:
if not tracing_context.fake_mode or not tracing_context.fake_mode.shape_env:
return contextlib.nullcontext()
shape_env = tracing_context.fake_mode.shape_env
return shape_env.suppress_guards()
@ -3343,12 +3342,13 @@ def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str:
for i, e in enumerate(row):
widths[i] = max(widths[i], len(str(e)))
lines = []
lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths)))
# Need nested {} for string formatting; ignore SET_LINTER here
lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths))) # noqa: set_linter
# widths whitespace horizontal separators
total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1)
lines.append("-" * total_width)
for row in elements:
lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths)))
lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths))) # noqa: set_linter
return "\n".join(lines)

View File

@ -5,7 +5,7 @@ import operator
import typing
import warnings
from collections.abc import Sequence
from contextlib import nullcontext
from contextlib import AbstractContextManager, nullcontext
from enum import Enum
from functools import reduce
from typing import (
@ -2090,7 +2090,9 @@ def alert_not_deterministic(caller: str):
class CUDARngStateHelper:
@staticmethod
def get_torch_state_as_tuple(fake_mode=nullcontext()):
def get_torch_state_as_tuple(
fake_mode: AbstractContextManager[Any] = nullcontext(),
):
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available")

View File

@ -573,8 +573,8 @@ def _decompose_and_get_gm_with_new_signature_constants(
delattr(ep.graph_module, name)
# TODO(zhxhchen17) Return the new graph_signature directly.
fake_mode = detect_fake_mode(fake_args)
fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode # type: ignore[assignment]
fake_mode_det = detect_fake_mode(fake_args)
fake_mode_ctx = contextlib.nullcontext() if fake_mode_det is None else fake_mode_det # type: ignore[assignment]
custom_triton_ops_decomposition_ctx = (
contextlib.nullcontext
if decompose_custom_triton_ops
@ -582,7 +582,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
)
with (
_ignore_backend_decomps(),
fake_mode,
fake_mode_ctx,
_override_composite_implicit_decomp(cia_to_decomp),
custom_triton_ops_decomposition_ctx(),
):

View File

@ -7872,7 +7872,9 @@ class PropagateUnbackedSymInts(torch.fx.Interpreter):
from torch._guards import detect_fake_mode
result = super().run_node(n)
rebind_unbacked(detect_fake_mode().shape_env, n, result)
fake_mode = detect_fake_mode()
assert fake_mode is not None
rebind_unbacked(fake_mode.shape_env, n, result)
return result

View File

@ -1028,7 +1028,7 @@ def bound_sympy(
# If there's a tracing context, augment available constrained ranges.
context = torch._guards.TracingContext.try_get()
if context and context.fake_mode.shape_env:
if context and context.fake_mode and context.fake_mode.shape_env:
if ranges:
ranges = {**context.fake_mode.shape_env.var_to_range, **ranges}
else: