Compare commits

...

2 Commits

Author SHA1 Message Date
2549554a10 Update on "Remove obsolete is_export checks"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-15 19:51:24 -07:00
d01a0daca3 Remove obsolete is_export checks
[ghstack-poisoned]
2025-10-15 19:46:35 -07:00
5 changed files with 9 additions and 49 deletions

View File

@ -107,7 +107,7 @@ class GenerationTracker:
cls.generation_values = ExactWeakKeyDictionary()
def is_dynamic_nn_module(obj: Any, is_export: bool) -> bool:
def is_dynamic_nn_module(obj: Any) -> bool:
"""Check for nn.Modules() created dynamically or mutated"""
if isinstance(obj, torch.nn.Module) and (
"forward" in obj.__dict__ or isinstance(obj, (dict, MutableMapping))
@ -117,11 +117,7 @@ def is_dynamic_nn_module(obj: Any, is_export: bool) -> bool:
return True
if hasattr(obj, "torchdynamo_force_dynamic"):
return obj.torchdynamo_force_dynamic
if (
isinstance(obj, torch.nn.Module)
and config.inline_inbuilt_nn_modules
and (not is_export or config.install_free_tensors)
):
if isinstance(obj, torch.nn.Module) and config.inline_inbuilt_nn_modules:
return True
if isinstance(obj, torch.nn.Module) and nn_module_has_global_hooks():

View File

@ -1102,7 +1102,7 @@ class OutputGraph(OutputGraphCommon):
*names: Any,
**options: Any,
) -> VariableTracker:
if is_dynamic_nn_module(target, self.export):
if is_dynamic_nn_module(target):
# Instead of returning UnspecializedNNModuleVariable, call
# VariableTracker.build so that it is tracked for mutation.
return VariableTracker.build(self.current_tx, target, **options)

View File

@ -1888,7 +1888,7 @@ class VariableBuilder:
# don't allow STORE_ATTR mutation with custom __setattr__
return result
return self.tx.output.side_effects.track_object_existing(value, result)
elif mutation_guard.is_dynamic_nn_module(value, self.tx.export):
elif mutation_guard.is_dynamic_nn_module(value):
# created dynamically, don't specialize on it
# Note [Tracing a torch.compiled function]
@ -1936,18 +1936,14 @@ class VariableBuilder:
and not value.__module__.startswith("torch.nn.modules.container")
) or getattr(value.__class__, "_dynamo_marked_static", False):
new_source = self.source
if config.inline_inbuilt_nn_modules and (
not self.tx.output.export or config.install_free_tensors
):
if config.inline_inbuilt_nn_modules:
# Export corner case - look at test_repros.py test_inlining_cornercase
new_source = UnspecializedBuiltinNNModuleSource(self.source)
result = UnspecializedBuiltinNNModuleVariable(value, source=new_source)
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
else:
new_source = self.source
if config.inline_inbuilt_nn_modules and (
not self.tx.output.export or config.install_free_tensors
):
if config.inline_inbuilt_nn_modules:
# Export corner case - look at test_repros.py test_inlining_cornercase
new_source = UnspecializedNNModuleSource(self.source)
result = UnspecializedNNModuleVariable(value, source=new_source)

View File

@ -770,38 +770,9 @@ class NNModuleVariable(VariableTracker):
assert self.source
if isinstance(args[0], SliceVariable):
# TODO(anijain2305,export-team) - Remove this if condition when inlining of inbuilt nn modules is
# enabled for export.
if tx.output.export:
# Build a TupleVariable of NNModules
result = []
# Turn the slice into the list of integers
keys = list(range(len(module)))[args[0].as_python_constant()]
for idx, submod in enumerate(module[args[0].as_python_constant()]):
key = keys[idx]
src = NNModuleSource(GetItemSource(self.source, key))
result.append(
tx.output.register_attr_or_module(
submod,
key,
source=src,
)
)
new_module = module[args[0].as_python_constant()]
new_module_variable = tx.output.register_attr_or_module(
new_module,
f"{self}.__getitem__(slice)",
source=NNModuleSource(
GetItemSource(self.source, args[0].as_python_constant())
),
)
return new_module_variable
else:
# slice on nn module results in a creation of new module instance, so we need to make it sourceless.
# Convert to unspecialized so that UnspecializedNNModule variable can take care of it.
self.convert_to_unspecialized(tx)
# slice on nn module results in a creation of new module instance, so we need to make it sourceless.
# Convert to unspecialized so that UnspecializedNNModule variable can take care of it.
self.convert_to_unspecialized(tx)
from .tensor import SymNodeVariable

View File

@ -1603,9 +1603,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
)
and source
and isinstance(self, variables.UnspecializedNNModuleVariable)
# export has some awkwardness around specialized and unspecialized modules. Skip wrapping source for export
# usecase for now.
and (not tx.output.export or torch._dynamo.config.install_free_tensors)
):
# Recalculate source for params/buffers
if name in ("_buffers", "_parameters"):