Update tests to check for more robust pattern (#163107)

Landing this instead of https://github.com/pytorch/pytorch/pull/162994.

Here is how i think the whole dynamo + frame construction logic work:
1) There is no way to create a frame object in python land as this is created in runtime from cpython. So that's why aot_compile creates FrameInfo this way. (kind of like simulating the runtime) i guess you could write your own very simple eval_frame.c where you can interject the frame construction but we probably don't want that.
2) When there is no wrapper (the old export or aot_compile), we first assign sources by iterating over f_locals which contain both local args and closure variables (this is implementation details of cpython frame construction). So thats why closure variables end up getting LocalSource names as can be shown in this test case (f6ea41ead2/test/export/test_export.py (L1369)). Note that L["self"] here means we are referring to local object self. Important thing to keep in mind here is this self is not actually model self, but the outer self.
3) When we switch to wrapper case, we end up trying to inline the original inner module. When doing so, we need to track all local and closures for this inner module as can be seen here (f6ea41ead2/torch/_dynamo/variables/functions.py (L463)) Here we are not looking into inner frame's f_locals but just directly look at closures. I guess this is because we are one more frame up so there is no access to frame f_locals at this point. And it is probably not good idea to change dynamo's logic here. As a result, i get following error message that is different from old export:
"While exporting, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: ["L['self']._export_root.forward.__func__.__closure__[1].cell_contents.bank", "L['self']._export_root.forward.__func__.__closure__[1].cell_contents.bank_dict", "L['self']._export_root.forward.__func__.__closure__[0].cell_contents"]"

My initial attempt of solving this was taking inner closures and put them to f_locals for the frame i am constructing which turned out too compilcated because we needed to muck around bytecode instructions as well. So i am thinking we should just update the test to reflect new names and follow up with better post-processing step to have better names.

Differential Revision: [D82582029](https://our.internmc.facebook.com/intern/diff/D82582029)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163107
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-09-23 10:36:17 -07:00
committed by PyTorch MergeBot
parent fc84743707
commit e671dcc969
2 changed files with 26 additions and 12 deletions

View File

@ -1370,7 +1370,7 @@ graph():
self.mod.forward = hacked_up_forward.__get__(self.mod, Foo)
def __call__(self, x, y):
ep = torch.export.export(self.mod, (x, y), strict=True).module()
ep = export(self.mod, (x, y), strict=True).module()
out = ep(x, y)
return out
@ -1379,13 +1379,31 @@ graph():
foo = Foo()
ref = ReferenceControl(foo)
with self.assertWarnsRegex(
UserWarning,
"While exporting, we found certain side effects happened in the model.forward. "
"Here are the list of potential sources you can double check: "
"\[\"L\['global_list'\]\", \"L\['self'\].bank\", \"L\['self'\].bank_dict\"",
):
ref(torch.randn(4, 4), torch.randn(4, 4))
# TODO (tmanlaibaatar) this kinda sucks but today there is no good way to get
# good source name. We should have an util that post processes dynamo source names
# to be more readable.
if is_strict_v2_test(self._testMethodName):
with self.assertWarnsRegex(
UserWarning,
r"(L\['self']\._export_root\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank"
r"|L\['self']\._export_root\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank_dict"
r"|L\['self']\._export_root\.forward\.__func__\.__closure__\[0\]\.cell_contents)",
):
ref(torch.randn(4, 4), torch.randn(4, 4))
elif is_inline_and_install_strict_test(self._testMethodName):
with self.assertWarnsRegex(
UserWarning,
r"(L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank"
r"|L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank_dict"
r"|L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[0\]\.cell_contents)",
):
ref(torch.randn(4, 4), torch.randn(4, 4))
else:
with self.assertWarnsRegex(
UserWarning,
r"(L\['global_list'\]|L\['self'\]\.bank|L\['self'\]\.bank_dict)",
):
ref(torch.randn(4, 4), torch.randn(4, 4))
def test_mask_nonzero_static(self):
class TestModule(torch.nn.Module):

View File

@ -89,10 +89,6 @@ unittest.expectedFailure(
unittest.expectedFailure(
InlineAndInstallStrictExportTestExport.test_retrace_pre_autograd_inline_and_install_strict # noqa: F821
)
# this is because detect leak test has export root
unittest.expectedFailure(
InlineAndInstallStrictExportTestExport.test_detect_leak_strict_inline_and_install_strict # noqa: F821
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests