mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][annotate] Remove the need of external ctx mgr of preserve_node_meta (#165188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165188 Approved by: https://github.com/yushangdi ghstack dependencies: #164776
This commit is contained in:
committed by
PyTorch MergeBot
parent
1e4c7dffa3
commit
f0325d0787
@ -18,17 +18,6 @@ def checkpoint_wrapper(fn):
|
||||
|
||||
|
||||
class AnnotateTests(torch._dynamo.test_case.TestCase):
|
||||
# TODO - should not need this because we should turn this on in Dynamo but
|
||||
# for some reasons, test fail.
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.cm = torch.fx.traceback.preserve_node_meta()
|
||||
self.cm.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
self.cm.__exit__(None, None, None)
|
||||
|
||||
def get_custom_metadata(self, gm):
|
||||
def helper(gm):
|
||||
custom_metadata = []
|
||||
|
@ -45,17 +45,6 @@ def aot_eager_regional_inductor():
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
|
||||
class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
# TODO - should not need this because we should turn this on in Dynamo but
|
||||
# for some reasons, test fail.
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.cm = torch.fx.traceback.preserve_node_meta()
|
||||
self.cm.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
self.cm.__exit__(None, None, None)
|
||||
|
||||
def test_simple(self):
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
|
@ -23,6 +23,7 @@ restoring state changes.
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import ExitStack
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch._C
|
||||
@ -1278,9 +1279,13 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable):
|
||||
)
|
||||
|
||||
def enter(self, tx, *args):
|
||||
cm = torch.fx.traceback.annotate(self.target_values)
|
||||
cm.__enter__()
|
||||
self.set_cleanup_hook(tx, lambda: cm.__exit__(None, None, None))
|
||||
# Run the annotation ctx manager in eager. Also ensure that
|
||||
# preserve_node_meta context manager is setup. This is important to pass
|
||||
# on the metadata to the create_proxy nodes.
|
||||
stack = ExitStack()
|
||||
stack.enter_context(torch.fx.traceback.annotate(self.target_values))
|
||||
stack.enter_context(torch.fx.traceback.preserve_node_meta())
|
||||
self.set_cleanup_hook(tx, lambda: stack.close())
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def module_name(self):
|
||||
|
Reference in New Issue
Block a user