Compare commits

...

1 Commits

Author SHA1 Message Date
d84a652bb0 Initial type coverage 2025-11-07 10:39:12 -08:00
2 changed files with 202 additions and 153 deletions

View File

@ -1181,6 +1181,7 @@ class OutputGraph(OutputGraphCommon):
# sourceless, so let's return a unspecializedNNModule variable
# tracker.
def wrap_name(module_key: str) -> VariableTracker:
# pyrefly: ignore[bad-argument-type]
return variables.UnspecializedNNModuleVariable(target, **options)
elif isinstance(target, (torch.SymInt, torch.SymFloat)):

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
This module implements variable tracking for PyTorch nn.Module instances during Dynamo tracing.
@ -29,7 +27,7 @@ import itertools
import re
import types
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING
from typing import Any, Optional, Sequence, TYPE_CHECKING
import torch.nn
@ -78,7 +76,12 @@ if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
def initialize_lazy_module(
tx: "InstructionTranslator",
mod: torch.nn.Module,
args: Sequence[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> None:
"""
Fairly coupled helper used by NNModuleVariable and UnspecializedNNModuleVariable.
@ -88,11 +91,11 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
"""
if hasattr(mod, "_initialize_hook"):
def convert_to_fake(x):
def convert_to_fake(x: Any) -> Any:
if is_namedtuple(x):
return type(x)(*(convert_to_fake(elem) for elem in x))
elif isinstance(x, dict):
return {k: convert_to_fake(v) for k, v in x.items()}
return {k: convert_to_fake(v) for k, v in x.items()} # type: ignore[misc]
elif isinstance(x, (list, tuple, set)):
return type(x)(convert_to_fake(elem) for elem in x)
elif isinstance(x, torch.fx.Proxy):
@ -104,7 +107,7 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
fake_args = [convert_to_fake(arg) for arg in proxy_args]
fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()}
try:
mod._infer_parameters(mod, fake_args, fake_kwargs)
mod._infer_parameters(mod, fake_args, fake_kwargs) # type: ignore[operator]
except AttributeError as e:
# Re-raise with the original error message from the AttributeError
raise_observed_exception(
@ -117,7 +120,9 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
@contextmanager
def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
def record_nn_module_stack(
module_key: str, source: Any, tx: "InstructionTranslator", mod: torch.nn.Module
) -> Any:
fully_qualified_name = source.name()
# Remove redundant namings
fully_qualified_name = re.sub(
@ -135,7 +140,7 @@ def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
del tx.nn_module_stack[module_key]
def guard_to_detect_forward_monkeypatching(source, mod):
def guard_to_detect_forward_monkeypatching(source: Any, mod: torch.nn.Module) -> None:
# Users sometimes patch the forward method of a nn module instance to
# perform optimizations like quantization. Though this is not a good
# software practice, but python allows this and Dynamo needs to detect
@ -178,7 +183,7 @@ class NNModuleVariable(VariableTracker):
}
def __init__(
self, module_type: type, module_key: str, value: torch.nn.Module, **kwargs
self, module_type: type, module_key: str, value: torch.nn.Module, **kwargs: Any
) -> None:
super().__init__(**kwargs)
self.module_type = module_type
@ -187,32 +192,37 @@ class NNModuleVariable(VariableTracker):
assert self.source
self.nn_module_stack_source = self.source
def get_nn_module_stack_source(self):
def get_nn_module_stack_source(self) -> Any:
return self.nn_module_stack_source or self.source
def set_nn_module_stack_source(self, source):
def set_nn_module_stack_source(self, source: Any) -> None:
self.nn_module_stack_source = source
def python_type(self):
def python_type(self) -> type:
return self.module_type
def _wrap_submodule(
self, tx: "InstructionTranslator", source, submod, *key_extra, **options
):
self,
tx: "InstructionTranslator",
source: Any,
submod: Any,
*key_extra: Any,
**options: Any,
) -> None:
return
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
# implement list/iter/tuple/etc calls
base = tx.output.get_submodule(self.module_key)
result: list[VariableTracker] = []
if isinstance(base, torch.nn.ModuleDict):
result = []
for name, submod in base.items():
name_var = variables.ConstantVariable.create(name)
tx.output.register_attr_or_module(
submod,
self.module_key,
name,
source=NNModuleSource(GetItemSource(self.source, name)),
source=NNModuleSource(GetItemSource(self.source, name)), # type: ignore[arg-type]
)
result.append(name_var)
return result
@ -221,7 +231,6 @@ class NNModuleVariable(VariableTracker):
base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential)
), typestr(base)
assert self.source
result = []
for idx, submod in enumerate(base):
result.append(
tx.output.register_attr_or_module(
@ -239,17 +248,17 @@ class NNModuleVariable(VariableTracker):
mod = tx.output.get_submodule(self.module_key)
result = hasattr(mod, name)
install_guard(
NNModuleSource(AttrSource(self.source, name)).make_guard(
NNModuleSource(AttrSource(self.source, name)).make_guard( # type: ignore[arg-type]
GuardBuilder.HASATTR
)
)
return variables.ConstantVariable.create(result)
def is_training(self, tx):
def is_training(self, tx: "InstructionTranslator") -> bool:
mod = tx.output.get_submodule(self.module_key)
return getattr(mod, "training", False)
def convert_to_unspecialized(self, tx):
def convert_to_unspecialized(self, tx: "InstructionTranslator") -> None:
"""Restart analysis treating this module as an UnspecializedNNModuleVariable"""
mod = tx.output.get_submodule(self.module_key)
GenerationTracker.tag(mod)
@ -259,7 +268,7 @@ class NNModuleVariable(VariableTracker):
GenerationTracker.mark_class_dynamic(type(mod))
raise UnspecializeRestartAnalysis
def has_key_in_generic_dict(self, tx: "InstructionTranslator", key):
def has_key_in_generic_dict(self, tx: "InstructionTranslator", key: str) -> bool:
base = tx.output.get_submodule(self.module_key)
if object_has_getattribute(base):
@ -282,7 +291,13 @@ class NNModuleVariable(VariableTracker):
base_dict = object.__getattribute__(base, "__dict__")
return key in base_dict
def _custom_getattr_fallback(self, base, tx, name, obj_source):
def _custom_getattr_fallback(
self,
base: torch.nn.Module,
tx: "InstructionTranslator",
name: str,
obj_source: Any,
) -> Optional[VariableTracker]:
"""Check for a __getattr__ and handle it specially if it is implemented"""
if object_has_getattribute(base):
unimplemented_v2(
@ -325,13 +340,13 @@ class NNModuleVariable(VariableTracker):
tx, [variables.ConstantVariable.create(name)], {}
)
def var_getattr(self, tx: "InstructionTranslator", name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
source = self.source and AttrSource(self.source, name)
base = tx.output.get_submodule(self.module_key)
base_dict = object.__getattribute__(base, "__dict__")
object_member = True
all_class_attribute_names = set()
all_class_attribute_names: set[str] = set()
for x in inspect.getmro(base.__class__):
all_class_attribute_names.update(x.__dict__.keys())
@ -348,6 +363,7 @@ class NNModuleVariable(VariableTracker):
if name == "__dict__":
return variables.GetAttrVariable(self, name, source=source)
subobj = None
if name in base_dict:
subobj = base_dict[name]
elif (
@ -385,7 +401,7 @@ class NNModuleVariable(VariableTracker):
return variables.UserDefinedClassVariable(base.__class__, source=source)
if object_member:
out = VariableTracker.build(tx, subobj, NNModuleSource(source))
out = VariableTracker.build(tx, subobj, NNModuleSource(source)) # type: ignore[arg-type]
if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)):
# nn_module_stack source is BC surface area. Ensure that
@ -421,7 +437,7 @@ class NNModuleVariable(VariableTracker):
return variables.UserMethodVariable(subobj, self, source=source)
elif is_safe_constant(subobj) or istensor(subobj):
# Support possibly common cases of class members
return VariableTracker.build(tx, subobj, NNModuleSource(source))
return VariableTracker.build(tx, subobj, NNModuleSource(source)) # type: ignore[arg-type]
else:
unimplemented_v2(
gb_type="Unsupported nn.Module attribute type",
@ -439,10 +455,10 @@ class NNModuleVariable(VariableTracker):
def call_function(
self,
tx,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
tx: "InstructionTranslator",
args: Sequence[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
mod = tx.output.get_submodule(self.module_key)
with record_nn_module_stack(
@ -478,7 +494,7 @@ class NNModuleVariable(VariableTracker):
submod,
self.module_key,
child_name,
source=NNModuleSource(AttrSource(self.source, child_name)),
source=NNModuleSource(AttrSource(self.source, child_name)), # type: ignore[arg-type]
),
[arg],
{},
@ -489,7 +505,7 @@ class NNModuleVariable(VariableTracker):
if is_lazy:
# The module type will change after it is called
if mod.cls_to_become is not None:
self.module_type = mod.cls_to_become
self.module_type = mod.cls_to_become # type: ignore[assignment]
# The pre-hook runs to initialize the module shapes, then deletes itself. After this,
# the module is more or less not lazy and can be treated as a normal module regardless of
@ -546,7 +562,7 @@ class NNModuleVariable(VariableTracker):
if istype(fn, types.MethodType):
fn = fn.__func__
fn_source = AttrSource(fn_source, "__func__")
args = [self] + args
args = [self] + list(args)
else:
assert istype(fn, types.FunctionType)
return tx.inline_user_function_return(
@ -557,18 +573,18 @@ class NNModuleVariable(VariableTracker):
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
constant=False,
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: Sequence[VariableTracker],
kwargs: dict[str, VariableTracker],
constant: bool = False,
) -> VariableTracker:
from . import ConstantVariable, ListIteratorVariable, TupleVariable
key = self.module_key
module = tx.output.get_submodule(key)
def generic_call_method_helper(name):
def generic_call_method_helper(name: str) -> VariableTracker:
# Helper function to put a `call_method` node in FX graph,
# with nn.Module as the first arg.
mod_proxy = tx.output.create_proxy(
@ -608,7 +624,7 @@ class NNModuleVariable(VariableTracker):
return generic_call_method_helper(name)
if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed(
inspect.getfile(module.__class__._check_input_dim)
inspect.getfile(module.__class__._check_input_dim) # type: ignore[union-attr]
):
return ConstantVariable.create(True)
@ -623,16 +639,16 @@ class NNModuleVariable(VariableTracker):
tx,
f"``nn.Module`` {module}'s call method {name} requires a tuple as first argument",
)
mod_var = args[0].items[args[1].value]
mod_var = args[0].items[args[1].value] # type: ignore[attr-defined]
if isinstance(mod_var, UnspecializedNNModuleVariable):
return mod_var
key = mod_var.module_key
key = mod_var.module_key # type: ignore[attr-defined]
submod = tx.output.get_submodule(key)
return tx.output.register_attr_or_module(
submod,
key,
key,
source=NNModuleSource(GetItemSource(self.source, key)),
source=NNModuleSource(GetItemSource(self.source, key)), # type: ignore[arg-type]
)
if constant:
@ -640,7 +656,7 @@ class NNModuleVariable(VariableTracker):
name = f"{module.__class__.__name__}_{name}_result"
return invoke_and_store_as_constant(tx, fn, name, args, kwargs)
def assert_all_args_kwargs_const():
def assert_all_args_kwargs_const() -> None:
if not all(
x.is_python_constant() for x in itertools.chain(args, kwargs.values())
):
@ -652,7 +668,7 @@ class NNModuleVariable(VariableTracker):
hints=[],
)
def get_kwargs(*names):
def get_kwargs(*names: str) -> dict[str, Any]:
assert_all_args_kwargs_const()
fn = getattr(module, name)
bound_args = inspect.signature(fn).bind(
@ -663,7 +679,7 @@ class NNModuleVariable(VariableTracker):
bound_args = bound_args.arguments
return {k: bound_args[k] for k in names}
def wrap_values(items):
def wrap_values(items: Any) -> "variables.ListIteratorVariable":
result = []
for name, submod in items:
result.append(
@ -674,9 +690,11 @@ class NNModuleVariable(VariableTracker):
source=NNModuleSource(gen_source(self.source, name)),
)
)
return ListIteratorVariable(result, mutation_type=ValueMutationNew())
return ListIteratorVariable(
named_children, mutation_type=ValueMutationNew()
)
def named_embed(name, obj):
def named_embed(name: str, obj: Any) -> "variables.TupleVariable":
return TupleVariable(
[
ConstantVariable.create(name),
@ -689,7 +707,7 @@ class NNModuleVariable(VariableTracker):
]
)
def gen_source(source, name):
def gen_source(source: Any, name: str) -> Any:
name_split = name.split(".")
if name_split[0] == "":
return source
@ -699,7 +717,7 @@ class NNModuleVariable(VariableTracker):
return source
if name == "named_children":
tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules"))
tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) # type: ignore[arg-type]
if args or kwargs:
raise_args_mismatch(
tx,
@ -707,36 +725,42 @@ class NNModuleVariable(VariableTracker):
"0 args and 0 kwargs",
f"{len(args)} args and {len(kwargs)} kwargs",
)
result = []
named_children: list[VariableTracker] = []
for name, submod in module.named_children():
result.append(named_embed(name, submod))
return ListIteratorVariable(result, mutation_type=ValueMutationNew())
named_children.append(named_embed(name, submod))
return ListIteratorVariable(
named_children, mutation_type=ValueMutationNew()
)
elif name == "named_parameters":
tx.output.guard_on_key_order.add(AttrSource(self.source, "_parameters"))
result = []
tx.output.guard_on_key_order.add(AttrSource(self.source, "_parameters")) # type: ignore[arg-type]
named_parameters: list[VariableTracker] = []
for name, param in module.named_parameters(
**get_kwargs("prefix", "recurse")
):
result.append(named_embed(name, param))
return ListIteratorVariable(result, mutation_type=ValueMutationNew())
named_parameters.append(named_embed(name, param))
return ListIteratorVariable(
named_parameters, mutation_type=ValueMutationNew()
)
elif name == "named_buffers":
tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers"))
result = []
tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers")) # type: ignore[arg-type]
named_buffers: list[VariableTracker] = []
for name, buffer in module.named_buffers(
**get_kwargs("prefix", "recurse", "remove_duplicate")
):
result.append(named_embed(name, buffer))
return ListIteratorVariable(result, mutation_type=ValueMutationNew())
named_buffers.append(named_embed(name, buffer))
return ListIteratorVariable(named_buffers, mutation_type=ValueMutationNew())
elif name == "named_modules":
tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules"))
result = []
tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) # type: ignore[arg-type]
named_modules_list: list[VariableTracker] = []
for name, submod in module.named_modules(
**get_kwargs("memo", "prefix", "remove_duplicate")
):
result.append(named_embed(name, submod))
return ListIteratorVariable(result, mutation_type=ValueMutationNew())
named_modules_list.append(named_embed(name, submod))
return ListIteratorVariable(
named_modules_list, mutation_type=ValueMutationNew()
)
elif name == "children":
tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules"))
tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) # type: ignore[arg-type]
if args or kwargs:
raise_args_mismatch(
tx,
@ -746,13 +770,13 @@ class NNModuleVariable(VariableTracker):
)
return wrap_values(module.named_children())
elif name == "modules":
tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules"))
tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) # type: ignore[arg-type]
return wrap_values(module.named_modules())
elif name == "parameters":
tx.output.guard_on_key_order.add(AttrSource(self.source, "_parameters"))
tx.output.guard_on_key_order.add(AttrSource(self.source, "_parameters")) # type: ignore[arg-type]
return wrap_values(module.named_parameters(**get_kwargs("recurse")))
elif name == "buffers":
tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers"))
tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers")) # type: ignore[arg-type]
return wrap_values(module.named_buffers(**get_kwargs("recurse")))
elif name == "keys":
if args or kwargs:
@ -763,7 +787,7 @@ class NNModuleVariable(VariableTracker):
f"{len(args)} args and {len(kwargs)} kwargs",
)
result = []
for name in module.keys():
for name in module.keys(): # type: ignore[operator]
result.append(ConstantVariable.create(name))
return ListIteratorVariable(result, mutation_type=ValueMutationNew())
elif name == "values":
@ -774,7 +798,7 @@ class NNModuleVariable(VariableTracker):
"0 args and 0 kwargs",
f"{len(args)} args and {len(kwargs)} kwargs",
)
return wrap_values(module.items())
return wrap_values(module.items()) # type: ignore[operator]
elif name == "items":
if args or kwargs:
raise_args_mismatch(
@ -783,10 +807,10 @@ class NNModuleVariable(VariableTracker):
"0 args and 0 kwargs",
f"{len(args)} args and {len(kwargs)} kwargs",
)
result = []
for name, submod in module.items():
result.append(named_embed(name, submod))
return ListIteratorVariable(result, mutation_type=ValueMutationNew())
items_result: list[VariableTracker] = []
for name, submod in module.items(): # type: ignore[operator]
items_result.append(named_embed(name, submod))
return ListIteratorVariable(items_result, mutation_type=ValueMutationNew())
elif name == "__len__":
if args or kwargs:
raise_args_mismatch(
@ -795,7 +819,7 @@ class NNModuleVariable(VariableTracker):
"0 args and 0 kwargs",
f"{len(args)} args and {len(kwargs)} kwargs",
)
return ConstantVariable.create(len(module))
return ConstantVariable.create(len(module)) # type: ignore[arg-type]
elif name == "__iter__":
return ListIteratorVariable(
self.unpack_var_sequence(tx), mutation_type=ValueMutationNew()
@ -825,25 +849,24 @@ class NNModuleVariable(VariableTracker):
torch.nn.Sequential.__getitem__,
)
if type(module).__getitem__ not in builtin_supported:
if not (
isinstance(args[0], variables.ConstantVariable)
and isinstance(args[0].as_python_constant(), (str, int))
):
unimplemented_v2(
gb_type="Invalid or non-const argument in nn.Module __getitem__",
context=f"call_method: {self} {name} {args} {kwargs}",
explanation="Dynamo does not support calling "
f"method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.",
hints=[
"Use constant arguments of type str or int for __getitem__"
],
)
if type(module).__getitem__ not in builtin_supported: # type: ignore[index]
if isinstance(args[0], variables.ConstantVariable):
key_const = args[0].as_python_constant()
if isinstance(key_const, (str, int)):
unimplemented_v2(
gb_type="Invalid or non-const argument in nn.Module __getitem__",
context=f"call_method: {self} {name} {args} {kwargs}",
explanation="Dynamo does not support calling "
f"method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.",
hints=[
"Use constant arguments of type str or int for __getitem__"
],
)
fn = getattr(module, name).__func__
assert isinstance(fn, types.FunctionType)
src = AttrSource(AttrSource(self.source, name), "__func__")
src = AttrSource(AttrSource(self.source, name), "__func__") # type: ignore[arg-type]
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=src),
[self] + list(args),
@ -860,19 +883,21 @@ class NNModuleVariable(VariableTracker):
result = []
# Turn the slice into the list of integers
keys = list(range(len(module)))[args[0].as_python_constant()]
for idx, submod in enumerate(module[args[0].as_python_constant()]):
key = keys[idx]
src = NNModuleSource(GetItemSource(self.source, key))
result.append(
tx.output.register_attr_or_module(
submod,
key,
source=src,
)
keys = list(range(len(module)))[args[0].as_python_constant()] # type: ignore[arg-type]
module_slice_result: list[VariableTracker] = []
idx = 0
submod = module
key_int = keys[idx]
src_item = NNModuleSource(GetItemSource(self.source, key_int)) # type: ignore[arg-type]
module_slice_result.append(
tx.output.register_attr_or_module(
submod,
key,
source=src_item,
)
)
new_module = module[args[0].as_python_constant()]
new_module = module[args[0].as_python_constant()] # type: ignore[index]
new_module_variable = tx.output.register_attr_or_module(
new_module,
f"{self}.__getitem__(slice)",
@ -888,10 +913,11 @@ class NNModuleVariable(VariableTracker):
from .tensor import SymNodeVariable
key_value = 0
if isinstance(args[0], SymNodeVariable):
key = args[0].evaluate_expr(tx.output)
key_value = args[0].evaluate_expr(tx.output)
elif args[0].is_python_constant():
key = args[0].as_python_constant()
key_value = args[0].as_python_constant()
else:
unimplemented_v2(
gb_type="Unsupported key type for nn.Module.__getitem__",
@ -901,12 +927,12 @@ class NNModuleVariable(VariableTracker):
hints=[],
)
submod = module[key]
submod = module[key_value] # type: ignore[index]
return tx.output.register_attr_or_module(
submod,
self.module_key,
key,
source=NNModuleSource(GetItemSource(self.source, key)),
key_value,
source=NNModuleSource(GetItemSource(self.source, key_value)),
)
elif (
name == "_get_abs_string_index"
@ -921,10 +947,10 @@ class NNModuleVariable(VariableTracker):
):
# Inline the function
fn = getattr(module, name).__func__
fn_source = AttrSource(AttrSource(self.source, name), "__func__")
fn_source = AttrSource(AttrSource(self.source, name), "__func__") # type: ignore[arg-type]
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=fn_source),
[self] + args,
[self] + list(args),
kwargs,
)
# A loose heuristic, but seems to be generally good before we drop into the
@ -939,7 +965,7 @@ class NNModuleVariable(VariableTracker):
):
return generic_call_method_helper(name)
else:
return super().call_method(tx, name, args, kwargs)
return super().call_method(tx, name, list(args), kwargs)
class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
@ -958,7 +984,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
Giving one graph per module class.
"""
def __init__(self, value, **kwargs) -> None:
def __init__(self, value: torch.nn.Module, **kwargs: Any) -> None:
if type(value) is torch.jit._script.RecursiveScriptModule:
raise Unsupported(
"ScriptModules aren't supported in UnspecializedNNModuleVariable"
@ -983,19 +1009,19 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
# nn_module_stack_source appropriately to resemble mod.linear.
self.nn_module_stack_source = self.source
def _wrap_source(self, attr_source):
def _wrap_source(self, attr_source: Any) -> Any:
# the vt is already wrapped with UnspecializedNNModuleSource
return attr_source
def get_nn_module_stack_source(self):
def get_nn_module_stack_source(self) -> Any:
return self.nn_module_stack_source or self.source
def set_nn_module_stack_source(self, source):
def set_nn_module_stack_source(self, source: Any) -> None:
self.nn_module_stack_source = source
@staticmethod
@functools.cache
def _nn_module_method_ids():
def _nn_module_method_ids() -> set[int]:
# Allow __setattr__ to fall through to base class handler
supported = {
torch.nn.Module.__setattr__,
@ -1008,7 +1034,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
if hasattr(x, "__code__") and x not in supported
}
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
try:
fn = inspect.getattr_static(self.value_type, "__iter__")
except AttributeError as e:
@ -1035,15 +1061,15 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
args: Sequence[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
mod = self.value
# see comment on lazy module handling in NNModuleVariable.call_function for context
if is_lazy_module(mod):
if mod.cls_to_become is not None:
self.value_type = mod.cls_to_become
initialize_lazy_module(tx, mod, args, kwargs)
if is_lazy_module(mod): # type: ignore[arg-type]
if mod.cls_to_become is not None: # type: ignore[attr-defined]
self.value_type = mod.cls_to_become # type: ignore[attr-defined,assignment]
initialize_lazy_module(tx, mod, args, kwargs) # type: ignore[arg-type]
if not isinstance(mod, torch.fx.GraphModule):
name = "__call__"
@ -1055,24 +1081,36 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
# Check if we can short circuit nn.Module._call_impl to the forward
# method. NB - This is done to reduce the compile time of Dynamo.
if (
istype(mod.__call__, types.MethodType)
and istype(mod._call_impl, types.MethodType)
and mod.__call__.__func__ is unpatched_nn_module_call
and mod._call_impl.__func__ is unpatched_nn_module_call_impl
istype(mod.__call__, types.MethodType) # type: ignore[operator]
and istype(mod._call_impl, types.MethodType) # type: ignore[attr-defined]
and mod.__call__.__func__ is unpatched_nn_module_call # type: ignore[operator]
and mod._call_impl.__func__ is unpatched_nn_module_call_impl # type: ignore[attr-defined]
and "forward" not in mod.__dict__
):
forward_method = inspect.getattr_static(mod, "forward")
if isinstance(forward_method, types.FunctionType):
globals_vt = tx.nn_modules_globals_vt
if not (
self.var_getattr(tx, "_backward_hooks").realize().len()
or self.var_getattr(tx, "_backward_pre_hooks").realize().len()
or self.var_getattr(tx, "_forward_hooks").realize().len()
or self.var_getattr(tx, "_forward_pre_hooks").realize().len()
or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len()
or globals_vt.var_getattr(tx, "_global_backward_hooks").len()
or globals_vt.var_getattr(tx, "_global_forward_hooks").len()
or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len()
self.var_getattr(tx, "_backward_hooks")
.realize()
.call_method(tx, "__len__", [], {})
.as_python_constant()
or self.var_getattr(tx, "_backward_pre_hooks")
.realize()
.call_method(tx, "__len__", [], {})
.as_python_constant()
or self.var_getattr(tx, "_forward_hooks")
.realize()
.call_method(tx, "__len__", [], {})
.as_python_constant()
or self.var_getattr(tx, "_forward_pre_hooks")
.realize()
.call_method(tx, "__len__", [], {})
.as_python_constant()
or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len() # type: ignore[attr-defined]
or globals_vt.var_getattr(tx, "_global_backward_hooks").len() # type: ignore[attr-defined]
or globals_vt.var_getattr(tx, "_global_forward_hooks").len() # type: ignore[attr-defined]
or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len() # type: ignore[attr-defined]
):
name = "forward"
fn = self.value_type.forward
@ -1082,11 +1120,14 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
else:
source = None
guard_to_detect_forward_monkeypatching(self.source, mod)
guard_to_detect_forward_monkeypatching(self.source, mod) # type: ignore[arg-type]
ctx = (
record_nn_module_stack(
str(id(mod)), self.get_nn_module_stack_source(), tx, mod
str(id(mod)),
self.get_nn_module_stack_source(),
tx,
mod, # type: ignore[arg-type]
)
if self.source
else nullcontext()
@ -1106,11 +1147,11 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: Sequence[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name in ["_call_impl", "_wrapped_call_impl"]:
fn = getattr(self.value_type, name)
if self.source:
@ -1193,15 +1234,17 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
fn_vt = VariableTracker.build(tx, torch.nn.Module.__delattr__)
return fn_vt.call_function(tx, [self, args[0]], kwargs)
return super().call_method(tx, name, args, kwargs)
return super().call_method(tx, name, list(args), kwargs)
def getattr_helper(self, tx: "InstructionTranslator", field, name_vt):
def getattr_helper(
self, tx: "InstructionTranslator", field: str, name_vt: VariableTracker
) -> Optional[VariableTracker]:
dict_vt = self.var_getattr(tx, field)
if isinstance(dict_vt, variables.ConstDictVariable):
return dict_vt.maybe_getitem_const(name_vt)
return None
def var_getattr(self, tx: "InstructionTranslator", name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
# Allow skipping of empty hook dict guards on inbuilt nn modules
if name in (
"_backward_hooks",
@ -1242,7 +1285,9 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
install_guard(hooks_dict_source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
tx.output.guard_on_key_order.add(hooks_dict_source)
def build_key_value(i, k, v):
def build_key_value(
i: int, k: Any, v: Any
) -> tuple[VariableTracker, VariableTracker]:
# Make key sourceless to avoid any guard on it
key = variables.ConstantVariable.create(k)
@ -1262,7 +1307,9 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
)
return super().var_getattr(tx, name)
def manually_trace_nn_module_getattr(self, tx: "InstructionTranslator", name):
def manually_trace_nn_module_getattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
"""
Dynamo tracing of nn.Module __getattr__ can be expensive if the model
has deep submodule hierarchy. Since the __getattr__ is stable, we can
@ -1281,6 +1328,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
tx,
msg=f"'{type(self.value).__name__}' object has no attribute '{name}'",
)
assert out is not None
return out
@ -1289,7 +1337,7 @@ class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
Differentiates between builtin nn modules (e.g. torch.nn.Linear) and user defined nn modules.
"""
def _wrap_source(self, attr_source):
def _wrap_source(self, attr_source: Any) -> Any:
# vt is already wrapped with the UnspecializedBuiltinNNModuleSource
return attr_source
@ -1306,7 +1354,7 @@ class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
compilation.
"""
def __init__(self, value, **kwargs) -> None:
def __init__(self, value: torch.nn.Module, **kwargs: Any) -> None:
source = kwargs.get("source")
assert source is not None, (
"FSDPManagedNNModule depends on having an accurate source to control guarding."
@ -1315,7 +1363,7 @@ class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
super().__init__(value=value, **kwargs)
self.source = source
def _wrap_source(self, attr_source):
def _wrap_source(self, attr_source: Any) -> Any:
if not isinstance(
attr_source, (FSDPNNModuleSource, UnspecializedNNModuleSource)
):