This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.
### UX
1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.
Example
```
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
```
2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is
```
# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
return aot_autograd(
fw_compiler=compile_fx_annotated_nodes_with_inductor,
bw_compiler=compile_fx_annotated_nodes_with_inductor,
)
```
3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.
### Implementation
1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module
Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`
Forward graph
```
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
inner = torch__inductor_standalone_compile_inner(sin, primals_2)
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
getitem: "f32[10]" = inner[0]; inner = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
return (sin_1, primals_1, primals_2, sin, getitem)
```
Backward graph
```
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1); primals_1 = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
cos: "f32[10]" = torch.ops.aten.cos.default(add); add = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
# No stacktrace found for following nodes
inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2); mul_1 = sin = primals_2 = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
getitem: "f32[10]" = inner[0]
getitem_1: "f32[10]" = inner[1]; inner = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1); getitem_1 = cos_1 = None
return (mul_4, getitem)
```
### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
a) compile flex
b) streams
c) nn.Module info to organize MoE components for pipelining
d) PP stages
e) Rename graph nodes for more debugging
f) No nested regional compile
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164776
Approved by: https://github.com/SherlockNoMad
ghstack dependencies: #165188
After the change, the error stacktrace is attached with user code stack and is suppressed into 1 (without the scrolling up mssage). For example:
```python
class Test(torch.nn.Module):
def forward(self, c, x):
def cond_fn(c, x):
return c > 0 and x.size(0) < 20
def body_fn(c, x):
return c - 1, x.sin()
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x))
```
Now gives the following error message:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1705, in test_while_loop_size_mismatch_tensor_expansion
self._run_test(
~~~~~~~~~~~~~~^
model=WhileLoopModels.SizeMismatchTensorExpansion(),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<2 lines>...
dynamic=dynamic,
^^^^^^^^^^^^^^^^
)
^
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1417, in _run_test
result = model(*inputs_with_counters)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1053, in forward
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x))
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 176, in while_loop
return torch.compile(
~~~~~~~~~~~~~~
_while_loop_op_wrapper, backend=backend, fullgraph=True
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
)(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 804, in compile_wrapper
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1595, in __call__
result = self._torchdynamo_orig_backend(
frame, cache_entry, self.hooks, frame_state, skip=1
)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1353, in __call__
result = self._inner_convert(
frame, cache_entry, hooks, frame_state, skip=skip + 1
)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 682, in __call__
result = _compile(
frame.f_code,
...<16 lines>...
convert_frame_box=self._box,
)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1172, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_utils_internal.py", line 98, in wrapper_function
return function(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 858, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 897, in _compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1461, in transform_code_object
transformations(instructions, code_options)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 300, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 818, in transform
tracer.run()
~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3528, in run
super().run()
~~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
while self.step():
~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
self.dispatch_table[inst.opcode](self, inst)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 91, in graph_break_as_hard_error
raise exc.with_traceback(sys.exc_info()[2]) from None
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 77, in graph_break_as_hard_error
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1287, in call_function
) = speculate_subgraph(
~~~~~~~~~~~~~~~~~~^
tx,
^^^
...<33 lines>...
supports_aliasing=self.supports_aliasing,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 877, in speculate_subgraph
raise ex
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 718, in speculate_subgraph
output = f.call_function(tx, args, sub_kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function
return super().call_function(tx, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call
return tracer.inline_call_()
~~~~~~~~~~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_
self.run()
~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
while self.step():
~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
self.dispatch_table[inst.opcode](self, inst)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function
return super().call_function(tx, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call
return tracer.inline_call_()
~~~~~~~~~~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_
self.run()
~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
while self.step():
~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
self.dispatch_table[inst.opcode](self, inst)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 830, in inner
unimplemented_v2(
~~~~~~~~~~~~~~~~^
gb_type="Data-dependent branching",
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<5 lines>...
],
^^
)
^
File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 580, in unimplemented_v2
raise Unsupported(msg)
torch._dynamo.exc.UncapturedHigherOrderOpError: while_loop doesn't work unless it is captured completely with torch.compile. Got Data-dependent branching
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
Hint: Use `torch.cond` to express dynamic control flow.
Developer debug context: attempted to jump with TensorVariable()
For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0170.html
from user code:
File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 167, in _while_loop_op_wrapper
return while_loop_op(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 137, in flat_cond_fn
return cond_fn(*carried, *additional)
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1047, in cond_fn
return c > 0 and x.size(0) < 20
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
To execute this test, run the following from the base repo dir:
python test/inductor/test_control_flow.py WhileLoopTests.test_while_loop_size_mismatch_tensor_expansion_device_cpu_dynamic_False
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159296
Approved by: https://github.com/zou3519
Summary:
We transfer stack trace in post_grad passes.
We shouldn't add "stack_trace" to _COPY_META_FIELDS because _COPY_META_FIELDS is used in proxy.py where stack_trace is explicitly set.
Since the stack_trace is being used by more and more debugging tools, we should also start testing it more rigorously. This PR start by adding a first test for testing that stack trace is preserved through post_grad_passes.
Test Plan:
```
buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r test_pattern_matcher_transfer_meta
buck run mode/dev-nosan
fbcode//caffe2/test/inductor:auto_functionalize -- --rcaffe2/test/inductor:auto_functionalize_old
```
Rollback Plan:
Differential Revision: D78669729
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158752
Approved by: https://github.com/jingsh
The reason why we did this before is because that's how our older
autograd.Function x Dynamo interaction work, but we've since adopted
newer designs that don't actually need the autograd.Function.ctx proxied
into the graph.
We still need a fx.Proxy for the autograd.Function.ctx object, so
whenever we do I create one via discard_graph_changes.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152621
Approved by: https://github.com/oulgen
The "PendingUnbackedSymbolNotFound" error is when an unbacked symbol is created within a piece of code, but this symbol never appears in any of the outputs. I believe the original intention is to help catch incorrectly written meta kernels, where users might've unintentionally created an unbacked symbol but never used it anywhere, but in our case this is intentional. An example is the following test case:
```python
def test_pending_unbacked(self):
class M(torch.nn.Module):
@mark_compile_region
def gn(self, x):
u = x[0].item()
return x * u
def forward(self, x):
for _ in range(4):
x = self.gn(x)
return x
torch._dynamo.config.capture_scalar_outputs = True
torch.compile(M())(torch.randn(8))
```
This fails with the error:
```
torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {zuf1} not in returned outputs (FakeTensor(..., size=(8,)),) .
```
In this case, creating the unbacked symbol is intentional, so we can bypass this using `fake_mode.shape_env.ignore_fresh_unbakced_symbols()`.
Differential Revision: [D71298926](https://our.internmc.facebook.com/intern/diff/D71298926)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149297
Approved by: https://github.com/zou3519
ghstack dependencies: #149296