mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[bucketing] Rewrite all_gather, reduce_scatter passes via tracing merge_fn (#158663)
Rewriting bucketing of all_gather and reduce_scatter with defining of "merge graph" via torch function. `all_gather_merge_fn_to_trace` `reduce_scatter_merge_fn_to_trace` (Instead of creating nodes and doing FakeTensor prop manually) This allows to experiment with merge function. Used foreach_copy_ in merging function for all_gather - added lowering for inductor for `foreach_copy_` Adding topological sort after bucketing passes (comment in post_grad.py): ``` # Fx collectives bucketing passes require topological sort for the cases: # when bucketed collectives have users before the last collective in the bucket # AND when inputs of bucketed collective have ancestors after the first collective in the bucket. # # In this case we can not manually pick the place for bucketed collective insertion. # But we are guaranteed by the bucketing (independent collectives in the bucket), # that it is possible to reorder nodes to satisfy all ordering requirements. # # --- before bucketing --- # in0 = ... # wait_ag0 = ag(in0) # user0(wait_ag0) # ... # pre_in1 = ... # in1 = transform(pre_in1) # wait_ag1 = ag(in1) # user1(wait_ag1) # # --- after bucketing --- # # in0 = ... # user(wait_ag0) <--- wait_ag0 is defined only after bucketed collective. # # pre_in1 = ... # in1 = transform(pre_in1) # ag_bucket(in0+in1) # wait_bucket # wait_ag0 = wait_bucket[0] # wait_ag1 = wait_bucket[1] # user1(wait_ag1) ```` Correctness of the passes verified by loss curve for llama3 8b for simple_fsdp and for autoparallel: <img width="1364" height="495" alt="Screenshot 2025-07-22 at 14 27 28" src="https://github.com/user-attachments/assets/67b2cabb-3206-450b-b529-e23c24292fc6" /> <img width="1355" height="509" alt="Screenshot 2025-07-22 at 14 27 56" src="https://github.com/user-attachments/assets/4d0e6b25-2eb1-47b2-8d68-dcec185239c4" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/158663 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc5dbbbb78
commit
8aebf01287