[Reland] Refactor stack_trace preservation for node meta preservation (#90803) (#92400)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90803
Approved by: https://github.com/jerryzh168, https://github.com/albanD
ghstack-source-id: 5848cca08ef5d6f8868f4f79d8bc29711e9a52c2

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92400
Approved by: https://github.com/jerryzh168
This commit is contained in:
Sherlock Huang
2023-01-30 23:30:43 +00:00
committed by PyTorch MergeBot
parent 1fa68d40b8
commit 36fe31f537
6 changed files with 44 additions and 48 deletions

View File

@ -178,7 +178,7 @@ class TestFunctionalization(TestCase):
from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
import torch.fx.traceback as fx_traceback
setup_stacktrace_preservation_hooks([loss.grad_fn])
with fx_traceback.override_stack_trace():
with fx_traceback.preserve_node_meta():
loss.backward()
return x.grad

View File

@ -649,7 +649,7 @@ def export(
if aten_graph:
# Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args):
with torch.fx.traceback.override_stack_trace():
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(graph).run(*args)
graph = make_fx(

View File

@ -896,7 +896,7 @@ def create_joint_forward_backward_functionalized(
backward_out = []
# Call the backwards pass
if grad_primals:
with fx_traceback.override_stack_trace():
with fx_traceback.preserve_node_meta():
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
@ -2447,7 +2447,7 @@ def aot_module_simplified(
mod, pytree.tree_unflatten(args[:params_len], params_spec)
):
if isinstance(mod, torch.fx.GraphModule):
with fx_traceback.override_stack_trace(), warnings.catch_warnings():
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "Anomaly Detection has been enabled."
)

View File

@ -153,7 +153,7 @@ class Interpreter:
@contextmanager
def _set_current_node(self, node):
with fx_traceback.append_stack_trace(node.stack_trace), fx_traceback.set_current_meta(node.meta):
with fx_traceback.set_current_meta(node.meta):
yield
@compatibility(is_backward_compatible=True)
@ -477,7 +477,7 @@ class Transformer(Interpreter):
Transform ``self.module`` and return the transformed
``GraphModule``.
"""
with fx_traceback.override_stack_trace():
with fx_traceback.preserve_node_meta():
result = super().run(enable_io_processing=False)
if result is not None:
def strip_proxy(a : Union[Argument, Proxy]) -> Any:

View File

@ -161,10 +161,23 @@ class TracerBase:
proxy = proxy_factory_fn(node)
# Optionally set stack trace on the created Node for debugging purposes
if fx_traceback.is_stack_trace_overridden():
proxy.node.meta = fx_traceback.get_current_meta()
stacks = fx_traceback.format_stack()
proxy.node.stack_trace = '\n'.join(reversed(stacks))
if fx_traceback.has_preserved_node_meta():
current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
# Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta
# If other meta fields are needed, they can be added here
stack_trace = current_meta.get("stack_trace")
if stack_trace:
proxy.node.stack_trace = stack_trace
nn_module_stack = current_meta.get("nn_module_stack")
if nn_module_stack:
proxy.node.meta["nn_module_stack"] = nn_module_stack
source_fn = current_meta.get("source_fn")
if source_fn:
proxy.node.meta["source_fn"] = source_fn
elif self.record_stack_traces:
user_frame = self._find_user_frame()
if user_frame:

View File

@ -1,66 +1,49 @@
import traceback
from contextlib import contextmanager
from typing import Optional, List, Any, Dict
from typing import List, Any, Dict
from ._compatibility import compatibility
__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack',
'is_stack_trace_overridden', 'get_current_meta', 'set_current_meta']
__all__ = ['preserve_node_meta', 'has_preserved_node_meta',
'set_stack_trace', 'format_stack',
'set_current_meta', 'get_current_meta']
current_stack: List[str] = []
current_meta: Dict[str, Any] = {}
is_overridden = False
should_preserve_node_meta = False
@compatibility(is_backward_compatible=False)
@contextmanager
def override_stack_trace():
global is_overridden
def preserve_node_meta():
global should_preserve_node_meta
saved_is_overridden = is_overridden
saved_should_preserve_node_meta = should_preserve_node_meta
try:
is_overridden = True
should_preserve_node_meta = True
yield
finally:
is_overridden = saved_is_overridden
should_preserve_node_meta = saved_should_preserve_node_meta
@compatibility(is_backward_compatible=False)
def set_stack_trace(stack : List[str]):
global current_stack
global current_meta
if is_overridden and stack:
current_stack = stack
@compatibility(is_backward_compatible=False)
@contextmanager
def append_stack_trace(stack : Optional[str]):
"""
The content of stack here is an entire stacktraces as a string
"""
global current_stack
if is_overridden and stack:
try:
current_stack.append(stack)
yield
finally:
current_stack.pop()
else:
yield
if should_preserve_node_meta and stack:
current_meta["stack_trace"] = "".join(stack)
@compatibility(is_backward_compatible=False)
def format_stack() -> List[str]:
if is_overridden:
return current_stack.copy()
if should_preserve_node_meta:
return [current_meta.get("stack_trace", "")]
else:
# fallback to traceback.format_stack()
return traceback.format_list(traceback.extract_stack()[:-1])
@compatibility(is_backward_compatible=False)
def is_stack_trace_overridden() -> bool:
return is_overridden
def has_preserved_node_meta() -> bool:
return should_preserve_node_meta
@compatibility(is_backward_compatible=False)
@ -68,13 +51,13 @@ def is_stack_trace_overridden() -> bool:
def set_current_meta(meta : Dict[str, Any]):
global current_meta
old_meta = current_meta
if is_overridden and meta:
if should_preserve_node_meta and meta:
saved_meta = current_meta
try:
current_meta = meta
yield
finally:
current_meta = old_meta
current_meta = saved_meta
else:
yield