We have an issue when using fx_traceback.annotate and HOPs that trace joint graphs. HOPs have bodies that have already been traced by Dynamo, and after Animesh's PR, does have the annotations. But when we lower that Dynamo HOP body to aten in either pre-dispatch or post-dispatch, we need to propagate the annotations to the aten nodes.
AOTAutograd does this indirectly by piggybacking off the `PropagateUnbackedSymInts` fx.Interpreter. I'm not sure if all HOPs should be using it to trace their joints or not. This PR adds an interpreter to local_map's implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165336
Approved by: https://github.com/yushangdi
Context is in https://www.internalfb.com/excalidraw/EX519691 and https://docs.google.com/document/d/1qnuXLZk_GYt_PksHTwkn7L2ELRDnYlIRPkHAlXTyuhw/edit?tab=t.0.
So for Autoparallel initial trace, we want to trace the graph with global shapes initially. But, for the local_map region, we are forced to trace with the expected local tensors. To the tracers, this looks weird, because it's a plain tensor input (representing DTensor's full tensor .to_local()) that we need to "redistribute".
After hacking a miserable version that had cross-key dependencies, @ydwu4 proposed this simpler approach to override the fake key. This means the shape conversion will be invisible to all dispatch keys above fake, this covers all current tracing mechanisms. This manifests as the joint graph for the HOP body being traced with local shapes:
```python
# HOP forward, note local shapes (10, 80)
class GraphModule(torch.nn.Module):
def forward(self, primals_0: "f32[10, 80]"):
# No stacktrace found for following nodes
view: "f32[800]" = torch.ops.aten.view.default(primals_0, [-1]); primals_0 = None
add: "f32[800]" = torch.ops.aten.add.Tensor(view, 10); view = None
view_1: "f32[10, 80]" = torch.ops.aten.view.default(add, [10, 80]); add = None
return (view_1,)
# HOP backward, note local shapes (10, 80)
class GraphModule(torch.nn.Module):
def forward(self, tangents_0: "f32[10, 80]"):
# No stacktrace found for following nodes
clone: "f32[10, 80]" = torch.ops.aten.clone.default(tangents_0); tangents_0 = None
return (clone,)
```
while the rest of the graph is still traced with global shapes:
```python
# Parent graph joint, note global shapes (80, 80)
class inner_f(torch.nn.Module):
def forward(self, primals, tangents):
primals_1: "f32[80, 80]"; tangents_1: "f32[80, 80]";
primals_1, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
# File: /home/xmfan/core/a/pytorch/test/higher_order_ops/test_local_map.py:597 in forward, code: return fn(x)
call_local_map = torch._higher_order_ops.local_map.call_local_map(primals_1); primals_1 = None
getitem: "f32[80, 80]" = call_local_map[0]; call_local_map = None
call_local_map_1 = torch._higher_order_ops.local_map.call_local_map(tangents_1); tangents_1 = None
getitem_1: "f32[80, 80]" = call_local_map_1[0]; call_local_map_1 = None
return pytree.tree_unflatten([getitem, getitem_1], self._out_spec)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164340
Approved by: https://github.com/ydwu4
ghstack dependencies: #164296, #164321, #164419, #164420
Reviewed GPT5 summary:
**Summary / Goal**
Improve error reporting when local_map subgraph input/output counts mismatch placement info.
**Details**
- Adds descriptive runtime error messages.
**Motivation**
Helps debug local_map misalignments.
```python
AssertionError: Expecting 2 inputs to local_map function based on placements, but found 1. If the count matches for eager, Dynamo may have flattened inputs to the function or found additional tensors used via closures. Please adjust the input placements to match what the traced graph sees:
class GraphModule(torch.nn.Module):
def forward(self, l_args_0_: "f32[8, 8, 16]"):
# File: /home/xmfan/core/a/pytorch/test/higher_order_ops/test_local_map.py:523 in mismatch_input, code: return x + scalar, scalar
child: "f32[8, 8, 16]" = l_args_0_ + 10; l_args_0_ = None
return (child,)
.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164321
Approved by: https://github.com/ezyang, https://github.com/mlazos
ghstack dependencies: #164296