diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 3038b1a1938a..15a8d44d08a5 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -453,7 +453,10 @@ def _compile( output: Optional[OutputGraph] = None # This is shared across restarts mutated_closure_cell_contents: Set[str] = set() + fail_type: Optional[str] = None fail_reason: Optional[str] = None + fail_user_frame_filename: Optional[str] = None + fail_user_frame_lineno: Optional[int] = None speculation_log = SpeculationLog() @preserve_global_state @@ -610,12 +613,20 @@ def _compile( UncapturedHigherOrderOpError, BisectValidationException, ) as e: + fail_type = str(type(e)) fail_reason = str(e) exception_handler(e, code, frame, export=export) + if e.innermost_user_frame_summary is not None: # type: ignore[union-attr] + fail_user_frame_filename = e.innermost_user_frame_summary.filename # type: ignore[union-attr] + fail_user_frame_lineno = e.innermost_user_frame_summary.lineno # type: ignore[union-attr] raise except Exception as e: + fail_type = str(type(e)) fail_reason = str(e) exception_handler(e, code, frame, export=export) + if e.innermost_user_frame_summary is not None: # type: ignore[attr-defined] + fail_user_frame_filename = e.innermost_user_frame_summary.filename # type: ignore[attr-defined] + fail_user_frame_lineno = e.innermost_user_frame_summary.lineno # type: ignore[attr-defined] raise InternalTorchDynamoError(str(e)).with_traceback( e.__traceback__ ) from None @@ -670,7 +681,10 @@ def _compile( graph_input_count, entire_frame_compile_time, backend_compile_time, + fail_type, fail_reason, + fail_user_frame_filename, + fail_user_frame_lineno, non_compliant_ops, compliant_custom_ops, ) diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index d8830b0a4cda..c47cce8fa04d 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -214,8 +214,11 @@ class KeyErrorMsg: def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None: import traceback + exc.innermost_user_frame_summary = None # type: ignore[attr-defined] + real_stack = get_real_stack(exc) - if real_stack is not None: + if real_stack is not None and len(real_stack) > 0: + exc.innermost_user_frame_summary = real_stack[-1] # type: ignore[attr-defined] msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}" if config.replay_record_enabled and hasattr(exc, "record_filename"): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 0721e4296e33..6836da142527 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -592,7 +592,10 @@ class CompilationMetrics: graph_input_count: Optional[int] entire_frame_compile_time_s: Optional[float] backend_compile_time_s: Optional[float] + fail_type: Optional[str] fail_reason: Optional[str] + fail_user_frame_filename: Optional[str] + fail_user_frame_lineno: Optional[int] non_compliant_ops: Set[str] compliant_custom_ops: Set[str]