Turns out codegen'ing a nested step graph break is significantly more complicated than first thought. The optimized function should actually do:
- call graph/load values/do side effects etc.
- call into the leaf's resume function, but skipped (this essentially step graph break function for just the leaf function)
- call into all the other resume functions, traced.
This PR also adds `torch._dynamo.step_unsupported()`, which can be used for internal testing purposes to better test step graph break handling.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162737
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #160601
Today convert_frame is implemented like the following:
```
def _compile():
tracer_output = None
def transform():
nonlocal tracer_output
...
def _compile_inner():
transform(...)
compile_inner(...)
```
The code is using unconventional nonlocal variable as the return value. This is not ideal for 2 reasons:
1. Reasoning about the code, especially together with error handling code becomes harder.
2. more importantly, this makes it harder to extract out common code pieces into a shared library because everything must depend on a central global state.
In this diff we remove the usage of nonlocal return and just use the conventional function return to output the compilation data.
Differential Revision: [D80461258](https://our.internmc.facebook.com/intern/diff/D80461258/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160899
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #160814, #160815, #160855
After the change, the error stacktrace is attached with user code stack and is suppressed into 1 (without the scrolling up mssage). For example:
```python
class Test(torch.nn.Module):
def forward(self, c, x):
def cond_fn(c, x):
return c > 0 and x.size(0) < 20
def body_fn(c, x):
return c - 1, x.sin()
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x))
```
Now gives the following error message:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1705, in test_while_loop_size_mismatch_tensor_expansion
self._run_test(
~~~~~~~~~~~~~~^
model=WhileLoopModels.SizeMismatchTensorExpansion(),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<2 lines>...
dynamic=dynamic,
^^^^^^^^^^^^^^^^
)
^
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1417, in _run_test
result = model(*inputs_with_counters)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1053, in forward
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x))
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 176, in while_loop
return torch.compile(
~~~~~~~~~~~~~~
_while_loop_op_wrapper, backend=backend, fullgraph=True
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
)(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 804, in compile_wrapper
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1595, in __call__
result = self._torchdynamo_orig_backend(
frame, cache_entry, self.hooks, frame_state, skip=1
)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1353, in __call__
result = self._inner_convert(
frame, cache_entry, hooks, frame_state, skip=skip + 1
)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 682, in __call__
result = _compile(
frame.f_code,
...<16 lines>...
convert_frame_box=self._box,
)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1172, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_utils_internal.py", line 98, in wrapper_function
return function(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 858, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 897, in _compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1461, in transform_code_object
transformations(instructions, code_options)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 300, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 818, in transform
tracer.run()
~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3528, in run
super().run()
~~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
while self.step():
~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
self.dispatch_table[inst.opcode](self, inst)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 91, in graph_break_as_hard_error
raise exc.with_traceback(sys.exc_info()[2]) from None
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 77, in graph_break_as_hard_error
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1287, in call_function
) = speculate_subgraph(
~~~~~~~~~~~~~~~~~~^
tx,
^^^
...<33 lines>...
supports_aliasing=self.supports_aliasing,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 877, in speculate_subgraph
raise ex
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 718, in speculate_subgraph
output = f.call_function(tx, args, sub_kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function
return super().call_function(tx, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call
return tracer.inline_call_()
~~~~~~~~~~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_
self.run()
~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
while self.step():
~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
self.dispatch_table[inst.opcode](self, inst)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function
return super().call_function(tx, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call
return tracer.inline_call_()
~~~~~~~~~~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_
self.run()
~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
while self.step():
~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
self.dispatch_table[inst.opcode](self, inst)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 830, in inner
unimplemented_v2(
~~~~~~~~~~~~~~~~^
gb_type="Data-dependent branching",
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<5 lines>...
],
^^
)
^
File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 580, in unimplemented_v2
raise Unsupported(msg)
torch._dynamo.exc.UncapturedHigherOrderOpError: while_loop doesn't work unless it is captured completely with torch.compile. Got Data-dependent branching
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
Hint: Use `torch.cond` to express dynamic control flow.
Developer debug context: attempted to jump with TensorVariable()
For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0170.html
from user code:
File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 167, in _while_loop_op_wrapper
return while_loop_op(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 137, in flat_cond_fn
return cond_fn(*carried, *additional)
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1047, in cond_fn
return c > 0 and x.size(0) < 20
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
To execute this test, run the following from the base repo dir:
python test/inductor/test_control_flow.py WhileLoopTests.test_while_loop_size_mismatch_tensor_expansion_device_cpu_dynamic_False
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159296
Approved by: https://github.com/zou3519
I feel uneasy about touching `__warningregistry__` since it is undocumented and private surface. The only public API hook that doesn't increment warnings version seems to be https://docs.python.org/3/library/warnings.html#warnings.showwarning.
So we could wack a mole all the warnings muters in compile to just not display warnings, and we wouldn't invalidate warnings cache. This PR adds it for torch/_dynamo, and I didn't find any warnings versioning mutation from torch/_inductor.
There is a behavior change if someone calls a compiled graph with simplefilter("error"):
```python
# e.g. test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_nothing
with warnings.catch_warnings():
warnings.simplefilter("error") # turns all warnings into errors
compiled_fn() # will throw if any of the muted warnings fire
```
FIXES https://github.com/pytorch/pytorch/issues/128427
A note for the future: The warnings module doesn't offer a thread safe way of using it. Even regular filters have this problem, directly editing `__warningregistry__` would be very bad, and this PR would mute all threads. Someone will need to build a thread safe warnings interface.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158520
Approved by: https://github.com/anijain2305, https://github.com/zou3519
This should prevent bad resume function prologues from slipping by. In particular, graph breaks in resume function prologues will now hard error.
Implementation details:
- The resume function prologue is surrounded by `LOAD_CONST arg, STORE_FAST __is_tracing_resume_prologue` instructions. The first sequence has `arg=True` and the second sequence has `arg=False`.
- InstructionTranslator will know when it is tracing a resume function prologue when it detects `STORE_FAST __is_tracing_resume_prologue`. The top of stack will be True to mark the start of the prologue, False to mark the end.
- When `convert_frame.py` detects that an error occurred while the InstructionTranslator was tracing a resume function prologue, we will wrap the exception and hard error
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154564
Approved by: https://github.com/jansel
ghstack dependencies: #154283, #154289, #154782, #156762, #155166
This PR includes the GBID weblink whenever a user encounters a graph break. I also had to include the JSON file in setup.py, so it can be part of the files that are packaged in during CI. It also fixes the issue of the hardcoded error messages stripping away one of the '/' in 'https'.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156033
Approved by: https://github.com/williamwen42
This PR includes the GBID weblink whenever a user encounters a graph break. I also had to include the JSON file in setup.py, so it can be part of the files that are packaged in during CI. It also fixes the issue of the hardcoded error messages stripping away one of the '/' in 'https'.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156033
Approved by: https://github.com/williamwen42
This should prevent bad resume function prologues from slipping by. In particular, graph breaks in resume function prologues will now hard error.
Implementation details:
- The resume function prologue is surrounded by `LOAD_CONST arg, STORE_FAST __is_tracing_resume_prologue` instructions. The first sequence has `arg=True` and the second sequence has `arg=False`.
- InstructionTranslator will know when it is tracing a resume function prologue when it detects `STORE_FAST __is_tracing_resume_prologue`. The top of stack will be True to mark the start of the prologue, False to mark the end.
- When `convert_frame.py` detects that an error occurred while the InstructionTranslator was tracing a resume function prologue, we will wrap the exception and hard error
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154564
Approved by: https://github.com/jansel
ghstack dependencies: #154283, #154289, #154782, #155166
This should prevent bad resume function prologues from slipping by. In particular, graph breaks in resume function prologues will now hard error.
Implementation details:
- The resume function prologue is surrounded by `LOAD_CONST arg, STORE_FAST __is_tracing_resume_prologue` instructions. The first sequence has `arg=True` and the second sequence has `arg=False`.
- InstructionTranslator will know when it is tracing a resume function prologue when it detects `STORE_FAST __is_tracing_resume_prologue`. The top of stack will be True to mark the start of the prologue, False to mark the end.
- When `convert_frame.py` detects that an error occurred while the InstructionTranslator was tracing a resume function prologue, we will wrap the exception and hard error
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154564
Approved by: https://github.com/jansel
ghstack dependencies: #154283, #154289, #154782, #155166
Summary:
Today when guard serialization fails, dynamo will raise an internal error like:
```
torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: CLOSURE_MATCH guard cannot be serialized.
```
Adding a dedicated PackageError type to surface the error more clearly.
Test Plan: CI
Differential Revision: D75452124
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154430
Approved by: https://github.com/jamesjwu, https://github.com/jansel
Also show the line of code relevant to a dynamo-compiled frame, instead of just the first line (this was broken for data-dependent jump graph breaks and for 3.11+).
Also collapses resume frames together (use config.verbose to see full stack trace - for developers).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148401
Approved by: https://github.com/zou3519, https://github.com/jansel