[dynamo] Mark a vt unspecialized nn module variable source earlier (#154780)

I am working on providing some skip guard helper functions to allow users to reduce guard overhead. This is a refactor to allow that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154780
Approved by: https://github.com/StrongerXi, https://github.com/jansel
This commit is contained in:
Animesh Jain
2025-06-02 17:03:16 -07:00
committed by PyTorch MergeBot
parent ea7b233015
commit cc96febb97
8 changed files with 39 additions and 19 deletions

View File

@ -138,7 +138,7 @@ hf_Bert_large,pass,0
hf_BigBird,pass,18
hf_BigBird,pass,24

1 name accuracy graph_breaks
138
139
140
141
142
143
144

View File

@ -122,7 +122,7 @@ hf_Bert_large,pass,0
hf_BigBird,pass,18
hf_BigBird,pass,24

1 name accuracy graph_breaks
122
123
124
125
126
127
128

View File

@ -1299,6 +1299,7 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
self.assertTrue(torch._dynamo.testing.same(r, m(i)))
self.assertEqual(cnt.op_count, 6)
@patch.object(torch._dynamo.config, "allow_unspec_int_on_nn_module", True)
def test_self_mutating1(self):
m1 = torch.nn.Linear(10, 10)
m2 = SelfMutatingModule(m1)

View File

@ -7454,14 +7454,13 @@ def forward(self, l_inp_, l_tmp_):
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
l_a_ = L_a_
l_b_ = L_b_
l_self_num = L_self_num
tensor = torch.tensor([True])
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s97 = None
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = s97 = None
getitem = cond[0]; cond = None
return (getitem,)""", # noqa: B950
)

View File

@ -2402,6 +2402,10 @@ def is_int_specialization_case(value, source):
source.guard_source().is_unspecialized_builtin_nn_module()
and not config.allow_unspec_int_on_nn_module
)
or (
source.guard_source().is_unspecialized_nn_module()
and not config.allow_unspec_int_on_nn_module
)
or is_from_defaults(source)
# TODO: Delete this condition when rollout is done. NB: this
# condition never evaluates True in open source

View File

@ -115,6 +115,8 @@ from ..source import (
Source,
SubclassAttrListSource,
TupleIteratorGetItemSource,
UnspecializedBuiltinNNModuleSource,
UnspecializedNNModuleSource,
)
from ..utils import (
_extract_tensor_dict,
@ -434,7 +436,10 @@ class VariableBuilder:
return cached_vt
vt = self._wrap(value)
if vt.source is None:
vt.source = self.source
if (
self._can_lift_attrs_to_inputs(vt)
and value not in self.tx.output.side_effects
@ -1714,7 +1719,6 @@ class VariableBuilder:
value = value.get_base()
self.source = AttrProxySource(self.source)
self.install_guards(GuardBuilder.TYPE_MATCH)
if torch._dynamo.config.inline_inbuilt_nn_modules:
freezing = is_parameter_freezing()
@ -1749,12 +1753,23 @@ class VariableBuilder:
# this will get cleaned up once compile ends
self.tx.output.nn_modules[self.name] = value
if value.__module__.startswith(("torch.nn.", "torch.ao.")) or getattr(
value.__class__, "_dynamo_marked_static", False
):
result = UnspecializedBuiltinNNModuleVariable(value, source=self.source)
if (
value.__module__.startswith(("torch.nn.modules", "torch.ao."))
and not value.__module__.startswith("torch.nn.modules.container")
) or getattr(value.__class__, "_dynamo_marked_static", False):
new_source = self.source
if config.inline_inbuilt_nn_modules:
# Export corner case - look at test_repros.py test_inlining_cornercase
new_source = UnspecializedBuiltinNNModuleSource(self.source)
result = UnspecializedBuiltinNNModuleVariable(value, source=new_source)
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
else:
result = UnspecializedNNModuleVariable(value, source=self.source)
new_source = self.source
if config.inline_inbuilt_nn_modules:
# Export corner case - look at test_repros.py test_inlining_cornercase
new_source = UnspecializedNNModuleSource(self.source)
result = UnspecializedNNModuleVariable(value, source=new_source)
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__
@ -2127,6 +2142,10 @@ class VariableBuilder:
)
proxy.node.meta["grapharg"] = grapharg
# TODO - Why do we need to set the source of the np ndarray vt back to
# original source. Many tests fails.
numpy_ndarray_variable.source = self.source
return numpy_ndarray_variable
def wrap_symint(

View File

@ -2658,8 +2658,8 @@ class AutoFunctionalizeHigherOrderVariable(TorchHigherOrderOperatorVariable):
class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable):
def proxy_submod(self, tx, arg):
assert isinstance(arg.source, DictGetItemSource)
submod_name = tx.output.install_subgraph(arg.source.index, arg.value)
assert isinstance(arg.source.base, DictGetItemSource)
submod_name = tx.output.install_subgraph(arg.source.base.index, arg.value)
p_submod = make_attr(tx, submod_name)
set_example_value(p_submod.node, arg.value)
return p_submod

View File

@ -48,7 +48,6 @@ from ..source import (
FSDPNNModuleSource,
GetItemSource,
NNModuleSource,
UnspecializedBuiltinNNModuleSource,
UnspecializedNNModuleSource,
)
from ..utils import (
@ -891,8 +890,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
self.nn_module_stack_source = self.source
def _wrap_source(self, attr_source):
if not isinstance(attr_source, UnspecializedNNModuleSource):
return UnspecializedNNModuleSource(attr_source)
# the vt is already wrapped with UnspecializedNNModuleSource
return attr_source
def get_nn_module_stack_source(self):
@ -1193,8 +1191,7 @@ class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
"""
def _wrap_source(self, attr_source):
if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource):
return UnspecializedBuiltinNNModuleSource(attr_source)
# vt is already wrapped with the UnspecializedBuiltinNNModuleSource
return attr_source