Commit Graph

3803 Commits

Author SHA1 Message Date
19bf67be32 multimem reduce (#164517)
Modified `multimem_one_shot_all_reduce_out` function to accept a `root` argument, making it a `multimem_reduce` op.

The original `multimem_one_shot_all_reduce` op becomes a caller of the `multimem_reduce`, with each rank providing its own rank id as root.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164517
Approved by: https://github.com/ngimel
2025-10-08 05:25:16 +00:00
da903b6a8b list_stored_sd_metadata API. (#160610)
Summary:
1\ Certain checkpoint load use cases are not aware of the properties of the data/tensors they want to load.
2\ These usecases include data loader checkpoints, reading data for post processing (when the original model definition is not available).
3\ There, we have to use saved checkpoint  (metadata) as our source of truth.
4\ This RFC proposal exposes the checkpoint metadata using a public API.

In this proposal we expose the stored state-dict metadata  (minus associated storage/chunk metadata).

Chunk/storage details should not be exposed to the users and is a impl detail of the storage writer/reader.

Test Plan:
UT.

Rollback Plan:

Differential Revision: D80231457

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160610
Approved by: https://github.com/saumishr
2025-10-08 04:33:51 +00:00
d444384003 [SymmMem] Tiled reduce (#162243)
Added op: `tile_reduce(Tensor input, Tensor(a!) out, int root, str group_name)`

For now supports only:
- NVSHMEM backed symmetric tensor;
- 2D tensor and tile;
- torch.float.

Testing on right-bottom quandrant:
```
rank 0:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.]], device='cuda:0')
PASSED
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162243
Approved by: https://github.com/ngimel
2025-10-08 02:03:04 +00:00
c813617c53 [PP] Migrate other schedules to use PipelineScheduleRuntime (#164777)
Second fix for https://github.com/pytorch/pytorch/issues/164756

This has been a TODO to make the all schedules execute using the same runtime. Now after this change, schedules will use the same logic for `_PipelineScheduleRuntime` where it adds `UNSHARD` and `RESHARD` operations to the schedules which fixes the issue mentioned above.

<img width="920" height="406" alt="image" src="https://github.com/user-attachments/assets/a4d5bcd0-7dac-43cd-96f9-8ca33cfd8b91" />

A test is failing after the conversion:
- Fixed a gradient scaling issue for dWeight

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164777
Approved by: https://github.com/fegin
ghstack dependencies: #164775
2025-10-08 01:45:57 +00:00
e659661ffa [PP] Fix FSDP unshard/reshard (#164775)
First fix for https://github.com/pytorch/pytorch/issues/164756

In the pipeline IR we call `UNSHARD` and `RESHARD`,  but there is a bug because when we call `module.unshard()` these do not recursively call the FSDP modules, hence leading to sometime call allgather before the module forward.

Since we want the pipeline IR to explicitly handle this, we can call `group.unshard` instead which ensures that all the modules are unsharded.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164775
Approved by: https://github.com/weifengpy
2025-10-08 01:45:57 +00:00
c0510dc447 [ContextParallel] add _LoadBalancer classes, and load-balance interface to Context Parallel APIs (#161062)
**Summary**
This PR provides an interface for users to specify how to load-balance the attention
input. The load-balance is essentially a rearrangement of the input tensor(s) over the
seq_dim before sharding and can be specified via an index tensor `rearrange` such
that Q[rearrange] is the balanced Q users want (i.e. `rearrange[i] == j` where `i` is the new
index of `Q[j]` in the balanced Q). An example is the `_generate_round_robin_indices()` added
in https://github.com/pytorch/pytorch/pull/155442.

**New `_LoadBalancer` classes**
New `_LoadBalancer` class (defined in `torch/distributed/tensor/experimental/_load_balancer.py`)
provides one interface for defining load-balance behavior: `_generate_indices(self, restore: bool = False)`.

When `restore == False`, this method should output an index Tensor (namely `rearrange_idx`) such
that QKV will be transformed into Q' K' V' in a way that `Q'[i] == Q[rearrange_idx[i]]` (same applies
to K and V).

When `restore == True`, this method outputs an index Tensor (namely `restore_idx` such that
`Q'[restore_idx] == Q` (same applies to K and V).

**Impact**
2 public CP APIs and 1 private CP API is modified. This PR should be backward-compatible by:
- For uses w/ SDPA, existing users must be using the `context_parallel()` API which does not
take in the extra `load_balancer` argument and solely determines from the global var
`_cp_options.enable_load_balance`.
- For new users including who want to try `flex_attention()`, we require to use the new API
`_context_parallel_buffers` to explicitly shard the QKV input instead of using `context_parallel()`
because we no longer rely on TorchDispatchMode nor TorchFunctionMode for op replacement. And
we also require users to explicitly pass in a `load_balancer` argument if load-balancing is demanded.

**Load-Balance Behavior**
`context_parallel_unshard()`, and `create_cp_block_mask()` APIs now take an extra optional argument
`load_balancer`. This argument is optional because of backward compatibility but we require new users
to explicitly pass in a `load_balancer` if load-balancing is demanded:
- if `load_balancer == None` and `_cp_options.enable_load_balance == False`, CP performs
no load-balancing on input Tensors.
- if `load_balancer == None` and `_cp_options.enable_load_balance ==True`, CP performs
head-tail load-balancing (e.g. split a Tensor into 2*N chunks and first N are called head and
the rest are called tail. Place the first head chunk the last tail chunk on rank 0, and the second
head along with the second last tail chunk on rank 1, and so on).

`_context_parallel_buffers()` also takes the extra optional argument `load_balancer`, but the behavior
is slightly different from the other 2 APIs -- it doesn't branch on `_cp_options.enable_load_balance` :
- if `load_balancer == None`, no load-balancing will be performed
- otherwise, apply load-balancing using `load_balancer._generate_indices()` before sharding.

**Changes**
This PR moves the index Tensor generation logic into a set of LoadBalancer classes and
make LoadBalancer the common interface for Context Parallel APIs that leverages
load-balancing:
* _context_parallel_buffers
* context_parallel_unshard
* create_cp_block_mask

The `_LoadBalancer` classes added are:
- `_LoadBalancer`: the abstract base class that provides “_generate_indices” interface index Tensor generation.
- `_HeadTailLoadBalancer`: Implements head-tail balancing logic.
- `_PerDocumentHeadTailLoadBalancer`: Supports per-document head-tail balancing for batched sequences.

**Test**
`pytest test/distributed/tensor/test_attention.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161062
Approved by: https://github.com/fegin
2025-10-08 01:09:14 +00:00
e3ae80fc03 [PP] Let PP split BlockMask into micro-BlockMask (#164111)
BlockMask has batch dimension information. So PP has to split it as well just like all other tensors. All the tensors in BlockMask have the batch dimension, so we can just split it without too many issues. However, `mask_mod` requires the batch index as the input, which the value is going to be changed after the split. So we have to wrap it inside a closure to modify the batch index.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164111
Approved by: https://github.com/H-Huang
2025-10-07 23:25:34 +00:00
f505caa71b Revert "multimem reduce (#164517)"
This reverts commit d1cbb74fb16406488a174832e1b58b7c242f418d.

Reverted https://github.com/pytorch/pytorch/pull/164517 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/164517#issuecomment-3378529654))
2025-10-07 20:12:38 +00:00
df640df68a Revert "Reapply "C++-accessible Placements via pybind11 (#163030)" (#164519)"
This reverts commit 8c0bc879b97bc580aaa0777b2d266bdd068cb528.

Reverted https://github.com/pytorch/pytorch/pull/164519 on behalf of https://github.com/malfet due to Still breaks internal workflows ([comment](https://github.com/pytorch/pytorch/pull/164519#issuecomment-3378469432))
2025-10-07 19:46:17 +00:00
4725871a81 Return fake mode from export graph capture API (#164730)
This PR is to temporarily unblock various experiments to re-use dynamo create fake mode. Note that this is still not what we want as the end state. The end state should look sth like:
```
out = fulllgraph_capture(mod, inputs)
fake_mode = out.backend_inputs.fake_mode
gm  = out.module()
```
This doesn't work today because export requires we need to wrap the original module to setup a flat module to trace for easier handling of pytree. As a result, we would need to carry export specific flag in fullgraph_capture which seems not ideal.
Regardless, the end state is that we need to give downstream user a graph module and a fake mode in some form, so I think _dynamo_graph_capture_for_export returning a fake mode within graph module itself via gm.meta

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164730
Approved by: https://github.com/avikchaudhuri
2025-10-07 03:42:46 +00:00
8c0bc879b9 Reapply "C++-accessible Placements via pybind11 (#163030)" (#164519)
This makes Placement data representation available in C++ via pybind11. Reapply with fix for internal errors.

D83788896

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164519
Approved by: https://github.com/Skylion007, https://github.com/ezyang
2025-10-06 23:19:14 +00:00
b558c986e8 Add regression test for get_root_mesh with multiple independent meshes (#164731)
Fixes #163330

I tried to reproduce the bug with my 4-GPU setup (the original issue used 8 GPUs). I created several different test scenarios, trying to trigger the bug by:
- creating two different device meshes
- slicing them in various ways
- checking if get_root_mesh() would get confused

but the bug didn't show up! Everything worked correctly in `2.10`. I found that there was a massive refactoring of the `DeviceMesh` code (PR #163213) that landed on October 2nd. That PR completely rewrote how `DeviceMesh` tracks relationships between parent meshes and submeshes using. It seems like this refactoring fixed the bug! But I added a regression test to make sure it doesn't come back. The test (`test_get_root_mesh_multiple_independent_meshes`) does exactly what the bug report described:
  - creates two independent meshes
  - slices them both
  - verifies that each submesh correctly points back to its real parent
  - makes sure submeshes from mesh1 don't incorrectly claim mesh2 as their parent

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164731
Approved by: https://github.com/fduwjj
2025-10-06 18:52:25 +00:00
35f66b83f8 respect aten planned overlap in inductor (#164569)
Now that we have a hop to add implicit deps - use those deps for comm/compute overlap.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164569
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164568
2025-10-06 15:47:55 +00:00
660e369a68 [FSDP2] check storage equal and consider data_ptr() == 0 (#164595)
resolve https://github.com/pytorch/pytorch/issues/164554

unit test
* `pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict`
* `pytest -s test/distributed/_composable/fsdp/test_fully_shard_init.py -k test_meta_device_1d_init`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164595
Approved by: https://github.com/fegin
2025-10-06 08:44:38 +00:00
5d7360bb03 Revert "Enable all SIM rules except disabled ones (#164645)"
This reverts commit 321e6026925f6b6e8a36e3a8b7c0295cd7541911.

Reverted https://github.com/pytorch/pytorch/pull/164645 on behalf of https://github.com/izaitsevfb due to causes lint failures ([comment](https://github.com/pytorch/pytorch/pull/164645#issuecomment-3369274351))
2025-10-05 19:32:21 +00:00
321e602692 Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645
Approved by: https://github.com/ezyang
2025-10-05 07:38:25 +00:00
5ed4270440 remove more no longer needed torch._check_is_size calls 1 (#164630)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164630
Approved by: https://github.com/Skylion007
ghstack dependencies: #164627
2025-10-04 22:06:04 +00:00
b116c51330 torch.cond on DTensor triggers an internal assert, add xfail for this. (#164389)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164389
Approved by: https://github.com/albanD
2025-10-04 18:12:06 +00:00
6b768e1890 Support propagating custom meta field to backward graph nodes (#164174)
# Propagate custom meta data to backward

Support propagating the user annotation tags to backward graph, by extending the `copy_fwd_metadata_to_bw_nodes` utils (recommended by @xmfan , thanks!).

Example annotation API (added in https://github.com/pytorch/pytorch/pull/163673):

```
class M(torch.nn.Module):
    def forward(self, x):
        with fx_traceback.annotate({"pp_stage": 0}):
            with fx_traceback.annotate({"fdsp_bucket": 0}):
                x = x + 1
            x = x - 2
            with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
                x = x * 2
        x = x / 3
        return x
```

Assumptions (some inherited from https://github.com/pytorch/pytorch/pull/126573):

- I am trusting the seq_nr mapping introduced to aot_autograd nodes in https://github.com/pytorch/pytorch/pull/103129
- I am also trusting that the forward is single threaded, since seq_nr is thread local.  If this isn't always true, we'll need to also plumb thread_id through the same machinery which is populating seq_nr.
- **(This is changed in this PR!) I assume all backward graph nodes has "is_backward" for 'partitioner_tag', and all other nodes are forward graph nodes**.  If we don't run export before `aot_export_join_with_descriptors`, then none of the nodes has "nn_module_stack" in node meta. If we do run export first, then we don't need this change.
- I copy "custom" node meta from forward to backward graph nodes.

Question:
- Is it a good idea to copy all "custom" node meta? Or should we create a dedicated key in custom node meta to be copied? @SherlockNoMad
- Do we expect people to run export before using `aot_export_join_with_descriptors`?
- Can we assume the following for graph produced by `aot_export_join_with_descriptors`? "all backward graph nodes has "is_backward" for 'partitioner_tag', and all other nodes are forward graph nodes". Maybe this is a question for @ezyang

```
python test/functorch/test_aot_joint_with_descriptors.py -k test_preserve_
python test/export/test_export.py -k preserve_anno
python test/distributed/tensor/test_dtensor_export.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164174
Approved by: https://github.com/xmfan, https://github.com/SherlockNoMad
2025-10-04 05:03:32 +00:00
1894082000 UT/Examples for resharding checkpoint save/loads for distributed tensors with uneven shards. (#160533)
1\  DTensor abstraction on its own, does not support arbitrary length shards in its distributed tensors representation. It supports a single uneven shard, bit it has to be the last shard in the sharding dimension.

2\ However, DCP supports an API called checkpointable. This API allows you to define your custom shardable tensor structure. I have given a UT example ( look for CheckpointableDistTensor). Therefore, one option is to use CheckpointableDistTensor to save/load uneven shards.

3\ While exploring this path, I also noticed that torch.rec module also encountered a similar problem while working with DTensor. They workaround it by implementing Checkpointable API in DTensor and introducing an auxillary structure called LocalShardsWrapper. This is the second option we can use to unblock data loader resharding work.

In summary;
Use LocalShardWrapper + DTensor as the first option to unblock.
Second preference is to use new implementation of Checkpointable API. ( similar to CheckpointbaleDistTensor I have introduced in this example).

Differential Revision: D80182564

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160533
Approved by: https://github.com/saumishr
2025-10-03 22:15:02 +00:00
3ca09d65f1 [ROCm] Enable several distributed UTs (#164390)
Increase the tolerance for the following UTs as there was a slight mismatch seen on MI200.
    - test_data_parallel.py:test_strided_grad_layout
    - test_c10d_nccl.py:test_grad_layout_1devicemodule_1replicaperprocess

Skip for MI200:
    - test_fully_shard_training.py:test_2d_mlp_with_nd_mesh
    - test_2d_composability.py:test_train_parity_2d_mlp
    - test_fully_shard_overlap.py:test_fully_shard_training_overlap

Fixes #159489
Fixes #159488
Fixes #152700
Fixes #125555
Fixes #134139

Working as is on both MI200 and MI300:
Fixes #125991
Fixes #125918

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164390
Approved by: https://github.com/jeffdaily
2025-10-03 19:52:51 +00:00
5b0b4cda4a [dtensor] avoid shape recompilations on DTensorSpec (#163820)
skips DTensorSpec.sizes/strides in metadata guard checks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163820
Approved by: https://github.com/azahed98
2025-10-03 17:18:18 +00:00
5743d731c1 Use torch.testing.test_close instead of torch.testing.test_allclose (#164539)
Because torch.testing.test_allclose is deprecated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164539
Approved by: https://github.com/mlazos
2025-10-03 14:39:10 +00:00
2a760dc51e [DeviceMesh] Simplifying internal bookkeeping with CuTe layout (#163213)
We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.

Concretely, in this PR, we do the following:
1. Use the `_MeshLayout` to handle all index operations rather use a map to record mesh dims.
2. Removed `flatten_name_to_root_dims`, because now we can directly get layout from a flattened device mesh.
3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`.
4. Use the newly added function `check_overlap` to check layout overlap.
5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The reason is that layout acts as a backend of mesh tensor bookkeeping (indexing indices), it needs to be used as indices for remap back to the mesh tensor for new DeviceMesh generation and backend init. For example, in the case of 2K to 4K, the underlying layout is (2K, 1) but the actual value of the mesh tensor is [2K, 2K+1, ....,]. While flattening, slicing, we need to remap the layout back to the new mesh tensor so it maps the actual device allocation. For example, in the 2K to 4K case, if the shape is (1K, 1K) with dim_names ("dp", "tp"). Then when slicing "tp", the mesh tensor should be (2K, 2K+1, ..., 3K-1) or (3K, 3K+1, ... 4K-1). not the global ranks generated from the layout. (1K, 1).

Verified that loss curve is very close for DeepSeekV3 on torchtitan, note that exact same match is challenging because even if we run the baseline twice, the loss curve does not exactly match.

<img width="1113" height="490" alt="image" src="https://github.com/user-attachments/assets/7877b5a4-337e-4ad8-b878-2378f4f0f38d" />

The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.

With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.

This is a continue of https://github.com/pytorch/pytorch/pull/161106 (original one got messed with EasyCLA)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163213
Approved by: https://github.com/lw, https://github.com/fegin
2025-10-03 05:51:28 +00:00
d1cbb74fb1 multimem reduce (#164517)
Modified `multimem_one_shot_all_reduce_out` function to accept a `root` argument, making it a `multimem_reduce` op.

The original `multimem_one_shot_all_reduce` op becomes a caller of the `multimem_reduce`, with each rank providing its own rank id as root.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164517
Approved by: https://github.com/ngimel
2025-10-03 02:41:10 +00:00
22e219d996 Revert "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout (#163213)"
This reverts commit b0985144b59db8fb20964829b5e0a9d2f9a3f0d6.

Reverted https://github.com/pytorch/pytorch/pull/163213 on behalf of https://github.com/yangw-dev due to caused internal test failure ([comment](https://github.com/pytorch/pytorch/pull/163213#issuecomment-3363414435))
2025-10-02 22:22:26 +00:00
ece5e0f01b Fake process group Direct construction error (#163665)
Fixes #162129. Added validation in _rank_not_in_group() to check if ```FakeProcessGroup``` is properly initialized before use, raising a clear error message if ```torch.distributed.init_process_group(backend='fake')``` hasn't been called first.
This prevents silent failures and ensures proper dispatch system integration for all distributed operations.

Added test case test_fake_process_group_direct_usage_error() that validates the error is raised for ```all_reduce``` and ```all_to_all_single``` operations.

Please let me know if additional distributed operators should be tested or if any other updates are needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163665
Approved by: https://github.com/ezyang
2025-10-02 22:19:26 +00:00
cc71ab86a6 [DTensor] raise error if the local_tensor argument passed to DTensor.from_local is a DTensor (#164496)
**Summary**
Raise error when the `local_tensor` argument passed to `DTensor.from_local` is
a DTensor, this prevents users from accidentally calling `from_local` over a DTensor
object.

The error message is organized in this way:
```
the local_tensor argument only accepts torch.Tensor but got <class 'torch.distributed.tensor.DTensor'> value.
```

**Test**
`pytest test/distributed/tensor/test_dtensor.py -k test_from_local`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164496
Approved by: https://github.com/ezyang
2025-10-02 21:25:01 +00:00
f6f7676756 Revert "C++-accessible Placements via pybind11 (#163030)"
This reverts commit 3e03deab6f3c268c85c8efd9546e28cdda0fa4cc.

Reverted https://github.com/pytorch/pytorch/pull/163030 on behalf of https://github.com/swolchok due to doesn't pass pyre ([comment](https://github.com/pytorch/pytorch/pull/163030#issuecomment-3362450379))
2025-10-02 18:25:24 +00:00
b0985144b5 [DeviceMesh] Simplifying internal bookkeeping with CuTe layout (#163213)
We want to refactor the internal bookkeeping of DeviceMesh so that:
Simply the bookkeeping logics and make it generic enough so that it is easy to support new transformations like flatten noncontiguous dim, reshape and unflatten. (We leveraged the CuTe layout). This new layout also let us handle non-contiguous slicing, flatten, transpose possible.

Concretely, in this PR, we do the following:
1. Use the `_MeshLayout` to handle all index operations rather use a map to record mesh dims.
2. Removed `flatten_name_to_root_dims`, because now we can directly get layout from a flattened device mesh.
3. Replaced `_get_slice_mesh_dims` with `_get_slice_mesh_layout`.
4. Use the newly added function `check_overlap` to check layout overlap.
5. Use a new function `to_remapping_tensor` to use layout ranks as indices when the mesh tensor is not representable as CuTe. The reason is that layout acts as a backend of mesh tensor bookkeeping (indexing indices), it needs to be used as indices for remap back to the mesh tensor for new DeviceMesh generation and backend init. For example, in the case of 2K to 4K, the underlying layout is (2K, 1) but the actual value of the mesh tensor is [2K, 2K+1, ....,]. While flattening, slicing, we need to remap the layout back to the new mesh tensor so it maps the actual device allocation. For example, in the 2K to 4K case, if the shape is (1K, 1K) with dim_names ("dp", "tp"). Then when slicing "tp", the mesh tensor should be (2K, 2K+1, ..., 3K-1) or (3K, 3K+1, ... 4K-1). not the global ranks generated from the layout. (1K, 1).

Verified that loss curve is very close for DeepSeekV3 on torchtitan, note that exact same match is challenging because even if we run the baseline twice, the loss curve does not exactly match.

<img width="1113" height="490" alt="image" src="https://github.com/user-attachments/assets/7877b5a4-337e-4ad8-b878-2378f4f0f38d" />

The PR looks big indeed but we don't change any existing behavior of DeviceMesh, so it is a pure refactor.

With this refactoring we also enabled the slicing and flatten of non-contiguous dims of a device mesh which is hard to implement without cute layout.

This is a continue of https://github.com/pytorch/pytorch/pull/161106 (original one got messed with EasyCLA)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163213
Approved by: https://github.com/lw, https://github.com/fegin
2025-10-02 15:42:03 +00:00
27eb36debb DebugMode add ignore_compile_internals (#164205)
Fixes #164143

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164205
Approved by: https://github.com/albanD
2025-10-02 07:39:54 +00:00
3e03deab6f C++-accessible Placements via pybind11 (#163030)
This makes Placement data representation available in C++ via pybind11.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163030
Approved by: https://github.com/ezyang
2025-10-02 02:38:23 +00:00
8b29c59844 [CI][CUDA] Fix distributed tests for b200 (#164345)
This PR fixes the tests that were encountered in #159323.
Namely it fixes #162746 and #162745.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164345
Approved by: https://github.com/eqy
2025-10-02 01:13:49 +00:00
9065364995 Add xfailing test case for inplace mutation of local DTensor (#164355)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164355
Approved by: https://github.com/albanD
2025-10-01 23:16:26 +00:00
3ffaab3bc8 [Replicate][Pipeline Parallelism] integration of new replicate function with pipeline parallelism (#164031)
**Summary:** In order to test numerics for replicate + pp, stage.py needs to be able to call replicate's backward manually as pipeline parallelism doesn't have this feature.

**Test Case**
1.  pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164031
Approved by: https://github.com/weifengpy, https://github.com/H-Huang
ghstack dependencies: #163897
2025-10-01 18:01:16 +00:00
ebd0707578 [SymmMem] Add get_nbi the nonblocking version (#163540)
```Py
@triton.jit
def foo(dest, src):
    nvshmem.get_nbi(dest, src, 100, 0)
    # Some independent computation which overlaps with the get operation
    ...
    # Wait for completion of the get operation
    nvshmem.quiet()
```

Allows us to overlap comm and compute in the same kernel, instead of two kernels + signals.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163540
Approved by: https://github.com/ngimel, https://github.com/fegin
2025-10-01 17:50:24 +00:00
76ddbc2bbb Add option to FakeProcessGroup to raise error if comms are invoked. (#162841)
The current behavior is to do "nothing", which means you will corrupt
data.  If you're doing something similar to LocalTensor, where you're
overriding the behavior of collectives to do something numerically,
this can be unwelcome behavior.  If you can error when this happens
it can help prevent silent numerical incorrectness.

Authored with claude code.

Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162841
Approved by: https://github.com/dcci
2025-10-01 17:48:19 +00:00
ed90040d33 Releases multicast object before releasing mapped buffers in CUDASymmetricMemory (#163750)
Fixes: https://github.com/pytorch/pytorch/issues/162429. In B200, cuMulticastUnbind can error if the mapped buffers are free'd before the multicast object is free'd. The only documentation I could find is here: e11d7f77c1/src/transport/nvls.cc (L113).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163750
Approved by: https://github.com/ngimel, https://github.com/Skylion007, https://github.com/kwen2501, https://github.com/nWEIdia, https://github.com/cyyever
ghstack dependencies: #163575
2025-10-01 09:07:48 +00:00
8bb71c07c4 Skip symmetric memory tests calling _scaled_mm on CCC < 8.9 (#164251)
This avoids them failing on e.g. A100 GPUs with
> RuntimeError: torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164251
Approved by: https://github.com/Skylion007, https://github.com/kwen2501
2025-10-01 03:26:21 +00:00
60a4961ff4 [DTensor] Allow redistribute to Partial if src matches (#164253)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164253
Approved by: https://github.com/zpcore
2025-09-30 22:42:49 +00:00
2810977d3a [FSDP][Replicate] tests replicate type casting behavior and edge cases in mixed precision (#162861)
**Summary:** Ensures that replicate can handle the same type casting behavior and edge cases that fully shard can when mixed precision is used

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_float16_on_one_submodule
2. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_submodules_with_external_inputs
3. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_norm_modules_bf16
4. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_norm_modules_fp16
5. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_clamp_reduce_dtype
6. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k test_dataclass_input

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162861
Approved by: https://github.com/mori360
ghstack dependencies: #162830, #162836, #162839, #162851, #162853, #162855
2025-09-30 22:03:23 +00:00
ae4fd4ea75 [FSDP2] support AC(FSDP) for torchtitan's MOE (#164009)
for fsdp2 + EP, titan has fully_shard(AC(layer)) and fully_shard(layer.moe.experts): https://github.com/pytorch/torchtitan/issues/1624

for implicit prefetching, backward order is
* _pre_backward unshard (norm, output)
* _backward_prefetch unshard layers.6
* post_backward reshard (norm, output)
* _pre_backward unshard layers.6 (no-op, unsharded already)
* _backward_prefetch unshard layers.6.moe.experts
* recompute_fn pre_forward unshard layers.6.moe.experts (no-op, unsharded already)
* ~~recompute_fn post_forward reshard layers.6.moe.experts~~ <----- this PR make it a no-op
* _pre_backward unshard layers.6.moe.experts (no-op, unsharded already)
* _backward_prefetch unshard layers.5
* post_backward reshard layers.6.moe.experts
* post_backward reshard layers.6

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_comm.py -k test_set_modules_to_backward_prefetch_inside_ac`

before fix: `NGPU=4 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.expert_parallel_degree=2`
```
[rank0]:[titan] 2025-09-30 11:43:01,714 - root - INFO - step:  1  loss: 12.0162  grad_norm:  1.7315  memory: 45.64GiB(48.05%)  tps: 1,028  tflops: 10.87  mfu: 1.10%
[rank0]:[titan] 2025-09-30 11:43:01,714 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-30 11:43:35,233 - root - INFO - [GC] Performing periodical GC collection 0.06 seconds
[rank0]:[titan] 2025-09-30 11:43:35,987 - root - INFO - step: 50  loss:  6.9302  grad_norm:  0.9985  memory: 59.66GiB(62.80%)  tps: 11,712  tflops: 123.89  mfu: 12.53%
```

after fix: `NGPU=4 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.expert_parallel_degree=2`
```
[rank0]:[titan] 2025-09-30 11:38:57,377 - root - INFO - step:  1  loss: 12.0134  grad_norm:  1.6916  memory: 38.42GiB(40.45%)  tps: 805  tflops: 8.51  mfu: 0.86%
[rank0]:[titan] 2025-09-30 11:38:57,377 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-30 11:39:28,541 - root - INFO - [GC] Performing periodical GC collection 0.06 seconds
[rank0]:[titan] 2025-09-30 11:39:29,279 - root - INFO - step: 50  loss:  6.9346  grad_norm:  1.1875  memory: 52.58GiB(55.36%)  tps: 12,583  tflops: 133.10  mfu: 13.46%
```

for explicit prefetching, layers.6 backward prefetch layers.5 and layers.5.moe.experts. layers.6.moe.experts does not have explicit prefetch. backward order is like this
* _pre_backward unshard (norm, output)
* _prefetch_unshard layers.6
* post_backward reshard (norm, output)
* _pre_backward unshard layers.6 (no-op, unsharded already)
* _prefetch_unshard layers.5
* _prefetch_unshard layers.5.moe.experts
* recompute_fn pre_forward unshard layers.6.moe.experts
* ~~recompute_fn post_forward reshard layers.6.moe.experts~~ <----- this PR makes it a no-op
* _pre_backward unshard layers.6.moe.expert (no-op, unsharded already)
* post_backward reshard layers.6.moe.expert
* post_backward reshard layers.6

before fix: `NGPU=4 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.expert_parallel_degree=2`
```
[rank0]:[titan] 2025-09-30 11:53:24,574 - root - INFO - step:  1  loss: 12.0180  grad_norm:  1.6948  memory: 45.77GiB(48.18%)  tps: 849  tflops: 8.98  mfu: 0.91%
[rank0]:[titan] 2025-09-30 11:53:24,574 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-30 11:53:57,768 - root - INFO - [GC] Performing periodical GC collection 0.07 seconds
[rank0]:[titan] 2025-09-30 11:53:58,515 - root - INFO - step: 50  loss:  6.9358  grad_norm:  1.0528  memory: 59.80GiB(62.95%)  tps: 11,827  tflops: 125.10  mfu: 12.65%```
```

after fix: `NGPU=4 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.expert_parallel_degree=2`
```
[rank0]:[titan] 2025-09-30 12:08:39,404 - root - INFO - step:  1  loss: 12.0143  grad_norm:  1.7030  memory: 38.55GiB(40.58%)  tps: 988  tflops: 10.45  mfu: 1.06%
[rank0]:[titan] 2025-09-30 12:08:39,404 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-30 12:09:10,482 - root - INFO - [GC] Performing periodical GC collection 0.06 seconds
[rank0]:[titan] 2025-09-30 12:09:11,168 - root - INFO - step: 50  loss:  6.9356  grad_norm:  0.9911  memory: 52.81GiB(55.59%)  tps: 12,637  tflops: 133.68  mfu: 13.52%
```

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164009
Approved by: https://github.com/soulitzer
2025-09-30 22:02:24 +00:00
99e28ffab3 [FSDP][Replicate] tests replicate core functionality with mixed precision (#162855)
**Summary:** Ensures that replicate functionality works the same as fully shard's when mixed precision is used

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_mixed_precision.py -k TestReplicateMixedPrecisionTraining

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162855
Approved by: https://github.com/mori360
ghstack dependencies: #162830, #162836, #162839, #162851, #162853
2025-09-30 21:45:58 +00:00
01dd2c2b42 [FSDP][Replicate] tests replicate is composable with tp (#162853)
**Summary:** Proof that new replicate API is composable with TP

**Test Case**
1. pytest test/distributed/_composable/test_replicate_training.py -k test_replicate_tp

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162853
Approved by: https://github.com/mori360
ghstack dependencies: #162830, #162836, #162839, #162851
2025-09-30 21:29:54 +00:00
d3bdf8c32e [FSDP][Replicate] tests replicate with custom forward method (#162851)
**Summary: tests replicate works when users use custom forward methods**

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py -k test_register_fsdp_forward_method

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162851
Approved by: https://github.com/mori360
ghstack dependencies: #162830, #162836, #162839
2025-09-30 21:15:34 +00:00
1ce9563ff6 [FSDP][Replicate] tests replicate gradient accumulation and 1f1b microbatching (#162839)
**Summary:** In order to ensure that replicate acts as intended (a specialized version of hsdp) we need to make sure that it can pass the same tests that fully_shard can for training. The first test verifies Replicate works with gradient accumulation properly. The second verifies that replicate works correctly with a One-Forward-One-Backward (1F1B) pipeline parallelism schedule

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py -k test_gradient_accumulation
2. pytest test/distributed/_composable/test_replicate_training.py -k test_1f1b_microbatching

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162839
Approved by: https://github.com/mori360
ghstack dependencies: #162830, #162836
2025-09-30 21:00:16 +00:00
7f4c3e7d2f distributed/serialization: support zero sized tensors (#164198)
Fixes
```
[4] ValueError: both buffer length (0) and count (-1) must not be 0
```

Test plan:

```
pytest test/distributed/test_serialization.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164198
Approved by: https://github.com/amirafzali
2025-09-30 08:11:29 +00:00
6e5b4249a5 [DTensor][Export] Supporting exporting a model with DTensor params/inputs (#163609)
I experimented with 3 paths to get joint graph for DTensorized module and input

1. strict_export + aot_export_joint_with_descriptors
2. graph_capture + aot_export_joint_with_descriptors
3. aot_export_joint_with_descriptors alone

Added test to guard them.

1 doesn't work, as bw graph region is missing from the joint graph.
I am leaning towards making 2 the recommended path.
If 2 doesn't work going forward, we can fallback to 3.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163609
Approved by: https://github.com/tugsbayasgalan

Co-authored-by: suo <suo@fb.com>
2025-09-30 07:54:13 +00:00
7d59e37434 Add Comm-Compute Preserving Bucketer (#163960)
tl;dr performs bucketing while preserving comm-compute overlap.

In comm-compute overlap we will have a graph with:

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.

We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.

TODO:
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163960
Approved by: https://github.com/ruisizhang123, https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754, #163959
2025-09-30 04:53:58 +00:00
0d7994ca97 [inductor] do comm compute overlap at aten fx level (#163215)
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.

Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results

Other mis:
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)

For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.

fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3

bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163215
Approved by: https://github.com/IvanKobzarev
2025-09-30 04:53:58 +00:00