11 Commits

Author SHA1 Message Date
21697feff2 [hop] run local_map with interpreter to preserve fx_traceback annotations (#165336)
We have an issue when using fx_traceback.annotate and HOPs that trace joint graphs. HOPs have bodies that have already been traced by Dynamo, and after Animesh's PR, does have the annotations. But when we lower that Dynamo HOP body to aten in either pre-dispatch or post-dispatch, we need to propagate the annotations to the aten nodes.

AOTAutograd does this indirectly by piggybacking off the `PropagateUnbackedSymInts` fx.Interpreter. I'm not sure if all HOPs should be using it to trace their joints or not. This PR adds an interpreter to local_map's implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165336
Approved by: https://github.com/yushangdi
2025-10-16 02:53:17 +00:00
a61d0de9f9 [hop] support local_map filtered gradients (#164437)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164437
Approved by: https://github.com/ezyang
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602, #164431, #164433
2025-10-10 02:34:27 +00:00
3ad88924ad [hop] support local_map None placements (#164433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164433
Approved by: https://github.com/ezyang
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602, #164431
2025-10-10 02:34:27 +00:00
3241b9c15f [hop] support local_map None gradients (#164431)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164431
Approved by: https://github.com/bdhirsh
ghstack dependencies: #164296, #164321, #164419, #164420, #164340, #163602
2025-10-10 02:34:27 +00:00
25d4d5107e [dynamo] trace local_map with local shapes for AP (#163602)
Context is in https://www.internalfb.com/excalidraw/EX519691 and https://docs.google.com/document/d/1qnuXLZk_GYt_PksHTwkn7L2ELRDnYlIRPkHAlXTyuhw/edit?tab=t.0. And the description of the previous PR: https://github.com/pytorch/pytorch/pull/164340.

The previous PR adds the support on the HOP side for eager execution and AOTAutograd. Dynamo is still passing the HOP a subgraph with wrong shapes. This PR fixes that. This is similar to the HOP implementation, however we additionally need to manually keep the TensorVariable metadata in sync.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163602
Approved by: https://github.com/ydwu4
ghstack dependencies: #164296, #164321, #164419, #164420, #164340
2025-10-10 02:34:27 +00:00
e4fe811be8 [hop] trace local_map with local shapes in fake key (#164340)
Context is in https://www.internalfb.com/excalidraw/EX519691 and https://docs.google.com/document/d/1qnuXLZk_GYt_PksHTwkn7L2ELRDnYlIRPkHAlXTyuhw/edit?tab=t.0.

So for Autoparallel initial trace, we want to trace the graph with global shapes initially. But, for the local_map region, we are forced to trace with the expected local tensors. To the tracers, this looks weird, because it's a plain tensor input (representing DTensor's full tensor .to_local()) that we need to "redistribute".

After hacking a miserable version that had cross-key dependencies, @ydwu4 proposed this simpler approach to override the fake key. This means the shape conversion will be invisible to all dispatch keys above fake, this covers all current tracing mechanisms. This manifests as the joint graph for the HOP body being traced with local shapes:
```python
# HOP forward, note local shapes (10, 80)
class GraphModule(torch.nn.Module):
    def forward(self, primals_0: "f32[10, 80]"):
        # No stacktrace found for following nodes
        view: "f32[800]" = torch.ops.aten.view.default(primals_0, [-1]);  primals_0 = None
        add: "f32[800]" = torch.ops.aten.add.Tensor(view, 10);  view = None
        view_1: "f32[10, 80]" = torch.ops.aten.view.default(add, [10, 80]);  add = None
        return (view_1,)

# HOP backward, note local shapes (10, 80)
class GraphModule(torch.nn.Module):
    def forward(self, tangents_0: "f32[10, 80]"):
        # No stacktrace found for following nodes
        clone: "f32[10, 80]" = torch.ops.aten.clone.default(tangents_0);  tangents_0 = None
        return (clone,)
```

while the rest of the graph is still traced with global shapes:
```python
# Parent graph joint, note global shapes (80, 80)
class inner_f(torch.nn.Module):
    def forward(self, primals, tangents):
        primals_1: "f32[80, 80]"; tangents_1: "f32[80, 80]";

        primals_1, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
         # File: /home/xmfan/core/a/pytorch/test/higher_order_ops/test_local_map.py:597 in forward, code: return fn(x)
        call_local_map = torch._higher_order_ops.local_map.call_local_map(primals_1);  primals_1 = None
        getitem: "f32[80, 80]" = call_local_map[0];  call_local_map = None
        call_local_map_1 = torch._higher_order_ops.local_map.call_local_map(tangents_1);  tangents_1 = None
        getitem_1: "f32[80, 80]" = call_local_map_1[0];  call_local_map_1 = None
        return pytree.tree_unflatten([getitem, getitem_1], self._out_spec)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164340
Approved by: https://github.com/ydwu4
ghstack dependencies: #164296, #164321, #164419, #164420
2025-10-10 02:34:27 +00:00
ae139b73e0 [dynamo] Better error message for local_map subgraph mismatches number of inputs/outputs with placement info (#164321)
Reviewed GPT5 summary:

**Summary / Goal**
Improve error reporting when local_map subgraph input/output counts mismatch placement info.

**Details**
- Adds descriptive runtime error messages.

**Motivation**
Helps debug local_map misalignments.

```python
AssertionError: Expecting 2 inputs to local_map function based on placements, but found 1. If the count matches for eager, Dynamo may have flattened inputs to the function or found additional tensors used via closures. Please adjust the input placements to match what the traced graph sees:
class GraphModule(torch.nn.Module):
    def forward(self, l_args_0_: "f32[8, 8, 16]"):
         # File: /home/xmfan/core/a/pytorch/test/higher_order_ops/test_local_map.py:523 in mismatch_input, code: return x + scalar, scalar
        child: "f32[8, 8, 16]" = l_args_0_ + 10;  l_args_0_ = None
        return (child,)
        .
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164321
Approved by: https://github.com/ezyang, https://github.com/mlazos
ghstack dependencies: #164296
2025-10-10 02:34:27 +00:00
124dd364e9 [hop] support local_map + SAC (#163322)
Some ops like local_map hop's deferred mode are not desugared by make_fx, this means that when we apply SAC tags, we will need to define dispatch rules for the SAC torch dispatch modes as pointed out here: https://github.com/pytorch/pytorch/issues/162246#issuecomment-3259176721. This PR adds those rules.

Additionally it fixes a pre-existing issue where we weren't coercing tangent layout (that AOTAutograd typically does) when partitioning the HOP joint.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163322
Approved by: https://github.com/ezyang
2025-09-24 04:57:40 +00:00
821458d97a [dynamo][hop] Introduce Local Map HOP (#161458)
Can't actually deploy it because of: https://github.com/pytorch/pytorch/issues/161456

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161458
Approved by: https://github.com/ydwu4
2025-09-17 09:32:38 +00:00
e7c3f802ff Revert "[dynamo][hop] Introduce Local Map HOP (#161458)"
This reverts commit 505458db803e1ffabac08a2fc150b566d3ea3a57.

Reverted https://github.com/pytorch/pytorch/pull/161458 on behalf of https://github.com/jeffdaily due to broke rocm tests ([comment](https://github.com/pytorch/pytorch/pull/161458#issuecomment-3299230458))
2025-09-16 15:14:36 +00:00
505458db80 [dynamo][hop] Introduce Local Map HOP (#161458)
Can't actually deploy it because of: https://github.com/pytorch/pytorch/issues/161456

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161458
Approved by: https://github.com/ydwu4
2025-09-16 00:37:40 +00:00