Correctly propagate exception to parent tx (#146502)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146502
Approved by: https://github.com/anijain2305, https://github.com/williamwen42, https://github.com/zou3519
ghstack dependencies: #146504, #146499
This commit is contained in:
Guilherme Leobas
2025-03-11 13:33:09 +00:00
committed by PyTorch MergeBot
parent fb53e9e514
commit daff65d671
6 changed files with 431 additions and 135 deletions

View File

@ -362,7 +362,7 @@ def raise_observed_exception(
# CPython here raises an exception. Since there is no python code, we have to manually setup the exception
# stack and raise the exception.
exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type]
tx.exn_vt_stack.append(exception_vt)
tx.exn_vt_stack.set_current_exception(exception_vt)
raise observed_exception_map[exc_type]
@ -391,7 +391,7 @@ def handle_observed_exception(tx: Any) -> None:
#
# Fortunately this translates to a simple pop from the exn_vt_stack
tx.exn_vt_stack.pop()
tx.exn_vt_stack.clear_current_exception()
# These exceptions are ok to fallback to eager/graph_break.