[hop][exc] make UncapturedHigherOrderOpError print user code and avoid re-raise (#159296)

After the change, the error stacktrace is attached with user code stack and  is suppressed into 1 (without the scrolling up mssage). For example:
```python
    class Test(torch.nn.Module):
        def forward(self, c, x):
            def cond_fn(c, x):
                return c > 0 and x.size(0) < 20

            def body_fn(c, x):
                return c - 1, x.sin()

            return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x))
```

Now gives the following error message:
```python
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1705, in test_while_loop_size_mismatch_tensor_expansion
    self._run_test(
    ~~~~~~~~~~~~~~^
        model=WhileLoopModels.SizeMismatchTensorExpansion(),
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<2 lines>...
        dynamic=dynamic,
        ^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1417, in _run_test
    result = model(*inputs_with_counters)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1053, in forward
    return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x))
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 176, in while_loop
    return torch.compile(
           ~~~~~~~~~~~~~~
        _while_loop_op_wrapper, backend=backend, fullgraph=True
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    )(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
    ~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 804, in compile_wrapper
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1595, in __call__
    result = self._torchdynamo_orig_backend(
        frame, cache_entry, self.hooks, frame_state, skip=1
    )
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1353, in __call__
    result = self._inner_convert(
        frame, cache_entry, hooks, frame_state, skip=skip + 1
    )
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 682, in __call__
    result = _compile(
        frame.f_code,
    ...<16 lines>...
        convert_frame_box=self._box,
    )
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1172, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_utils_internal.py", line 98, in wrapper_function
    return function(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 858, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 897, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1461, in transform_code_object
    transformations(instructions, code_options)
    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 300, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 818, in transform
    tracer.run()
    ~~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3528, in run
    super().run()
    ~~~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 91, in graph_break_as_hard_error
    raise exc.with_traceback(sys.exc_info()[2]) from None
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 77, in graph_break_as_hard_error
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1287, in call_function
    ) = speculate_subgraph(
        ~~~~~~~~~~~~~~~~~~^
        tx,
        ^^^
    ...<33 lines>...
        supports_aliasing=self.supports_aliasing,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 877, in speculate_subgraph
    raise ex
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 718, in speculate_subgraph
    output = f.call_function(tx, args, sub_kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call
    return tracer.inline_call_()
           ~~~~~~~~~~~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_
    self.run()
    ~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function
    return super().call_function(tx, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call
    return tracer.inline_call_()
           ~~~~~~~~~~~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_
    self.run()
    ~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
    while self.step():
          ~~~~~~~~~^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
    self.dispatch_table[inst.opcode](self, inst)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 830, in inner
    unimplemented_v2(
    ~~~~~~~~~~~~~~~~^
        gb_type="Data-dependent branching",
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<5 lines>...
        ],
        ^^
    )
    ^
  File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 580, in unimplemented_v2
    raise Unsupported(msg)
torch._dynamo.exc.UncapturedHigherOrderOpError: while_loop doesn't work unless it is captured completely with torch.compile. Got Data-dependent branching
  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
  Hint: Use `torch.cond` to express dynamic control flow.

  Developer debug context: attempted to jump with TensorVariable()

 For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0170.html

from user code:
   File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 167, in _while_loop_op_wrapper
    return while_loop_op(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 137, in flat_cond_fn
    return cond_fn(*carried, *additional)
  File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1047, in cond_fn
    return c > 0 and x.size(0) < 20

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

To execute this test, run the following from the base repo dir:
    python test/inductor/test_control_flow.py WhileLoopTests.test_while_loop_size_mismatch_tensor_expansion_device_cpu_dynamic_False

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159296
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu
2025-08-11 11:50:43 -07:00
committed by PyTorch MergeBot
parent 5a40c57844
commit fc25c68f20
3 changed files with 30 additions and 30 deletions

View File

@ -264,7 +264,14 @@ class UnsafeScriptObjectError(TorchDynamoException):
class UncapturedHigherOrderOpError(TorchDynamoException):
pass
def __init__(self, msg: str, real_stack: Optional[StackSummary] = None) -> None:
super().__init__(msg)
self.msg = msg
self.real_stack = (
real_stack
if real_stack is not None
else torch._guards.TracingContext.extract_stack()
)
class IncorrectUsage(Exception):