mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90803 Approved by: https://github.com/jerryzh168, https://github.com/albanD ghstack-source-id: 5848cca08ef5d6f8868f4f79d8bc29711e9a52c2 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/92400 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
1fa68d40b8
commit
36fe31f537
@ -178,7 +178,7 @@ class TestFunctionalization(TestCase):
|
||||
from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
|
||||
import torch.fx.traceback as fx_traceback
|
||||
setup_stacktrace_preservation_hooks([loss.grad_fn])
|
||||
with fx_traceback.override_stack_trace():
|
||||
with fx_traceback.preserve_node_meta():
|
||||
loss.backward()
|
||||
return x.grad
|
||||
|
||||
|
@ -649,7 +649,7 @@ def export(
|
||||
if aten_graph:
|
||||
# Running graph with interpreter is needed for propagating the stack_trace
|
||||
def graph_with_interpreter(*args):
|
||||
with torch.fx.traceback.override_stack_trace():
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
return torch.fx.Interpreter(graph).run(*args)
|
||||
|
||||
graph = make_fx(
|
||||
|
@ -896,7 +896,7 @@ def create_joint_forward_backward_functionalized(
|
||||
backward_out = []
|
||||
# Call the backwards pass
|
||||
if grad_primals:
|
||||
with fx_traceback.override_stack_trace():
|
||||
with fx_traceback.preserve_node_meta():
|
||||
backward_out = torch.autograd.grad(
|
||||
needed_outs,
|
||||
grad_primals,
|
||||
@ -2447,7 +2447,7 @@ def aot_module_simplified(
|
||||
mod, pytree.tree_unflatten(args[:params_len], params_spec)
|
||||
):
|
||||
if isinstance(mod, torch.fx.GraphModule):
|
||||
with fx_traceback.override_stack_trace(), warnings.catch_warnings():
|
||||
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", "Anomaly Detection has been enabled."
|
||||
)
|
||||
|
@ -153,7 +153,7 @@ class Interpreter:
|
||||
|
||||
@contextmanager
|
||||
def _set_current_node(self, node):
|
||||
with fx_traceback.append_stack_trace(node.stack_trace), fx_traceback.set_current_meta(node.meta):
|
||||
with fx_traceback.set_current_meta(node.meta):
|
||||
yield
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
@ -477,7 +477,7 @@ class Transformer(Interpreter):
|
||||
Transform ``self.module`` and return the transformed
|
||||
``GraphModule``.
|
||||
"""
|
||||
with fx_traceback.override_stack_trace():
|
||||
with fx_traceback.preserve_node_meta():
|
||||
result = super().run(enable_io_processing=False)
|
||||
if result is not None:
|
||||
def strip_proxy(a : Union[Argument, Proxy]) -> Any:
|
||||
|
@ -161,10 +161,23 @@ class TracerBase:
|
||||
proxy = proxy_factory_fn(node)
|
||||
|
||||
# Optionally set stack trace on the created Node for debugging purposes
|
||||
if fx_traceback.is_stack_trace_overridden():
|
||||
proxy.node.meta = fx_traceback.get_current_meta()
|
||||
stacks = fx_traceback.format_stack()
|
||||
proxy.node.stack_trace = '\n'.join(reversed(stacks))
|
||||
if fx_traceback.has_preserved_node_meta():
|
||||
current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
|
||||
|
||||
# Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta
|
||||
# If other meta fields are needed, they can be added here
|
||||
stack_trace = current_meta.get("stack_trace")
|
||||
if stack_trace:
|
||||
proxy.node.stack_trace = stack_trace
|
||||
|
||||
nn_module_stack = current_meta.get("nn_module_stack")
|
||||
if nn_module_stack:
|
||||
proxy.node.meta["nn_module_stack"] = nn_module_stack
|
||||
|
||||
source_fn = current_meta.get("source_fn")
|
||||
if source_fn:
|
||||
proxy.node.meta["source_fn"] = source_fn
|
||||
|
||||
elif self.record_stack_traces:
|
||||
user_frame = self._find_user_frame()
|
||||
if user_frame:
|
||||
|
@ -1,66 +1,49 @@
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, List, Any, Dict
|
||||
from typing import List, Any, Dict
|
||||
from ._compatibility import compatibility
|
||||
|
||||
__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack',
|
||||
'is_stack_trace_overridden', 'get_current_meta', 'set_current_meta']
|
||||
__all__ = ['preserve_node_meta', 'has_preserved_node_meta',
|
||||
'set_stack_trace', 'format_stack',
|
||||
'set_current_meta', 'get_current_meta']
|
||||
|
||||
|
||||
current_stack: List[str] = []
|
||||
current_meta: Dict[str, Any] = {}
|
||||
is_overridden = False
|
||||
should_preserve_node_meta = False
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@contextmanager
|
||||
def override_stack_trace():
|
||||
global is_overridden
|
||||
def preserve_node_meta():
|
||||
global should_preserve_node_meta
|
||||
|
||||
saved_is_overridden = is_overridden
|
||||
saved_should_preserve_node_meta = should_preserve_node_meta
|
||||
try:
|
||||
is_overridden = True
|
||||
should_preserve_node_meta = True
|
||||
yield
|
||||
finally:
|
||||
is_overridden = saved_is_overridden
|
||||
should_preserve_node_meta = saved_should_preserve_node_meta
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def set_stack_trace(stack : List[str]):
|
||||
global current_stack
|
||||
global current_meta
|
||||
|
||||
if is_overridden and stack:
|
||||
current_stack = stack
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@contextmanager
|
||||
def append_stack_trace(stack : Optional[str]):
|
||||
"""
|
||||
The content of stack here is an entire stacktraces as a string
|
||||
"""
|
||||
global current_stack
|
||||
|
||||
if is_overridden and stack:
|
||||
try:
|
||||
current_stack.append(stack)
|
||||
yield
|
||||
finally:
|
||||
current_stack.pop()
|
||||
else:
|
||||
yield
|
||||
if should_preserve_node_meta and stack:
|
||||
current_meta["stack_trace"] = "".join(stack)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def format_stack() -> List[str]:
|
||||
if is_overridden:
|
||||
return current_stack.copy()
|
||||
if should_preserve_node_meta:
|
||||
return [current_meta.get("stack_trace", "")]
|
||||
else:
|
||||
# fallback to traceback.format_stack()
|
||||
return traceback.format_list(traceback.extract_stack()[:-1])
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def is_stack_trace_overridden() -> bool:
|
||||
return is_overridden
|
||||
def has_preserved_node_meta() -> bool:
|
||||
return should_preserve_node_meta
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@ -68,13 +51,13 @@ def is_stack_trace_overridden() -> bool:
|
||||
def set_current_meta(meta : Dict[str, Any]):
|
||||
global current_meta
|
||||
|
||||
old_meta = current_meta
|
||||
if is_overridden and meta:
|
||||
if should_preserve_node_meta and meta:
|
||||
saved_meta = current_meta
|
||||
try:
|
||||
current_meta = meta
|
||||
yield
|
||||
finally:
|
||||
current_meta = old_meta
|
||||
current_meta = saved_meta
|
||||
else:
|
||||
yield
|
||||
|
||||
|
Reference in New Issue
Block a user