mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: X-link: https://github.com/pytorch/executorch/pull/12986 As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to a critical set of files for dynamo, `source.py` and the base `_guards.py` Running ``` mypy torch/_dynamo/source.py torch/_guards.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 1227 | 2208 | 55.57% | 207 | 362 | 57.18% | | This PR | 2217 | 2217 | 100.00% | 362 | 362 | 100.00% | | Delta | +990 | +9 | +44.43% | +155 | 0 | +42.82% | cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Reviewed By: JacobSzwejbka, yangw-dev Differential Revision: D79199389 Pulled By: Lucaskabela Pull Request resolved: https://github.com/pytorch/pytorch/pull/159491 Approved by: https://github.com/anijain2305, https://github.com/yangw-dev
This commit is contained in:
committed by
PyTorch MergeBot
parent
1293405c8d
commit
2b1ae29960
@ -1848,7 +1848,7 @@ def export(
|
||||
ignore_fresh_unbacked = null_context()
|
||||
assert ambient_fake_mode is not None
|
||||
if shape_env := ambient_fake_mode.shape_env:
|
||||
ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols()
|
||||
ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() # type: ignore[assignment]
|
||||
|
||||
with (
|
||||
ambient_fake_mode,
|
||||
@ -1900,7 +1900,9 @@ def export(
|
||||
fakify_with_ambient, graph_inputs
|
||||
)
|
||||
graph_captured_result = torch.func.functional_call(
|
||||
graph, fake_params_buffers, fake_graph_inputs
|
||||
graph,
|
||||
fake_params_buffers, # type: ignore[arg-type]
|
||||
fake_graph_inputs, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return graph_captured_result
|
||||
|
@ -3685,7 +3685,7 @@ def strip_local_scope(s: str) -> str:
|
||||
def get_guard_fail_reason_helper(
|
||||
guard_manager: GuardFn,
|
||||
f_locals: dict[str, object],
|
||||
compile_id: CompileId,
|
||||
compile_id: Optional[CompileId],
|
||||
) -> str:
|
||||
"""
|
||||
Return the reason why `guard_manager` failed.
|
||||
|
@ -809,6 +809,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
|
||||
@property
|
||||
def shape_env(self):
|
||||
assert self.tracing_context.fake_mode is not None
|
||||
return self.tracing_context.fake_mode.shape_env
|
||||
|
||||
@property
|
||||
@ -1691,6 +1692,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
)
|
||||
self.call_cleanup_hooks()
|
||||
old_fake_mode = self.tracing_context.fake_mode
|
||||
assert old_fake_mode is not None
|
||||
if not self.export:
|
||||
import torch._functorch.config as _config
|
||||
|
||||
@ -1738,6 +1740,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
)
|
||||
|
||||
counters["stats"]["unique_graphs"] += 1
|
||||
assert old_fake_mode.shape_env is not None
|
||||
if specializations := old_fake_mode.shape_env.specializations:
|
||||
specialization_guards = []
|
||||
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
|
||||
|
@ -1,5 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
This module provides Source classes that track the origins of values in PyTorch Dynamo.
|
||||
Sources represent where values come from (e.g. local variables, globals, attributes) and
|
||||
@ -22,9 +20,9 @@ the code needed to recreate values.
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._guards import ChainedSource, GuardSource, Source
|
||||
from torch._guards import ChainedSource, Guard, GuardSource, Source
|
||||
|
||||
from . import utils
|
||||
from .bytecode_transformation import create_call_function, create_instruction
|
||||
@ -96,7 +94,7 @@ _GUARD_SOURCE_FSDP_MODULE = {
|
||||
}
|
||||
|
||||
|
||||
def is_constant_source(source):
|
||||
def is_constant_source(source: Source) -> bool:
|
||||
if isinstance(source, ConstantSource):
|
||||
return True
|
||||
try:
|
||||
@ -124,16 +122,16 @@ class LocalSource(Source):
|
||||
# or `co_freevars`.
|
||||
is_derefed_cell_contents: bool = False
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
if self.is_derefed_cell_contents:
|
||||
codegen.load_deref(self.local_name)
|
||||
else:
|
||||
codegen.append_output(codegen.create_load(self.local_name))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.LOCAL
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"L[{repr(self.local_name)}]"
|
||||
|
||||
|
||||
@ -141,13 +139,13 @@ class LocalSource(Source):
|
||||
class SyntheticLocalSource(Source):
|
||||
local_name: str
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.append_output(codegen.create_load(self.local_name))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.SYNTHETIC_LOCAL
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"SYNTHETIC_LOCAL[{self.local_name!r}]"
|
||||
|
||||
|
||||
@ -155,15 +153,15 @@ class SyntheticLocalSource(Source):
|
||||
class RandomValueSource(Source):
|
||||
random_call_index: int
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.RANDOM_VALUE
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
|
||||
codegen.append_output(codegen.create_load_const(self.random_call_index))
|
||||
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"random_value_{self.random_call_index}"
|
||||
|
||||
|
||||
@ -171,13 +169,13 @@ class RandomValueSource(Source):
|
||||
class GlobalSource(Source):
|
||||
global_name: str
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.append_output(codegen.create_load_global(self.global_name, add=True))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.GLOBAL
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"G[{repr(self.global_name)}]"
|
||||
|
||||
|
||||
@ -185,7 +183,7 @@ class GlobalSource(Source):
|
||||
class GlobalWeakRefSource(Source):
|
||||
global_name: str
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.append_output(
|
||||
codegen.create_load_global(self.global_name, add=True)
|
||||
@ -193,23 +191,23 @@ class GlobalWeakRefSource(Source):
|
||||
)
|
||||
codegen.extend_output(create_call_function(0, False))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.GLOBAL
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"G[{repr(self.global_name)}]()"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WeakRefCallSource(ChainedSource):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(lambda: codegen(self.base))
|
||||
codegen.extend_output(create_call_function(0, False))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}()"
|
||||
|
||||
|
||||
@ -222,7 +220,7 @@ class CallFunctionNoArgsSource(WeakRefCallSource):
|
||||
class AttrSource(ChainedSource):
|
||||
member: str
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
assert self.base, "Can't construct an AttrSource without a valid base source"
|
||||
if "." in self.member:
|
||||
member_parts = self.member.split(".")
|
||||
@ -231,14 +229,14 @@ class AttrSource(ChainedSource):
|
||||
)
|
||||
object.__setattr__(self, "member", member_parts[-1])
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
codegen.extend_output(codegen.create_load_attrs(self.member))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
if not self.member.isidentifier():
|
||||
return f"getattr({self.base.name()}, {self.member!r})"
|
||||
return f"{self.base.name()}.{self.member}"
|
||||
@ -248,7 +246,7 @@ class AttrSource(ChainedSource):
|
||||
class GenericAttrSource(ChainedSource):
|
||||
member: str
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
assert self.base, "Can't construct an AttrSource without a valid base source"
|
||||
if "." in self.member:
|
||||
member_parts = self.member.split(".")
|
||||
@ -257,14 +255,14 @@ class GenericAttrSource(ChainedSource):
|
||||
)
|
||||
object.__setattr__(self, "member", member_parts[-1])
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
codegen.extend_output(codegen.create_load_attrs(self.member))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"object.__getattribute__({self.base.name()}, {self.member!r})"
|
||||
|
||||
|
||||
@ -277,7 +275,7 @@ class LocalCellSource(Source):
|
||||
|
||||
local_name: str
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
|
||||
# Dynamo's bytecode transformation differentiates them slightly, so we
|
||||
# always emit `LOAD_CLOSURE` here.
|
||||
@ -295,20 +293,20 @@ class LocalCellSource(Source):
|
||||
class GradSource(ChainedSource):
|
||||
member: str = "grad"
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
codegen.extend_output(codegen.create_load_attrs(self.member))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.{self.member}"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ParamBufferSource(AttrSource):
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
|
||||
|
||||
|
||||
@ -331,16 +329,16 @@ class UnspecializedParamBufferSource(AttrSource):
|
||||
class EphemeralSource(Source):
|
||||
desc: Optional[str] = None
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.EPHEMERAL
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
|
||||
|
||||
def make_guard(self, fn):
|
||||
def make_guard(self, fn: Callable[..., Any]) -> Guard:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_ephemeral(self):
|
||||
def is_ephemeral(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@ -349,13 +347,15 @@ class TensorProperty(enum.Enum):
|
||||
STRIDE = 1
|
||||
STORAGE_OFFSET = 2
|
||||
|
||||
def method_name(self):
|
||||
def method_name(self) -> str:
|
||||
if self is TensorProperty.SIZE:
|
||||
return "size"
|
||||
elif self is TensorProperty.STRIDE:
|
||||
return "stride"
|
||||
elif self is TensorProperty.STORAGE_OFFSET:
|
||||
return "storage_offset"
|
||||
else:
|
||||
raise AssertionError(f"unhandled {self}")
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -363,14 +363,14 @@ class TensorPropertySource(ChainedSource):
|
||||
prop: TensorProperty
|
||||
idx: Optional[int] = None # None for STORAGE_OFFSET
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
assert self.base is not None
|
||||
if self.prop is TensorProperty.STORAGE_OFFSET:
|
||||
assert self.idx is None
|
||||
else:
|
||||
assert self.idx is not None
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
utils.__name__, f"call_{self.prop.method_name()}"
|
||||
@ -384,10 +384,10 @@ class TensorPropertySource(ChainedSource):
|
||||
create_call_function(2 if self.idx is not None else 1, False)
|
||||
)
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
if self.prop is TensorProperty.SIZE:
|
||||
return f"{self.base.name()}.size()[{self.idx}]"
|
||||
elif self.prop is TensorProperty.STRIDE:
|
||||
@ -403,88 +403,88 @@ class TensorPropertySource(ChainedSource):
|
||||
class IndexedSource(ChainedSource):
|
||||
idx: int
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"({self.idx}, {self.base.name()})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NegateSource(ChainedSource):
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
# NB: use method call so that function stripping regexes work
|
||||
return f"{self.base.name()}.__neg__()"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ConvertIntSource(ChainedSource):
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"cast_symbool_to_symint_guardless({self.base.name()})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FlattenScriptObjectSource(ChainedSource):
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.__obj_flatten__()"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ScriptObjectQualifiedNameSource(ChainedSource):
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}._type().qualified_name()"
|
||||
|
||||
|
||||
class AttrProxySource(ChainedSource):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.get_base()"
|
||||
|
||||
|
||||
@ -495,7 +495,7 @@ class DefaultsSource(ChainedSource):
|
||||
field: str = dataclasses.field(init=False, repr=False, compare=False)
|
||||
_name: str = dataclasses.field(init=False, repr=False, compare=False)
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
assert self.base, (
|
||||
"Base must be a valid source in order to properly track and guard this Defaults to its origin."
|
||||
)
|
||||
@ -512,16 +512,16 @@ class DefaultsSource(ChainedSource):
|
||||
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
codegen.extend_output(codegen.create_load_attrs(self.field))
|
||||
codegen.append_output(codegen.create_load_const(self.idx_key))
|
||||
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
|
||||
@ -530,14 +530,14 @@ class GetItemSource(ChainedSource):
|
||||
index: Any
|
||||
index_is_slice: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
assert self.base is not None
|
||||
if isinstance(self.index, slice):
|
||||
# store the hashable version of the slice so the whole GetItemSource is hashable
|
||||
super().__setattr__("index", self.index.__reduce__())
|
||||
super().__setattr__("index_is_slice", True)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
if self.index_is_slice:
|
||||
codegen.append_output(codegen.create_load_const(self.unpack_slice()))
|
||||
@ -545,15 +545,15 @@ class GetItemSource(ChainedSource):
|
||||
codegen.append_output(codegen.create_load_const(self.index))
|
||||
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def unpack_slice(self):
|
||||
def unpack_slice(self) -> slice:
|
||||
assert self.index_is_slice
|
||||
slice_class, slice_args = self.index
|
||||
return slice_class(*slice_args)
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
# Index can be of following types
|
||||
# 1) index is a slice - example 1:4
|
||||
# 2) index is a constant - example string, integer
|
||||
@ -568,10 +568,10 @@ class GetItemSource(ChainedSource):
|
||||
class ConstDictKeySource(ChainedSource):
|
||||
index: Any
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem")
|
||||
)
|
||||
@ -579,11 +579,11 @@ class ConstDictKeySource(ChainedSource):
|
||||
codegen.append_output(codegen.create_load_const(self.index))
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
# The list creation will be CSE'd by PyExprCSEPass
|
||||
return f"list(dict.keys({self.base.name()}))[{self.index!r}]"
|
||||
|
||||
def is_dict_key(self):
|
||||
def is_dict_key(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@ -591,15 +591,15 @@ class ConstDictKeySource(ChainedSource):
|
||||
class NonSerializableSetGetItemSource(ChainedSource):
|
||||
index: int
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
from .variables import ConstantVariable
|
||||
|
||||
assert ConstantVariable.is_literal(self.index)
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(utils.__name__, "set_getitem")
|
||||
)
|
||||
@ -607,11 +607,11 @@ class NonSerializableSetGetItemSource(ChainedSource):
|
||||
codegen.append_output(codegen.create_load_const(self.index))
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
# set ordering might not be stable
|
||||
return f"list({self.base.name()})[{self.index!r}]"
|
||||
|
||||
def is_dict_key(self):
|
||||
def is_dict_key(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@ -623,17 +623,17 @@ class DictGetItemSource(ChainedSource):
|
||||
# 2) constant - like string, integer
|
||||
index: Any
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
from .variables import ConstantVariable
|
||||
|
||||
assert isinstance(
|
||||
self.index, ConstDictKeySource
|
||||
) or ConstantVariable.is_literal(self.index)
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# Load dict
|
||||
codegen(self.base)
|
||||
|
||||
@ -644,7 +644,7 @@ class DictGetItemSource(ChainedSource):
|
||||
codegen.append_output(codegen.create_load_const(self.index))
|
||||
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
if isinstance(self.index, ConstDictKeySource):
|
||||
return f"{self.base.name()}[{self.index.name()}]"
|
||||
else:
|
||||
@ -660,17 +660,17 @@ class DictSubclassGetItemSource(ChainedSource):
|
||||
# 2) constant - like string, integer
|
||||
index: Any
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
from .variables import ConstantVariable
|
||||
|
||||
assert isinstance(
|
||||
self.index, ConstDictKeySource
|
||||
) or ConstantVariable.is_literal(self.index)
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# reconstruct dict.__getitem__(dct, key)
|
||||
|
||||
# Load dict.__getitem__
|
||||
@ -689,7 +689,7 @@ class DictSubclassGetItemSource(ChainedSource):
|
||||
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
if isinstance(self.index, ConstDictKeySource):
|
||||
return f"dict.__getitem__({self.base.name()}, {self.index.name()})"
|
||||
else:
|
||||
@ -702,7 +702,7 @@ class ListGetItemSource(GetItemSource):
|
||||
Same as GetItemSource with reconstruct and name overridden to be list specific.
|
||||
"""
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# Reconstruct list.__getitem__(lst, index) to avoid any side effects
|
||||
# from possibly overridden __getitem__.
|
||||
|
||||
@ -724,7 +724,7 @@ class ListGetItemSource(GetItemSource):
|
||||
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
# Index can be of following types
|
||||
# 1) index is a slice - example 1:4
|
||||
# 2) index is a constant - example string, integer
|
||||
@ -739,7 +739,7 @@ class ListGetItemSource(GetItemSource):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TupleIteratorGetItemSource(GetItemSource):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
|
||||
)
|
||||
@ -747,91 +747,91 @@ class TupleIteratorGetItemSource(GetItemSource):
|
||||
codegen.append_output(codegen.create_load_const(self.index))
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DataclassFieldsSource(ChainedSource):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(utils.__name__, "dataclass_fields")
|
||||
)
|
||||
codegen(self.base)
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"___dataclass_fields({self.base.name()})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TypeSource(ChainedSource):
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type"))
|
||||
codegen(self.base)
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"type({self.base.name()})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class OptimizerSource(ChainedSource):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return self.base.name()
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NNModuleSource(ChainedSource):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen(self.base)
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return self.base.name()
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UnspecializedNNModuleSource(NNModuleSource):
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource):
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FSDPNNModuleSource(NNModuleSource):
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GlobalStateSource(Source):
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return ""
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.GLOBAL
|
||||
|
||||
|
||||
@ -840,16 +840,16 @@ class TorchSource(Source):
|
||||
"""Points to the actual `torch` module - used instead of GlobalSource
|
||||
in case the user has overridden `torch` in their local namespace"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
from .guards import GuardBuilder, install_guard
|
||||
|
||||
install_guard(self.make_guard(GuardBuilder.ID_MATCH))
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return "__import__('torch')"
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.extend_output(
|
||||
[
|
||||
codegen.create_load_const(0), # level
|
||||
@ -858,7 +858,7 @@ class TorchSource(Source):
|
||||
]
|
||||
)
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.GLOBAL
|
||||
|
||||
|
||||
@ -866,15 +866,15 @@ class TorchSource(Source):
|
||||
class TorchFunctionModeStackSource(Source):
|
||||
ind: int
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"___get_torch_function_mode_stack_at({self._get_index()})"
|
||||
|
||||
def _get_index(self):
|
||||
def _get_index(self) -> int:
|
||||
from .variables.torch_function import TorchFunctionModeStackVariable
|
||||
|
||||
return TorchFunctionModeStackVariable.get_mode_index(self.ind)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
utils.__name__, "get_torch_function_mode_stack_at"
|
||||
@ -883,7 +883,7 @@ class TorchFunctionModeStackSource(Source):
|
||||
codegen.extend_output([codegen.create_load_const(self._get_index())])
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.GLOBAL
|
||||
|
||||
|
||||
@ -891,16 +891,16 @@ class TorchFunctionModeStackSource(Source):
|
||||
class ConstantSource(Source):
|
||||
source_name: str
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.append_output(codegen.create_load_global(self.source_name, add=False))
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.CONSTANT
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return self.source_name
|
||||
|
||||
def make_guard(self, fn):
|
||||
def make_guard(self, fn: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -909,10 +909,10 @@ class NumpyTensorSource(ChainedSource):
|
||||
def name(self) -> str:
|
||||
return f"___from_numpy({self.base.name()})"
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor"))
|
||||
codegen(self.base)
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
@ -923,7 +923,7 @@ class SubclassAttrListSource(ChainedSource):
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.__tensor_flatten__()[0]"
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
|
||||
@ -934,7 +934,7 @@ class FloatTensorSource(ChainedSource):
|
||||
def name(self) -> str:
|
||||
return f"___as_tensor({self.base.name()})"
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
|
||||
@ -943,7 +943,7 @@ class CallMethodItemSource(ChainedSource):
|
||||
def name(self) -> str:
|
||||
return f"{self.base.name()}.item()"
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
|
||||
@ -952,23 +952,25 @@ class CallMethodItemSource(ChainedSource):
|
||||
# guard contents from the ambient ShapeEnv
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ShapeEnvSource(Source):
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return ""
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.SHAPE_ENV
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BackwardStateSource(Source):
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return ""
|
||||
|
||||
def guard_source(self):
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.BACKWARD_STATE
|
||||
|
||||
|
||||
def get_local_source_name(source: Source, *, only_allow_input=False) -> Optional[str]:
|
||||
def get_local_source_name(
|
||||
source: Source, *, only_allow_input: bool = False
|
||||
) -> Optional[str]:
|
||||
if isinstance(source, ChainedSource):
|
||||
return get_local_source_name(source.base, only_allow_input=only_allow_input)
|
||||
if not isinstance(source, LocalSource):
|
||||
@ -978,7 +980,7 @@ def get_local_source_name(source: Source, *, only_allow_input=False) -> Optional
|
||||
return source.local_name
|
||||
|
||||
|
||||
def is_from_local_source(source: Source, *, only_allow_input=False):
|
||||
def is_from_local_source(source: Source, *, only_allow_input: bool = False) -> bool:
|
||||
return get_local_source_name(source, only_allow_input=only_allow_input) is not None
|
||||
|
||||
|
||||
@ -994,7 +996,7 @@ def get_global_source_name(source: Source) -> Optional[str]:
|
||||
return source.global_name
|
||||
|
||||
|
||||
def is_from_nonlocal_source(source: Source):
|
||||
def is_from_nonlocal_source(source: Source) -> bool:
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_nonlocal_source(source.base)
|
||||
return (
|
||||
@ -1004,14 +1006,14 @@ def is_from_nonlocal_source(source: Source):
|
||||
)
|
||||
|
||||
|
||||
def is_from_source(source: Source, target: Source):
|
||||
def is_from_source(source: Source, target: Source) -> bool:
|
||||
if isinstance(source, ChainedSource):
|
||||
return is_from_source(source.base, target)
|
||||
return source == target
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def is_from_unspecialized_nn_module_source(source: Source):
|
||||
def is_from_unspecialized_nn_module_source(source: Source) -> bool:
|
||||
if isinstance(source, UnspecializedNNModuleSource):
|
||||
return True
|
||||
if isinstance(source, ChainedSource):
|
||||
@ -1020,7 +1022,7 @@ def is_from_unspecialized_nn_module_source(source: Source):
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def is_from_unspecialized_builtin_nn_module_source(source: Source):
|
||||
def is_from_unspecialized_builtin_nn_module_source(source: Source) -> bool:
|
||||
if isinstance(source, UnspecializedBuiltinNNModuleSource):
|
||||
return True
|
||||
if isinstance(source, ChainedSource):
|
||||
@ -1029,7 +1031,7 @@ def is_from_unspecialized_builtin_nn_module_source(source: Source):
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def is_from_unspecialized_param_buffer_source(source: Source):
|
||||
def is_from_unspecialized_param_buffer_source(source: Source) -> bool:
|
||||
if isinstance(source, UnspecializedParamBufferSource):
|
||||
return True
|
||||
if isinstance(source, ChainedSource):
|
||||
@ -1038,7 +1040,7 @@ def is_from_unspecialized_param_buffer_source(source: Source):
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def is_from_flatten_script_object_source(source: Source):
|
||||
def is_from_flatten_script_object_source(source: Source) -> bool:
|
||||
if isinstance(source, FlattenScriptObjectSource):
|
||||
return True
|
||||
elif isinstance(source, ChainedSource):
|
||||
@ -1047,7 +1049,7 @@ def is_from_flatten_script_object_source(source: Source):
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def is_from_optimizer_source(source: Source):
|
||||
def is_from_optimizer_source(source: Source) -> bool:
|
||||
if isinstance(source, OptimizerSource):
|
||||
return True
|
||||
if isinstance(source, ChainedSource):
|
||||
@ -1058,7 +1060,7 @@ def is_from_optimizer_source(source: Source):
|
||||
# TODO: can probably write a generic "test this on everything in the chain"
|
||||
# helper
|
||||
@functools.lru_cache
|
||||
def is_from_defaults(source: Source):
|
||||
def is_from_defaults(source: Source) -> bool:
|
||||
if isinstance(source, DefaultsSource):
|
||||
return True
|
||||
|
||||
|
@ -2644,7 +2644,9 @@ def set_example_value(node, example_value):
|
||||
# this to accurately reflect what the state of the value was at the time
|
||||
# the program was traced).
|
||||
node.meta["example_value"] = example_value
|
||||
shape_env = TracingContext.get().fake_mode.shape_env
|
||||
fake_mode = TracingContext.get().fake_mode
|
||||
assert fake_mode is not None
|
||||
shape_env = fake_mode.shape_env
|
||||
if (
|
||||
symbol_to_path
|
||||
:= torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
|
||||
@ -4765,7 +4767,7 @@ def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None:
|
||||
|
||||
# Returns a set of code objects present traced in the current TracingContext, or None
|
||||
# if there is no current TracingContext.
|
||||
def get_traced_code() -> list[CodeType]:
|
||||
def get_traced_code() -> Optional[list[CodeType]]:
|
||||
from torch._guards import TracingContext
|
||||
|
||||
return TracingContext.get_traced_code()
|
||||
|
@ -365,6 +365,7 @@ def make_fake_inputs(
|
||||
# a toplevel TracingContext with a fake mode, so we do not want to
|
||||
# create another fake mode.
|
||||
fake_mode = context.fake_mode
|
||||
assert fake_mode is not None
|
||||
else:
|
||||
if isinstance(nn_module.forward, functools.partial):
|
||||
# functools handles nesting by itself, no need to recurse
|
||||
@ -852,7 +853,7 @@ def _fakify_script_objects(
|
||||
mod: torch.nn.Module,
|
||||
args: Sequence[Any],
|
||||
kwargs: dict[Any, Any],
|
||||
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
|
||||
fake_mode: Optional[torch._subclasses.fake_tensor.FakeTensorMode],
|
||||
):
|
||||
# This context manager is used to fakify script objects into FakeScriptObject.
|
||||
# Inputs:
|
||||
|
@ -1129,7 +1129,7 @@ def remove_proxy_from_state_dict(state_dict: dict, in_place: bool) -> dict:
|
||||
|
||||
def _detect_fake_mode_from_gm(
|
||||
gm: torch.fx.GraphModule,
|
||||
) -> torch._subclasses.fake_tensor.FakeTensorMode:
|
||||
) -> Optional[torch._subclasses.fake_tensor.FakeTensorMode]:
|
||||
"""
|
||||
For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs.
|
||||
Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes.
|
||||
|
@ -12,7 +12,7 @@ It does so by:
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from contextlib import contextmanager, ExitStack, nullcontext
|
||||
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
from unittest.mock import patch
|
||||
@ -441,9 +441,10 @@ def create_functionalized_rng_ops_wrapper(
|
||||
# It goes from (primals, tangents) to (seed, offset, primals, tangents)
|
||||
# At runtime, we pass on the current seed and offset. This is hidden from
|
||||
# the user.
|
||||
fake_mode = detect_fake_mode()
|
||||
if fake_mode is None:
|
||||
fake_mode = nullcontext()
|
||||
fake_mode_det = detect_fake_mode()
|
||||
fake_mode: AbstractContextManager[Any] = nullcontext()
|
||||
if fake_mode_det is not None:
|
||||
fake_mode = fake_mode_det
|
||||
|
||||
def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"):
|
||||
out = PhiloxStateTracker.get_state_as_tensor()
|
||||
@ -1343,7 +1344,9 @@ def create_functional_call(
|
||||
"ignore", "Anomaly Detection has been enabled."
|
||||
)
|
||||
with torch.autograd.detect_anomaly(check_nan=False):
|
||||
detect_fake_mode().epoch += 1
|
||||
fake_mode = detect_fake_mode()
|
||||
assert fake_mode is not None
|
||||
fake_mode.epoch += 1
|
||||
out = PropagateUnbackedSymInts(mod).run(
|
||||
*args[params_len:], **kwargs
|
||||
)
|
||||
|
@ -306,11 +306,12 @@ def compute_overlapping_inputs(aot_config, fwd_inputs, aliased_input_indices):
|
||||
tracing_context = torch._guards.TracingContext.try_get()
|
||||
|
||||
if tracing_context is not None:
|
||||
assert tracing_context.fake_mode is not None
|
||||
shape_env = tracing_context.fake_mode.shape_env
|
||||
|
||||
# Check whether we can actually get the dynamo sources from within AOTAutograd.
|
||||
if aot_config.aot_autograd_arg_pos_to_source and shape_env is not None:
|
||||
maybe_suppress_guards = shape_env.suppress_guards
|
||||
maybe_suppress_guards = shape_env.suppress_guards # type: ignore[assignment]
|
||||
|
||||
# Check whether there are any symbolic values being used.
|
||||
# We do this for 2 reasons:
|
||||
|
@ -495,6 +495,7 @@ class FunctionalizedRngRuntimeWrapper(InductorWrapper):
|
||||
if config.functionalize_rng_ops:
|
||||
# Update example inputs for the fw_compiler
|
||||
fake_mode = detect_fake_mode()
|
||||
assert fake_mode is not None
|
||||
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
|
||||
flat_args.extend([seed, offset])
|
||||
# We are not clearing flat_args here because
|
||||
|
218
torch/_guards.py
218
torch/_guards.py
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
@ -37,10 +36,15 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Iterator
|
||||
from types import CodeType
|
||||
|
||||
import sympy
|
||||
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
||||
|
||||
"""
|
||||
torch._guards is the definitional source of truth for general purpose guard structures.
|
||||
@ -83,7 +87,7 @@ class CompileId:
|
||||
# TODO: consider also tracking the recompilation count
|
||||
# See Note: Updating CompileId
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
# NOTE: Keep this in sync with both from_string and the tlparse repo
|
||||
if self.compiled_autograd_id is not None:
|
||||
assert (self.frame_id is None) == (self.frame_compile_id is None)
|
||||
@ -97,7 +101,7 @@ class CompileId:
|
||||
return f"{self.frame_id}/{self.frame_compile_id}"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, compile_id: Optional[str]):
|
||||
def from_string(cls, compile_id: Optional[str]) -> Optional[CompileId]:
|
||||
"""
|
||||
Factory method that creates a CompileId from its string representation.
|
||||
Keep this in sync with the __str__ method.
|
||||
@ -125,7 +129,7 @@ class TraceId(NamedTuple):
|
||||
# up by one
|
||||
attempt: int
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
# Keep this in sync with tlparse repo
|
||||
if self.attempt == 0:
|
||||
return str(self.compile_id)
|
||||
@ -185,7 +189,7 @@ class GuardSource(enum.Enum):
|
||||
GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
|
||||
)
|
||||
|
||||
def is_local(self):
|
||||
def is_local(self) -> bool:
|
||||
return self in (
|
||||
GuardSource.LOCAL,
|
||||
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
|
||||
@ -218,7 +222,7 @@ class SLoc:
|
||||
framework_loc: Optional[Union[traceback.FrameSummary, str]]
|
||||
maybe_user_loc: Optional[str]
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
floc = (
|
||||
self.framework_loc
|
||||
if isinstance(self.framework_loc, str)
|
||||
@ -257,7 +261,7 @@ class Guard:
|
||||
# it is meaningless. Example create_fns that are like this include
|
||||
# GRAD_MODE and SHAPE_ENV.
|
||||
originating_source: Source
|
||||
create_fn: Callable[[GuardBuilderBase, Guard], None]
|
||||
create_fn: Callable[[GuardBuilderBase, Guard], Any]
|
||||
|
||||
# Export only. These values are written to at time of guard check_fn creation.
|
||||
guard_types: Optional[list[str]] = None
|
||||
@ -269,12 +273,12 @@ class Guard:
|
||||
user_stack: Optional[traceback.StackSummary] = None
|
||||
_hash: Optional[int] = None
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
if self._hash is None:
|
||||
self._hash = hash((self.name, self.source, id(self.create_fn)))
|
||||
return self._hash
|
||||
|
||||
def sort_key(self):
|
||||
def sort_key(self) -> tuple[bool, int, int, str, int]:
|
||||
# Put the duplicate input guards at the end. The duplicate guards have
|
||||
# two sources while guard.name only considers one source.
|
||||
|
||||
@ -290,10 +294,10 @@ class Guard:
|
||||
self.inner_create_fn().__code__.co_firstlineno,
|
||||
)
|
||||
|
||||
def __lt__(self, other):
|
||||
def __lt__(self, other: Guard) -> bool:
|
||||
return self.sort_key() < other.sort_key()
|
||||
|
||||
def inner_create_fn(self):
|
||||
def inner_create_fn(self) -> Callable[[GuardBuilderBase, Guard], Any]:
|
||||
if isinstance(self.create_fn, functools.partial):
|
||||
return self.create_fn.func
|
||||
else:
|
||||
@ -308,7 +312,7 @@ class Guard:
|
||||
return self.originating_source.guard_source()
|
||||
|
||||
@staticmethod
|
||||
def weakref_to_str(obj_weakref):
|
||||
def weakref_to_str(obj_weakref: object) -> str:
|
||||
"""
|
||||
This is a workaround of a Python weakref bug.
|
||||
|
||||
@ -332,7 +336,7 @@ class Guard:
|
||||
else:
|
||||
return str(obj_weakref)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
s = f"""
|
||||
{self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
|
||||
{{
|
||||
@ -344,7 +348,7 @@ class Guard:
|
||||
"""
|
||||
return s
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
output = f"Name: {repr(self.name)}\n"
|
||||
source = self.source.name.lower() if self.source else ""
|
||||
output += f" Source: {source}\n"
|
||||
@ -355,7 +359,7 @@ class Guard:
|
||||
output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
|
||||
return output
|
||||
|
||||
def create(self, builder: GuardBuilderBase):
|
||||
def create(self, builder: GuardBuilderBase) -> Any:
|
||||
try:
|
||||
return self.create_fn(builder, self)
|
||||
except Exception:
|
||||
@ -364,16 +368,22 @@ class Guard:
|
||||
log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
|
||||
raise
|
||||
|
||||
def is_specialized_nn_module(self):
|
||||
def is_specialized_nn_module(self) -> bool:
|
||||
return self.source.is_specialized_nn_module()
|
||||
|
||||
def is_fsdp_module(self):
|
||||
def is_fsdp_module(self) -> bool:
|
||||
return self.source.is_fsdp_module()
|
||||
|
||||
def is_local(self):
|
||||
def is_local(self) -> bool:
|
||||
return self.source.is_local()
|
||||
|
||||
def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref):
|
||||
def set_export_info(
|
||||
self,
|
||||
guard_type: str,
|
||||
guarded_class: Optional[type],
|
||||
code_list: list[str],
|
||||
obj_weakref: object,
|
||||
) -> None:
|
||||
if not self.guard_types:
|
||||
self.guard_types = []
|
||||
|
||||
@ -428,7 +438,7 @@ class DuplicateInputs(GuardEnvExpr):
|
||||
input_source_a: Source
|
||||
input_source_b: Source
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
assert self.input_source_a != self.input_source_b
|
||||
|
||||
|
||||
@ -470,7 +480,7 @@ class Checkpointable(Generic[T]):
|
||||
def copy_graphstate(self) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def restore_graphstate(self, state: T): ...
|
||||
def restore_graphstate(self, state: T) -> None: ...
|
||||
|
||||
|
||||
class GuardsCheckpointState:
|
||||
@ -480,10 +490,10 @@ class GuardsCheckpointState:
|
||||
|
||||
dynamo_guards: set[Guard] = set()
|
||||
|
||||
def __init__(self, dynamo_guards):
|
||||
def __init__(self, dynamo_guards: set[Guard]) -> None:
|
||||
self.dynamo_guards = dynamo_guards
|
||||
|
||||
def diff(self, other):
|
||||
def diff(self, other: GuardsCheckpointState) -> Optional[set[Guard]]:
|
||||
"""
|
||||
Produces a delta against another GuardsCheckpointState.
|
||||
|
||||
@ -495,17 +505,19 @@ class GuardsCheckpointState:
|
||||
return None
|
||||
return r
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, GuardsCheckpointState):
|
||||
return False
|
||||
return self.diff(other) is None
|
||||
|
||||
|
||||
class ModuleContextCheckpointState:
|
||||
nn_modules: dict[str, torch.nn.Module] = {}
|
||||
|
||||
def __init__(self, nn_modules):
|
||||
def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None:
|
||||
self.nn_modules = nn_modules
|
||||
|
||||
def diff(self, other):
|
||||
def diff(self, other: ModuleContextCheckpointState) -> Optional[set[str]]:
|
||||
"""
|
||||
Produces a delta against another ModuleContextCheckpointState.
|
||||
|
||||
@ -517,7 +529,9 @@ class ModuleContextCheckpointState:
|
||||
return None
|
||||
return r
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, ModuleContextCheckpointState):
|
||||
return False
|
||||
return self.diff(other) is None
|
||||
|
||||
|
||||
@ -525,21 +539,21 @@ class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
|
||||
def __init__(self) -> None:
|
||||
self.nn_modules: dict[str, Any] = {}
|
||||
|
||||
def copy_graphstate(self):
|
||||
def copy_graphstate(self) -> ModuleContextCheckpointState:
|
||||
return ModuleContextCheckpointState(dict(self.nn_modules))
|
||||
|
||||
def restore_graphstate(self, state):
|
||||
def restore_graphstate(self, state: ModuleContextCheckpointState) -> None:
|
||||
assert isinstance(state, ModuleContextCheckpointState)
|
||||
self.nn_modules = state.nn_modules
|
||||
|
||||
|
||||
class GlobalContextCheckpointState:
|
||||
global_state: dict[str, tuple[Callable, ...]] = {}
|
||||
global_state: dict[str, tuple[Callable, Any]] = {}
|
||||
|
||||
def __init__(self, global_states):
|
||||
def __init__(self, global_states: dict[str, tuple[Callable, Any]]) -> None:
|
||||
self.global_state = global_states
|
||||
|
||||
def diff(self, other):
|
||||
def diff(self, other: GlobalContextCheckpointState) -> Optional[set[str]]:
|
||||
"""
|
||||
Produces a delta against another GlobalContextCheckpointState.
|
||||
|
||||
@ -551,7 +565,9 @@ class GlobalContextCheckpointState:
|
||||
return None
|
||||
return r
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, GlobalContextCheckpointState):
|
||||
return False
|
||||
return self.diff(other) is None
|
||||
|
||||
|
||||
@ -571,12 +587,12 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.global_state: dict[str, tuple[Callable, ...]] = {}
|
||||
self.global_state: dict[str, tuple[Callable, Any]] = {}
|
||||
|
||||
def copy_graphstate(self):
|
||||
return GlobalContextCheckpointState(dict(self.global_state))
|
||||
def copy_graphstate(self) -> GlobalContextCheckpointState:
|
||||
return GlobalContextCheckpointState(self.global_state)
|
||||
|
||||
def restore_graphstate(self, state):
|
||||
def restore_graphstate(self, state: GlobalContextCheckpointState) -> None:
|
||||
assert isinstance(state, GlobalContextCheckpointState)
|
||||
self.global_state = state.global_state
|
||||
assert (
|
||||
@ -590,26 +606,28 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
|
||||
# Like a Set[Guard] but will record the user stack on all guards at the
|
||||
# time they were installed at their destination
|
||||
class GuardsSet:
|
||||
def __init__(self, inner=None):
|
||||
def __init__(self, inner: Optional[set[Guard]] = None) -> None:
|
||||
if inner is None:
|
||||
inner = set()
|
||||
self.inner = inner
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[Guard]:
|
||||
return iter(self.inner)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.inner)
|
||||
|
||||
# Subtraction along with bool is typically used to determine the delta of
|
||||
# added guards between checkpoints for higher order ops
|
||||
def __sub__(self, other):
|
||||
def __sub__(self, other: GuardsSet) -> GuardsSet:
|
||||
return GuardsSet(self.inner - other.inner)
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.inner)
|
||||
|
||||
def add(self, guard: Guard, *, collect_debug_stack=True, skip=0):
|
||||
def add(
|
||||
self, guard: Guard, *, collect_debug_stack: bool = True, skip: int = 0
|
||||
) -> None:
|
||||
if guard in self.inner:
|
||||
return
|
||||
if collect_debug_stack:
|
||||
@ -619,12 +637,12 @@ class GuardsSet:
|
||||
guard.user_stack = TracingContext.extract_stack()
|
||||
self.inner.add(guard)
|
||||
|
||||
def update(self, *others: set[Guard]):
|
||||
def update(self, *others: set[Guard]) -> None:
|
||||
for o in others:
|
||||
for g in o:
|
||||
self.add(g, skip=1)
|
||||
|
||||
def remove_guards_with_source(self, source):
|
||||
def remove_guards_with_source(self, source: Source) -> None:
|
||||
"""Delete all guards that contains a given source"""
|
||||
from ._dynamo.source import is_from_source
|
||||
|
||||
@ -646,10 +664,10 @@ class GuardsContext(Checkpointable[GuardsCheckpointState]):
|
||||
self.dynamo_guards: GuardsSet = GuardsSet()
|
||||
self.aotautograd_guards: list[GuardEnvExpr] = []
|
||||
|
||||
def copy_graphstate(self):
|
||||
def copy_graphstate(self) -> GuardsCheckpointState:
|
||||
return GuardsCheckpointState(set(self.dynamo_guards.inner))
|
||||
|
||||
def restore_graphstate(self, state):
|
||||
def restore_graphstate(self, state: GuardsCheckpointState) -> None:
|
||||
# NB: "steals" the passed in state
|
||||
assert isinstance(state, GuardsCheckpointState)
|
||||
self.dynamo_guards = GuardsSet(state.dynamo_guards)
|
||||
@ -657,22 +675,22 @@ class GuardsContext(Checkpointable[GuardsCheckpointState]):
|
||||
|
||||
class HopSubgraphCache:
|
||||
@abstractmethod
|
||||
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str): ...
|
||||
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ...
|
||||
|
||||
@abstractmethod
|
||||
def add_autograd_key_entry(self, identifier: str, key: Callable): ...
|
||||
def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_autograd_key_entry(self, identifier: str): ...
|
||||
def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]: ...
|
||||
|
||||
@abstractmethod
|
||||
def add_proxy_dispatch_entry(self, identifier: str, key: Callable): ...
|
||||
def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_proxy_dispatch_entry(self, identifier: str): ...
|
||||
def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: ...
|
||||
|
||||
@abstractmethod
|
||||
def add_lazy_bwd_entry(
|
||||
@ -680,12 +698,12 @@ class HopSubgraphCache:
|
||||
identifier: str,
|
||||
tangent_metadata: tuple[object],
|
||||
gmod: torch.fx.GraphModule,
|
||||
): ...
|
||||
) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_lazy_bwd_entry(
|
||||
self, identifier: str, tangent_metadata: tuple[object]
|
||||
) -> int: ...
|
||||
) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: ...
|
||||
|
||||
|
||||
class InvokeSubgraphCache(HopSubgraphCache):
|
||||
@ -697,22 +715,22 @@ class InvokeSubgraphCache(HopSubgraphCache):
|
||||
str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
|
||||
] = defaultdict(dict)
|
||||
|
||||
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str):
|
||||
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None:
|
||||
self.dynamo_installed_submodules[fn_id].append(identifier)
|
||||
|
||||
def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]:
|
||||
return self.dynamo_installed_submodules.get(fn_id, [])
|
||||
|
||||
def add_autograd_key_entry(self, identifier: str, key: Callable):
|
||||
def add_autograd_key_entry(self, identifier: str, key: Callable) -> None:
|
||||
self.autograd_cache[identifier] = key
|
||||
|
||||
def get_autograd_key_entry(self, identifier: str):
|
||||
def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]:
|
||||
return self.autograd_cache.get(identifier, None)
|
||||
|
||||
def add_proxy_dispatch_entry(self, identifier: str, key: Callable):
|
||||
def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None:
|
||||
self.proxy_dispatch_cache[identifier] = key
|
||||
|
||||
def get_proxy_dispatch_entry(self, identifier: str):
|
||||
def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]:
|
||||
return self.proxy_dispatch_cache.get(identifier, None)
|
||||
|
||||
def add_lazy_bwd_entry(
|
||||
@ -720,13 +738,15 @@ class InvokeSubgraphCache(HopSubgraphCache):
|
||||
identifier: str,
|
||||
tangent_metadata: tuple[object],
|
||||
gmod: torch.fx.GraphModule,
|
||||
):
|
||||
) -> int:
|
||||
# Save the number of existing graph modules in the dictionary to get the suffix
|
||||
num_gmods = len(self.lazy_bwd_cache[identifier])
|
||||
self.lazy_bwd_cache[identifier][tangent_metadata] = (gmod, num_gmods)
|
||||
return num_gmods
|
||||
|
||||
def get_lazy_bwd_entry(self, identifier: str, tangent_metadata: tuple[object]):
|
||||
def get_lazy_bwd_entry(
|
||||
self, identifier: str, tangent_metadata: tuple[object]
|
||||
) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]:
|
||||
if identifier not in self.lazy_bwd_cache:
|
||||
return (None, None)
|
||||
|
||||
@ -779,7 +799,7 @@ class CompileContext:
|
||||
def try_get() -> Optional[CompileContext]:
|
||||
return getattr(_TLS, "compile_context", None)
|
||||
|
||||
def __init__(self, compile_id):
|
||||
def __init__(self, compile_id: Optional[CompileId]) -> None:
|
||||
assert compile_id is None or isinstance(compile_id, CompileId)
|
||||
self.compile_id: Optional[CompileId] = compile_id
|
||||
self.attempt = 0
|
||||
@ -787,14 +807,14 @@ class CompileContext:
|
||||
self.shape_env_guards: list[str] = []
|
||||
|
||||
@staticmethod
|
||||
def current_compile_id():
|
||||
def current_compile_id() -> Optional[CompileId]:
|
||||
self = CompileContext.try_get()
|
||||
if self is None:
|
||||
return None
|
||||
return self.compile_id
|
||||
|
||||
@staticmethod
|
||||
def current_trace_id():
|
||||
def current_trace_id() -> Optional[TraceId]:
|
||||
self = CompileContext.try_get()
|
||||
if self is None:
|
||||
return None
|
||||
@ -823,28 +843,28 @@ class TracingContext:
|
||||
"TracingContext.get() must be called within an ongoing trace."
|
||||
)
|
||||
|
||||
def __init__(self, fake_mode):
|
||||
def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None:
|
||||
self.guards_context = GuardsContext()
|
||||
self.module_context = ModuleContext()
|
||||
self.global_context = GlobalContext()
|
||||
self.previously_inlined_functions = dict()
|
||||
self.previously_cleaned_instructions = dict()
|
||||
self.fake_mode = fake_mode
|
||||
self.frame_summary_stack = []
|
||||
self.previously_inlined_functions: dict[Any, Any] = dict()
|
||||
self.previously_cleaned_instructions: dict[Any, Any] = dict()
|
||||
self.fake_mode: Optional[FakeTensorMode] = fake_mode
|
||||
self.frame_summary_stack: list[traceback.FrameSummary] = []
|
||||
# This is morally part of frame_summary_stack, but it is kept separate
|
||||
# for clarity. As we process a frame, this variable gets updated
|
||||
# to keep track of what line we are in the function. We make a
|
||||
# function call, this gets cleared and the frame location is pushed
|
||||
# to frame_summary_stack (prepping this variable for the inner frame's
|
||||
# progress)
|
||||
self.loc_in_frame = None
|
||||
self.loc_in_frame: Optional[tuple[str, int, str]] = None
|
||||
# this is only set after aot_autograd
|
||||
self.fw_metadata = None
|
||||
self.fw_metadata: Optional[ViewAndMutationMeta] = None
|
||||
# this is only set after aot_autograd
|
||||
self.aot_graph_name = None
|
||||
self.params_flat = None
|
||||
self.params_flat_unwrap_subclasses = None
|
||||
self.params_unwrapped_to_flat_index = None
|
||||
self.aot_graph_name: Optional[list[str]] = None
|
||||
self.params_flat: Optional[list[Any]] = None
|
||||
self.params_flat_unwrap_subclasses: Optional[list[Any]] = None
|
||||
self.params_unwrapped_to_flat_index: Optional[list[Any]] = None
|
||||
# this is for extended return calling convention from backend
|
||||
# compiler to aot_autograd
|
||||
# Per output, what the compiler specified stride of the output is,
|
||||
@ -872,7 +892,7 @@ class TracingContext:
|
||||
# list of code objects for inlined functions
|
||||
self.traced_code: list[CodeType] = []
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
# Look at the note in output_graph.py in function `save_global_state`
|
||||
# for the context on clearing global context.
|
||||
self.global_context.global_state = {}
|
||||
@ -881,7 +901,7 @@ class TracingContext:
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def patch(**kwargs):
|
||||
def patch(**kwargs: Any) -> Generator[None, None, None]:
|
||||
prior = {}
|
||||
ctx = TracingContext.get()
|
||||
|
||||
@ -897,7 +917,7 @@ class TracingContext:
|
||||
setattr(ctx, key, val)
|
||||
|
||||
@staticmethod
|
||||
def extract_stack():
|
||||
def extract_stack() -> traceback.StackSummary:
|
||||
self = TracingContext.try_get()
|
||||
if self is None:
|
||||
return traceback.StackSummary()
|
||||
@ -906,7 +926,7 @@ class TracingContext:
|
||||
stack = stack + [self._populate_loc_in_frame_summary()]
|
||||
return traceback.StackSummary.from_list(stack)
|
||||
|
||||
def _populate_loc_in_frame_summary(self):
|
||||
def _populate_loc_in_frame_summary(self) -> traceback.FrameSummary:
|
||||
assert self.loc_in_frame is not None
|
||||
filename, lineno, frame_name = self.loc_in_frame
|
||||
return traceback.FrameSummary(filename, lineno, frame_name, lookup_line=False)
|
||||
@ -915,7 +935,7 @@ class TracingContext:
|
||||
# associated with the current frame state
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def clear_frame():
|
||||
def clear_frame() -> Generator[None, None, None]:
|
||||
tc = TracingContext.get()
|
||||
with (
|
||||
unittest.mock.patch.object(tc, "frame_summary_stack", []),
|
||||
@ -947,7 +967,9 @@ class TracingContext:
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def current_frame(frame_summary):
|
||||
def current_frame(
|
||||
frame_summary: Optional[traceback.FrameSummary],
|
||||
) -> Generator[None, None, None]:
|
||||
# frame_summary can be None to solely take advantage of real_stack
|
||||
# attachment to thrown exceptions
|
||||
tc = TracingContext.get()
|
||||
@ -968,7 +990,9 @@ class TracingContext:
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def report_output_strides():
|
||||
def report_output_strides() -> Generator[
|
||||
Optional[list[Optional[tuple[int, ...]]]], None, None
|
||||
]:
|
||||
tc = TracingContext.try_get()
|
||||
if tc is None:
|
||||
yield None
|
||||
@ -981,13 +1005,13 @@ class TracingContext:
|
||||
tc.output_strides = old_output_strides
|
||||
|
||||
@staticmethod
|
||||
def set_current_loc(filename, lineno, frame_name):
|
||||
def set_current_loc(filename: str, lineno: int, frame_name: str) -> None:
|
||||
# Save the current location in the frame. Lazily generate the
|
||||
# framesummary.
|
||||
TracingContext.get().loc_in_frame = (filename, lineno, frame_name)
|
||||
|
||||
@staticmethod
|
||||
def get_traced_code():
|
||||
def get_traced_code() -> Optional[list[CodeType]]:
|
||||
tc = TracingContext.try_get()
|
||||
if tc is None:
|
||||
return None
|
||||
@ -995,7 +1019,9 @@ class TracingContext:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def compile_context(context: Optional[CompileContext]):
|
||||
def compile_context(
|
||||
context: Optional[CompileContext],
|
||||
) -> Generator[Optional[CompileContext], None, None]:
|
||||
old_context = getattr(_TLS, "compile_context", None)
|
||||
_TLS.compile_context = context
|
||||
try:
|
||||
@ -1005,7 +1031,9 @@ def compile_context(context: Optional[CompileContext]):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tracing(context: Optional[TracingContext]):
|
||||
def tracing(
|
||||
context: Optional[TracingContext],
|
||||
) -> Generator[Optional[TracingContext], None, None]:
|
||||
"""
|
||||
This function installs the passed in tracing context as a dynamic scoped
|
||||
global variable.
|
||||
@ -1035,13 +1063,13 @@ def tracing(context: Optional[TracingContext]):
|
||||
# TODO(voz): Consider a toplevel torch/_source.py
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Source:
|
||||
def is_dict_key(self):
|
||||
def is_dict_key(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_ephemeral(self):
|
||||
def is_ephemeral(self) -> bool:
|
||||
return False
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: PyCodegen) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
@ -1050,7 +1078,7 @@ class Source:
|
||||
def name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def make_guard(self, fn) -> Guard:
|
||||
def make_guard(self, fn: Callable[..., Any]) -> Guard:
|
||||
if self.guard_source() is GuardSource.CONSTANT:
|
||||
raise NotImplementedError
|
||||
return Guard(self, fn)
|
||||
@ -1058,7 +1086,7 @@ class Source:
|
||||
def is_specialized_nn_module(self) -> bool:
|
||||
return self.guard_source().is_specialized_nn_module()
|
||||
|
||||
def subguards_allowed(self):
|
||||
def subguards_allowed(self) -> bool:
|
||||
"""True if you can guard on attributes of this"""
|
||||
return self.guard_source() != GuardSource.SYNTHETIC_LOCAL
|
||||
|
||||
@ -1068,11 +1096,11 @@ class Source:
|
||||
class ChainedSource(Source):
|
||||
base: Source
|
||||
|
||||
def is_dict_key(self):
|
||||
def is_dict_key(self) -> bool:
|
||||
# Recurse until you either hit a ConstDictKey or a Source
|
||||
return self.base.is_dict_key()
|
||||
|
||||
def is_ephemeral(self):
|
||||
def is_ephemeral(self) -> bool:
|
||||
return self.base.is_ephemeral()
|
||||
|
||||
def get_base(self) -> Source:
|
||||
@ -1082,7 +1110,7 @@ class ChainedSource(Source):
|
||||
return current
|
||||
|
||||
|
||||
def detect_fake_mode(inputs: Any = None):
|
||||
def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
|
||||
"""
|
||||
Attempts to "detect" what the current fake mode is. If there is one ambiently
|
||||
available from TracingContext, we preferentially use that. Otherwise, we
|
||||
@ -1126,7 +1154,7 @@ def detect_fake_mode(inputs: Any = None):
|
||||
return None
|
||||
|
||||
|
||||
def active_fake_mode():
|
||||
def active_fake_mode() -> Optional[FakeTensorMode]:
|
||||
"""
|
||||
Inspects the dispatch mode stack for an active fake mode and returns it.
|
||||
Returns None if no fake mode is active.
|
||||
|
@ -465,6 +465,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
||||
from torch._subclasses.fake_tensor import extract_tensor_metadata
|
||||
|
||||
fake_mode = detect_fake_mode(primals + filtered_grad_outs)
|
||||
assert fake_mode is not None, "fake_mode should be enabled for HOPs"
|
||||
state = _CacheKeyState(fake_mode.shape_env)
|
||||
|
||||
tangent_metadata: list[object] = []
|
||||
@ -607,6 +608,7 @@ def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, *operands):
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
fake_mode = detect_fake_mode(operands)
|
||||
assert fake_mode is not None and fake_mode.shape_env is not None
|
||||
insert_deferred_runtime_asserts(
|
||||
graph,
|
||||
fake_mode.shape_env,
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import functools
|
||||
from contextlib import contextmanager, ExitStack, nullcontext
|
||||
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, overload, TypeVar, Union
|
||||
|
||||
@ -266,11 +266,12 @@ def _set_compilation_env():
|
||||
|
||||
# The invariant here is that we always trace the branch with fake tensor
|
||||
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
||||
fake_mode = detect_fake_mode(inputs)
|
||||
tracing_mode = "real"
|
||||
if fake_mode is None:
|
||||
fake_mode = nullcontext()
|
||||
tracing_mode = "fake"
|
||||
fake_mode_det = detect_fake_mode(inputs)
|
||||
fake_mode: AbstractContextManager = nullcontext()
|
||||
tracing_mode = "fake"
|
||||
if fake_mode_det is not None:
|
||||
fake_mode = fake_mode_det
|
||||
tracing_mode = "real"
|
||||
|
||||
# Note: we need to turn off proxy tensor mode to avoid tracing infra
|
||||
# code that happens in make_fx e.g. we now call as_strided when wrapping tensor
|
||||
@ -282,9 +283,12 @@ def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
||||
pre_dispatch=pre_dispatch,
|
||||
_error_on_data_dependent_ops=False,
|
||||
)(*inputs)
|
||||
if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None:
|
||||
if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None: # type: ignore[attr-defined]
|
||||
insert_deferred_runtime_asserts(
|
||||
gm, fake_mode.shape_env, "hoo_maybe_fake_tracing", export=True
|
||||
gm,
|
||||
fake_mode.shape_env, # type: ignore[attr-defined]
|
||||
"hoo_maybe_fake_tracing",
|
||||
export=True, # type: ignore[attr-defined]
|
||||
)
|
||||
return gm
|
||||
|
||||
|
@ -1065,7 +1065,7 @@ class GuardedCache(Generic[T]):
|
||||
Helper to get the shape env from the tracing context.
|
||||
"""
|
||||
ctx = torch._guards.TracingContext.try_get()
|
||||
if not ctx:
|
||||
if not ctx or not ctx.fake_mode:
|
||||
return None
|
||||
return ctx.fake_mode.shape_env
|
||||
|
||||
|
@ -1942,7 +1942,7 @@ def fw_compiler_freezing(
|
||||
idx for idx, n in enumerate(model_outputs) if isinstance(n, torch.fx.Node)
|
||||
]
|
||||
|
||||
static_input_idxs = []
|
||||
static_input_idxs: list[Any] = []
|
||||
# constant params will be real tensors, not fake
|
||||
tracing_context = torch._guards.TracingContext.try_get()
|
||||
unwrapped_args_offsets = [0]
|
||||
@ -2461,6 +2461,7 @@ def compile_fx(
|
||||
if node.op == "get_attr" and "val" not in node.meta:
|
||||
target = attrgetter(node.target)(gm)
|
||||
if isinstance(target, torch.Tensor):
|
||||
assert fake_mode is not None
|
||||
node.meta["val"] = fake_mode.from_tensor(
|
||||
target, static_shapes=True
|
||||
)
|
||||
|
@ -1429,7 +1429,9 @@ def register_replacement(
|
||||
)
|
||||
|
||||
sym_args: list[torch.SymInt] = []
|
||||
with torch._dynamo.utils.detect_fake_mode(args):
|
||||
fake_mode = torch._dynamo.utils.detect_fake_mode(args)
|
||||
assert fake_mode is not None
|
||||
with fake_mode:
|
||||
for i, grad in enumerate(requires_grad):
|
||||
if isinstance(args[i], torch.Tensor):
|
||||
if grad and is_integer_dtype(args[i].dtype):
|
||||
|
@ -203,6 +203,7 @@ def standalone_compile(
|
||||
# Reuse fake_mode from the TracingContext.
|
||||
# NB: The TracingContext only exists if we're currently in a torch.compile backend.
|
||||
context = torch._guards.TracingContext.get()
|
||||
assert context.fake_mode is not None
|
||||
fake_mode = context.fake_mode
|
||||
elif dynamic_shapes == "from_graph":
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
|
@ -2817,10 +2817,9 @@ def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[N
|
||||
return contextlib.nullcontext()
|
||||
|
||||
# In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
|
||||
shape_env = tracing_context.fake_mode.shape_env
|
||||
if not shape_env:
|
||||
if not tracing_context.fake_mode or not tracing_context.fake_mode.shape_env:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
shape_env = tracing_context.fake_mode.shape_env
|
||||
return shape_env.suppress_guards()
|
||||
|
||||
|
||||
@ -3343,12 +3342,13 @@ def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str:
|
||||
for i, e in enumerate(row):
|
||||
widths[i] = max(widths[i], len(str(e)))
|
||||
lines = []
|
||||
lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths)))
|
||||
# Need nested {} for string formatting; ignore SET_LINTER here
|
||||
lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths))) # noqa: set_linter
|
||||
# widths whitespace horizontal separators
|
||||
total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1)
|
||||
lines.append("-" * total_width)
|
||||
for row in elements:
|
||||
lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths)))
|
||||
lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths))) # noqa: set_linter
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ import operator
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from contextlib import nullcontext
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import (
|
||||
@ -2090,7 +2090,9 @@ def alert_not_deterministic(caller: str):
|
||||
|
||||
class CUDARngStateHelper:
|
||||
@staticmethod
|
||||
def get_torch_state_as_tuple(fake_mode=nullcontext()):
|
||||
def get_torch_state_as_tuple(
|
||||
fake_mode: AbstractContextManager[Any] = nullcontext(),
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA not available")
|
||||
|
||||
|
@ -573,8 +573,8 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||
delattr(ep.graph_module, name)
|
||||
|
||||
# TODO(zhxhchen17) Return the new graph_signature directly.
|
||||
fake_mode = detect_fake_mode(fake_args)
|
||||
fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode # type: ignore[assignment]
|
||||
fake_mode_det = detect_fake_mode(fake_args)
|
||||
fake_mode_ctx = contextlib.nullcontext() if fake_mode_det is None else fake_mode_det # type: ignore[assignment]
|
||||
custom_triton_ops_decomposition_ctx = (
|
||||
contextlib.nullcontext
|
||||
if decompose_custom_triton_ops
|
||||
@ -582,7 +582,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||
)
|
||||
with (
|
||||
_ignore_backend_decomps(),
|
||||
fake_mode,
|
||||
fake_mode_ctx,
|
||||
_override_composite_implicit_decomp(cia_to_decomp),
|
||||
custom_triton_ops_decomposition_ctx(),
|
||||
):
|
||||
|
@ -7872,7 +7872,9 @@ class PropagateUnbackedSymInts(torch.fx.Interpreter):
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
result = super().run_node(n)
|
||||
rebind_unbacked(detect_fake_mode().shape_env, n, result)
|
||||
fake_mode = detect_fake_mode()
|
||||
assert fake_mode is not None
|
||||
rebind_unbacked(fake_mode.shape_env, n, result)
|
||||
return result
|
||||
|
||||
|
||||
|
@ -1028,7 +1028,7 @@ def bound_sympy(
|
||||
|
||||
# If there's a tracing context, augment available constrained ranges.
|
||||
context = torch._guards.TracingContext.try_get()
|
||||
if context and context.fake_mode.shape_env:
|
||||
if context and context.fake_mode and context.fake_mode.shape_env:
|
||||
if ranges:
|
||||
ranges = {**context.fake_mode.shape_env.var_to_range, **ranges}
|
||||
else:
|
||||
|
Reference in New Issue
Block a user