Rewriting bucketing of all_gather and reduce_scatter with defining of "merge graph" via torch function.
`all_gather_merge_fn_to_trace`
`reduce_scatter_merge_fn_to_trace`
(Instead of creating nodes and doing FakeTensor prop manually)
This allows to experiment with merge function.
Used foreach_copy_ in merging function for all_gather - added lowering for inductor for `foreach_copy_`
Adding topological sort after bucketing passes (comment in post_grad.py):
```
# Fx collectives bucketing passes require topological sort for the cases:
# when bucketed collectives have users before the last collective in the bucket
# AND when inputs of bucketed collective have ancestors after the first collective in the bucket.
#
# In this case we can not manually pick the place for bucketed collective insertion.
# But we are guaranteed by the bucketing (independent collectives in the bucket),
# that it is possible to reorder nodes to satisfy all ordering requirements.
#
# --- before bucketing ---
# in0 = ...
# wait_ag0 = ag(in0)
# user0(wait_ag0)
# ...
# pre_in1 = ...
# in1 = transform(pre_in1)
# wait_ag1 = ag(in1)
# user1(wait_ag1)
#
# --- after bucketing ---
#
# in0 = ...
# user(wait_ag0) <--- wait_ag0 is defined only after bucketed collective.
#
# pre_in1 = ...
# in1 = transform(pre_in1)
# ag_bucket(in0+in1)
# wait_bucket
# wait_ag0 = wait_bucket[0]
# wait_ag1 = wait_bucket[1]
# user1(wait_ag1)
````
Correctness of the passes verified by loss curve for llama3 8b for simple_fsdp and for autoparallel:
<img width="1364" height="495" alt="Screenshot 2025-07-22 at 14 27 28" src="https://github.com/user-attachments/assets/67b2cabb-3206-450b-b529-e23c24292fc6" />
<img width="1355" height="509" alt="Screenshot 2025-07-22 at 14 27 56" src="https://github.com/user-attachments/assets/4d0e6b25-2eb1-47b2-8d68-dcec185239c4" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158663
Approved by: https://github.com/wconstab