mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[dynamo] Mark a vt unspecialized nn module variable source earlier (#154780)
I am working on providing some skip guard helper functions to allow users to reduce guard overhead. This is a refactor to allow that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154780 Approved by: https://github.com/StrongerXi, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
ea7b233015
commit
cc96febb97
@ -138,7 +138,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,18
|
||||
hf_BigBird,pass,24
|
||||
|
||||
|
||||
|
||||
|
|
@ -122,7 +122,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,18
|
||||
hf_BigBird,pass,24
|
||||
|
||||
|
||||
|
||||
|
|
@ -1299,6 +1299,7 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertTrue(torch._dynamo.testing.same(r, m(i)))
|
||||
self.assertEqual(cnt.op_count, 6)
|
||||
|
||||
@patch.object(torch._dynamo.config, "allow_unspec_int_on_nn_module", True)
|
||||
def test_self_mutating1(self):
|
||||
m1 = torch.nn.Linear(10, 10)
|
||||
m2 = SelfMutatingModule(m1)
|
||||
|
@ -7454,14 +7454,13 @@ def forward(self, l_inp_, l_tmp_):
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
|
||||
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
|
||||
l_a_ = L_a_
|
||||
l_b_ = L_b_
|
||||
l_self_num = L_self_num
|
||||
tensor = torch.tensor([True])
|
||||
cond_true_0 = self.cond_true_0
|
||||
cond_false_0 = self.cond_false_0
|
||||
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s97 = None
|
||||
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = s97 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return (getitem,)""", # noqa: B950
|
||||
)
|
||||
|
@ -2402,6 +2402,10 @@ def is_int_specialization_case(value, source):
|
||||
source.guard_source().is_unspecialized_builtin_nn_module()
|
||||
and not config.allow_unspec_int_on_nn_module
|
||||
)
|
||||
or (
|
||||
source.guard_source().is_unspecialized_nn_module()
|
||||
and not config.allow_unspec_int_on_nn_module
|
||||
)
|
||||
or is_from_defaults(source)
|
||||
# TODO: Delete this condition when rollout is done. NB: this
|
||||
# condition never evaluates True in open source
|
||||
|
@ -115,6 +115,8 @@ from ..source import (
|
||||
Source,
|
||||
SubclassAttrListSource,
|
||||
TupleIteratorGetItemSource,
|
||||
UnspecializedBuiltinNNModuleSource,
|
||||
UnspecializedNNModuleSource,
|
||||
)
|
||||
from ..utils import (
|
||||
_extract_tensor_dict,
|
||||
@ -434,7 +436,10 @@ class VariableBuilder:
|
||||
return cached_vt
|
||||
|
||||
vt = self._wrap(value)
|
||||
|
||||
if vt.source is None:
|
||||
vt.source = self.source
|
||||
|
||||
if (
|
||||
self._can_lift_attrs_to_inputs(vt)
|
||||
and value not in self.tx.output.side_effects
|
||||
@ -1714,7 +1719,6 @@ class VariableBuilder:
|
||||
value = value.get_base()
|
||||
self.source = AttrProxySource(self.source)
|
||||
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
if torch._dynamo.config.inline_inbuilt_nn_modules:
|
||||
freezing = is_parameter_freezing()
|
||||
|
||||
@ -1749,12 +1753,23 @@ class VariableBuilder:
|
||||
# this will get cleaned up once compile ends
|
||||
self.tx.output.nn_modules[self.name] = value
|
||||
|
||||
if value.__module__.startswith(("torch.nn.", "torch.ao.")) or getattr(
|
||||
value.__class__, "_dynamo_marked_static", False
|
||||
):
|
||||
result = UnspecializedBuiltinNNModuleVariable(value, source=self.source)
|
||||
if (
|
||||
value.__module__.startswith(("torch.nn.modules", "torch.ao."))
|
||||
and not value.__module__.startswith("torch.nn.modules.container")
|
||||
) or getattr(value.__class__, "_dynamo_marked_static", False):
|
||||
new_source = self.source
|
||||
if config.inline_inbuilt_nn_modules:
|
||||
# Export corner case - look at test_repros.py test_inlining_cornercase
|
||||
new_source = UnspecializedBuiltinNNModuleSource(self.source)
|
||||
result = UnspecializedBuiltinNNModuleVariable(value, source=new_source)
|
||||
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
|
||||
else:
|
||||
result = UnspecializedNNModuleVariable(value, source=self.source)
|
||||
new_source = self.source
|
||||
if config.inline_inbuilt_nn_modules:
|
||||
# Export corner case - look at test_repros.py test_inlining_cornercase
|
||||
new_source = UnspecializedNNModuleSource(self.source)
|
||||
result = UnspecializedNNModuleVariable(value, source=new_source)
|
||||
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
|
||||
|
||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||
@ -2127,6 +2142,10 @@ class VariableBuilder:
|
||||
)
|
||||
proxy.node.meta["grapharg"] = grapharg
|
||||
|
||||
# TODO - Why do we need to set the source of the np ndarray vt back to
|
||||
# original source. Many tests fails.
|
||||
numpy_ndarray_variable.source = self.source
|
||||
|
||||
return numpy_ndarray_variable
|
||||
|
||||
def wrap_symint(
|
||||
|
@ -2658,8 +2658,8 @@ class AutoFunctionalizeHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
|
||||
class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
def proxy_submod(self, tx, arg):
|
||||
assert isinstance(arg.source, DictGetItemSource)
|
||||
submod_name = tx.output.install_subgraph(arg.source.index, arg.value)
|
||||
assert isinstance(arg.source.base, DictGetItemSource)
|
||||
submod_name = tx.output.install_subgraph(arg.source.base.index, arg.value)
|
||||
p_submod = make_attr(tx, submod_name)
|
||||
set_example_value(p_submod.node, arg.value)
|
||||
return p_submod
|
||||
|
@ -48,7 +48,6 @@ from ..source import (
|
||||
FSDPNNModuleSource,
|
||||
GetItemSource,
|
||||
NNModuleSource,
|
||||
UnspecializedBuiltinNNModuleSource,
|
||||
UnspecializedNNModuleSource,
|
||||
)
|
||||
from ..utils import (
|
||||
@ -891,8 +890,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
self.nn_module_stack_source = self.source
|
||||
|
||||
def _wrap_source(self, attr_source):
|
||||
if not isinstance(attr_source, UnspecializedNNModuleSource):
|
||||
return UnspecializedNNModuleSource(attr_source)
|
||||
# the vt is already wrapped with UnspecializedNNModuleSource
|
||||
return attr_source
|
||||
|
||||
def get_nn_module_stack_source(self):
|
||||
@ -1193,8 +1191,7 @@ class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
|
||||
"""
|
||||
|
||||
def _wrap_source(self, attr_source):
|
||||
if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource):
|
||||
return UnspecializedBuiltinNNModuleSource(attr_source)
|
||||
# vt is already wrapped with the UnspecializedBuiltinNNModuleSource
|
||||
return attr_source
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user