Refactor stack_trace preservation for node meta preservation (#90803)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90803
Approved by: https://github.com/jerryzh168, https://github.com/albanD
This commit is contained in:
Sherlock Huang
2023-01-09 18:01:36 +00:00
committed by PyTorch MergeBot
parent 1e768c63c1
commit 0f1302eeae
6 changed files with 40 additions and 48 deletions

View File

@ -178,7 +178,7 @@ class TestFunctionalization(TestCase):
from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks
import torch.fx.traceback as fx_traceback
setup_stacktrace_preservation_hooks([loss.grad_fn])
with fx_traceback.override_stack_trace():
with fx_traceback.preserve_node_meta():
loss.backward()
return x.grad