Compare commits

...

1 Commits

Author SHA1 Message Date
f1f9683409 [export] Preserve nn_module_stack for aliased nn modules 2025-09-30 18:02:43 -07:00
2 changed files with 15 additions and 8 deletions

View File

@ -14777,13 +14777,8 @@ graph():
for nn_module_stack in nn_module_stacks
]
if is_inline_and_install_strict_test(self._testMethodName):
# when inlined and install have same ID so reference same layer
self.assertEqual(filtered_nn_module_stack[0], "sub_net.0")
self.assertEqual(filtered_nn_module_stack[1], "sub_net.0")
else:
self.assertEqual(filtered_nn_module_stack[0], "sub_net.0")
self.assertEqual(filtered_nn_module_stack[1], "sub_net.2")
self.assertEqual(filtered_nn_module_stack[0], "sub_net.0")
self.assertEqual(filtered_nn_module_stack[1], "sub_net.2")
def test_slice_nn_module_stack(self):
class N(torch.nn.Module):
@ -14818,7 +14813,7 @@ graph():
]
if is_inline_and_install_strict_test(self._testMethodName):
self.assertEqual(filtered_nn_module_stack[0], "mod_list_1.2")
self.assertEqual(filtered_nn_module_stack[1], "mod_list_1.2")
self.assertEqual(filtered_nn_module_stack[1], "mod_list_2.4")
# This is fine since both of these will be deprecated soon.
elif is_strict_v2_test(self._testMethodName) and IS_FBCODE:
self.assertEqual(

View File

@ -442,6 +442,18 @@ class VariableBuilder:
dup_guard = make_dupe_guard(self.source, side_effect_result.source)
if dup_guard:
self.install_guards(dup_guard)
if isinstance(value, torch.nn.Module) and isinstance(
side_effect_result, UnspecializedNNModuleVariable
):
# This means that two nn module instances with different sources
# have the same id. NN modules are somewhat special objects,
# because we have to track their nn_module_stack for ease of
# use. But if we don't do anything, we will just return the
# older variable tracker with the older nn_module_stack. So,
# lets return the old variable tracker but update its
# nn_module_stack
side_effect_result.set_nn_module_stack_source(self.source)
return side_effect_result
cached_vt = self.tx.output.variable_tracker_cache.lookup(value, self.source)