mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo] Revert back changes to UnspecializedBuiltinNNModuleVariable (#130991)
xref - https://fb.workplace.com/groups/1075192433118967/permalink/1466525440652329/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/130991 Approved by: https://github.com/williamwen42, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
9f392f8294
commit
a085acd7d6
@ -177,11 +177,7 @@ from .misc import (
|
||||
TorchVersionVariable,
|
||||
TypingVariable,
|
||||
)
|
||||
from .nn_module import (
|
||||
FSDPManagedNNModuleVariable,
|
||||
UnspecializedBuiltinNNModuleVariable,
|
||||
UnspecializedNNModuleVariable,
|
||||
)
|
||||
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
|
||||
from .optimizer import OptimizerVariable
|
||||
from .script_object import TorchScriptObjectVariable
|
||||
|
||||
@ -1278,10 +1274,7 @@ class VariableBuilder:
|
||||
# this will get cleaned up once compile ends
|
||||
self.tx.output.nn_modules[self.name] = value
|
||||
|
||||
if value.__module__.startswith(("torch.nn.", "torch.ao.")):
|
||||
result = UnspecializedBuiltinNNModuleVariable(value, source=self.source)
|
||||
else:
|
||||
result = UnspecializedNNModuleVariable(value, source=self.source)
|
||||
result = UnspecializedNNModuleVariable(value, source=self.source)
|
||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||
return result
|
||||
|
Reference in New Issue
Block a user