[dynamo][annotate] Remove the need of external ctx mgr of preserve_node_meta (#165188)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165188
Approved by: https://github.com/yushangdi
ghstack dependencies: #164776
This commit is contained in:
Animesh Jain
2025-10-10 20:27:04 -07:00
committed by PyTorch MergeBot
parent 1e4c7dffa3
commit f0325d0787
3 changed files with 8 additions and 25 deletions

View File

@ -18,17 +18,6 @@ def checkpoint_wrapper(fn):
class AnnotateTests(torch._dynamo.test_case.TestCase):
# TODO - should not need this because we should turn this on in Dynamo but
# for some reasons, test fail.
def setUp(self):
super().setUp()
self.cm = torch.fx.traceback.preserve_node_meta()
self.cm.__enter__()
def tearDown(self):
super().tearDown()
self.cm.__exit__(None, None, None)
def get_custom_metadata(self, gm):
def helper(gm):
custom_metadata = []

View File

@ -45,17 +45,6 @@ def aot_eager_regional_inductor():
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
class RegionalInductorTests(torch._inductor.test_case.TestCase):
# TODO - should not need this because we should turn this on in Dynamo but
# for some reasons, test fail.
def setUp(self):
super().setUp()
self.cm = torch.fx.traceback.preserve_node_meta()
self.cm.__enter__()
def tearDown(self):
super().tearDown()
self.cm.__exit__(None, None, None)
def test_simple(self):
def fn(x, y):
sin = torch.sin(x)

View File

@ -23,6 +23,7 @@ restoring state changes.
import inspect
import sys
import warnings
from contextlib import ExitStack
from typing import TYPE_CHECKING, Union
import torch._C
@ -1278,9 +1279,13 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
)
def enter(self, tx, *args):
cm = torch.fx.traceback.annotate(self.target_values)
cm.__enter__()
self.set_cleanup_hook(tx, lambda: cm.__exit__(None, None, None))
# Run the annotation ctx manager in eager. Also ensure that
# preserve_node_meta context manager is setup. This is important to pass
# on the metadata to the create_proxy nodes.
stack = ExitStack()
stack.enter_context(torch.fx.traceback.annotate(self.target_values))
stack.enter_context(torch.fx.traceback.preserve_node_meta())
self.set_cleanup_hook(tx, lambda: stack.close())
return variables.ConstantVariable.create(None)
def module_name(self):