116 Commits

Author SHA1 Message Date
82fa2aa269 DTensor: Fix trivial as_strided case, add alias support (#166867)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166867
Approved by: https://github.com/albanD
ghstack dependencies: #166868
2025-11-04 07:18:32 +00:00
2de4cf2102 [1/N] Remove unused loop variables (#166258)
This PR removes unused loop variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166258
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
2025-10-30 12:22:25 +00:00
369f2d6951 [3/N] fix typo in other folders (#166606)
fix typo in other folders

#166374
#166126

_typos.toml
```bash
[files]
extend-exclude = ["tools/linter/dictionary.txt"]
[default.extend-words]
nd = "nd"
arange = "arange"
Nd = "Nd"
GLOBALs = "GLOBALs"
hte = "hte"
iy = "iy"
PN = "PN"
Dout = "Dout"
optin = "optin"
gam = "gam"
PTD = "PTD"
Sur = "Sur"
nin = "nin"
tme = "tme"
inpt = "inpt"
mis = "mis"
Raison = "Raison"
ouput = "ouput"
nto = "nto"
Onwer = "Onwer"
callibrate = "callibrate"
ser = "ser"
Metdata = "Metdata"
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166606
Approved by: https://github.com/ezyang
2025-10-30 10:30:40 +00:00
56a809aa07 [DTensor] Fix torch.all() using incorrect reduction operator (#165924)
Fixes #165923
Corrects the reduction operation to be product.

Enables "all" in the boolean tensor tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165924
Approved by: https://github.com/malfet, https://github.com/Skylion007
2025-10-29 20:58:35 +00:00
1dd6b76914 Revert "[1/N] Remove unused loop variables (#166258)"
This reverts commit 76b2c37045e52540ec51e967aa7b6436a6b9b174.

Reverted https://github.com/pytorch/pytorch/pull/166258 on behalf of https://github.com/atalman due to breaks test/distributed/test_serialization.py::TestSerialization::test_weights_only [GH job link](https://github.com/pytorch/pytorch/actions/runs/18894311802/job/53929321703) [HUD commit link](76b2c37045) ([comment](https://github.com/pytorch/pytorch/pull/166258#issuecomment-3460964612))
2025-10-29 11:10:37 +00:00
76b2c37045 [1/N] Remove unused loop variables (#166258)
This PR removes unused loop variables.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166258
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
2025-10-29 01:34:15 +00:00
572cc12b42 Move MaskPartial to placement_types to improve discoverability (#164414)
Had trouble finding this one myself in #163030.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164414
Approved by: https://github.com/ezyang
2025-10-28 21:56:02 +00:00
7d16fcf2df Re-re-re-re-apply "C++-accessible Placements via pybind11 (#163030)" (#166132)
Was reverted (again!) due to a merge conflict that crept in sometime during the "export to github -> land internally -> merge on github" process.

D85096233

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166132
Approved by: https://github.com/Skylion007, https://github.com/ezyang, https://github.com/malfet
2025-10-27 21:19:32 +00:00
483845a9c4 [DTensor][Op] fix for DTensor ops with Partial placements (#165962)
**Summary:** When operations are done on partial placements, we use sharding logic to incorrectly determine whether we should redistribute the tensor to replicate. By delaying the redistribution, we do the operation first, and then the partial reduction. This leads to incorrect results for max, min, gradient norm clipping, and more. We solve this by setting reduction_linear to False when there is a Partial placement to force the redistribution before completing the op.

**Test Cases**
1. pytest test/distributed/tensor/test_math_ops.py -k test_partial_reduction_ops
2. pytest test/distributed/tensor/test_math_ops.py -k test_matching_partial_reduction_ops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165962
Approved by: https://github.com/wconstab
2025-10-27 21:17:13 +00:00
f863550192 [dtensor] fix incorrect norm calculation for Partial DTensors (#159856)
The sharding strategies for `aten.linalg_vector_norm` and the optimized `aten._foreach_norm.Scalar` incorrectly assumes the norm operation is always "reduction linear" with respect to its inputs. This bug causes the norm to be computed on local, incomplete data for DTensors with a `Partial(sum)` placement, leading to an inflated result (a sum of norms, rather than the correct norm of the sum).

The error can be reproduced with the following script:
```python
import os
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Partial, Replicate, Shard

def setup_distributed():
    """Initializes the distributed environment."""
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    dist.init_process_group("nccl")
    torch.cuda.set_device(local_rank)

    print(f"Initialized process {rank}/{world_size} on GPU {local_rank}")
    return rank, world_size

rank, world_size = setup_distributed()
assert world_size == 2, "Please run with exactly 2 GPUs for this minimal repro."

mesh = init_device_mesh("cuda", (world_size,))

if rank == 0:
    local_partial = torch.tensor([1.0, 3.0], dtype=torch.float32)
else:
    local_partial = torch.tensor([2.0, 1.0], dtype=torch.float32)

partial_dtensor = DTensor.from_local(local_partial, mesh, [Partial("sum")])
partial_result = torch.linalg.vector_norm(partial_dtensor)
print(
    f"[Rank {rank}] partial_result: {partial_result}, full_tensor: {partial_result.full_tensor()}"
)

shard_dtensor = partial_dtensor.redistribute(mesh, [Shard(0)])
shard_result = torch.linalg.vector_norm(shard_dtensor)
print(
    f"[Rank {rank}] shard_result: {shard_result}, full_tensor {shard_result.full_tensor()}"
)

replicate_dtensor = partial_dtensor.redistribute(mesh, [Replicate()])
replicate_result = torch.linalg.vector_norm(replicate_dtensor)
print(
    f"[Rank {rank}] replicate_result: {replicate_result}, full_tensor {replicate_result.full_tensor()}"
)

full_tensor = partial_dtensor.full_tensor()
full_result = torch.linalg.vector_norm(full_tensor)
print(f"[Rank {rank}] correct_result: {full_result}")
```

Run results show that the norm is `sqrt(1**2 + 3**2) + sqrt(2**2 + 1**2) = sqrt(10) + sqrt(5) = 5.398` instead of `sqrt(3**2 + 4**2) = 5`.
```
$ torchrun --local-ranks-filter 0 --nproc-per-node 2 script.py
Initialized process 0/2 on GPU 0
[Rank 0] partial_result: DTensor(local_tensor=3.1622776985168457, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),)), full_tensor: 5.398345947265625
[Rank 0] shard_result: DTensor(local_tensor=3.0, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(_NormPartial(reduce_op='sum', norm_type=2),)), full_tensor 5.0
[Rank 0] replicate_result: DTensor(local_tensor=5.0, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Replicate(),)), full_tensor 5.0
[Rank 0] correct_result: 5.0
$ torchrun --local-ranks-filter 1 --nproc-per-node 2 script.py
Initialized process 1/2 on GPU 1
[Rank 1] partial_result: DTensor(local_tensor=2.2360680103302, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),)), full_tensor: 5.398345947265625
[Rank 1] shard_result: DTensor(local_tensor=4.0, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(_NormPartial(reduce_op='sum', norm_type=2),)), full_tensor 5.0
[Rank 1] replicate_result: DTensor(local_tensor=5.0, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Replicate(),)), full_tensor 5.0
[Rank 1] correct_result: 5.0
```

This fix simply forces `reduction_linear=False` for partial placements. The output becomes:
```
$ python -m torch.distributed.run --local-ranks-filter 0 --nproc-per-node 2 script.py
Initialized process 0/2 on GPU 0
[Rank 0] partial_result: DTensor(local_tensor=5.0, device_mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(Replicate(),)), full_tensor: 5.0
[Rank 0] shard_result: DTensor(local_tensor=3.0, device_mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(_NormPartial(reduce_op='sum', norm_type=2),)), full_tensor 5.0
[Rank 0] replicate_result: DTensor(local_tensor=5.0, device_mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(Replicate(),)), full_tensor 5.0
[Rank 0] correct_result: 5.0
$ python -m torch.distributed.run --local-ranks-filter 1 --nproc-per-node 2 script.py
Initialized process 1/2 on GPU 1
[Rank 1] partial_result: DTensor(local_tensor=5.0, device_mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(Replicate(),)), full_tensor: 5.0
[Rank 1] shard_result: DTensor(local_tensor=4.0, device_mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(_NormPartial(reduce_op='sum', norm_type=2),)), full_tensor 5.0
[Rank 1] replicate_result: DTensor(local_tensor=5.0, device_mesh=DeviceMesh((2,), device: 'cuda', stride: (1,)), placements=(Replicate(),)), full_tensor 5.0
[Rank 1] correct_result: 5.0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159856
Approved by: https://github.com/ezyang
2025-10-26 05:58:44 +00:00
8f80892359 Use correct pyrefly syntax in suppressions distributed/... (#166241)
Updates the pyrefy-ignores in the torch/distributed directory to use the correct syntax. No functional changes.

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166241
Approved by: https://github.com/oulgen
2025-10-26 04:16:41 +00:00
c4f6619330 Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)
- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA
devices so that it can be reset before execution of the op for each such that ops
with randomness produces the same result for all ranks (note that we are planning a
separate change to add support of per rank rng state). Previously we relied on
op input arguments to deduce which devices to get rng state from. Which doesn't work
for factory functions such torch.randn. Hence this changes switches to uncondionally
collecting rng state from all devices.

- Fixing per rank specific computations in _MaskedPartial and Shard placements discovered
during test enablement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716
Approved by: https://github.com/ezyang
2025-10-18 23:33:24 +00:00
beb6b62e8c Revert "Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)"
This reverts commit 1b397420f22b22f90a1093233ecd9167656e50cb.

Reverted https://github.com/pytorch/pytorch/pull/165716 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165716#issuecomment-3418083391))
2025-10-18 09:15:49 +00:00
1b397420f2 Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)
- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA
devices so that it can be reset before execution of the op for each such that ops
with randomness produces the same result for all ranks (note that we are planning a
separate change to add support of per rank rng state). Previously we relied on
op input arguments to deduce which devices to get rng state from. Which doesn't work
for factory functions such torch.randn. Hence this changes switches to uncondionally
collecting rng state from all devices.

- Fixing per rank specific computations in _MaskedPartial and Shard placements discovered
during test enablement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716
Approved by: https://github.com/ezyang
2025-10-17 23:28:22 +00:00
9b6be53326 [distributed] Replace 94 assert statements in tensor ops files (#165229)
Replace assert statements with explicit if/raise patterns in:
- _math_ops.py (43 asserts)
- _matrix_ops.py (27 asserts)
- _view_ops.py (24 asserts)

This prevents assertions from being disabled with Python -O flag.

Fixes partially #164878.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165229
Approved by: https://github.com/albanD
2025-10-14 17:28:06 +00:00
45b8c0f75c [distributed] Replace 54 assert statements in tensor/_ops/_tensor_ops.py (#165226)
Replace assert statements with explicit if/raise patterns to prevent assertions from being disabled with Python -O flag.

Fixes partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165226
Approved by: https://github.com/albanD
2025-10-14 15:10:03 +00:00
770e6b910c [DTensor] Extend conv ops to 3D (#165241)
Current implementation hardcodes 4D input and output tensor shapes

Change that by computing `output_conv_shape` for any number of input dims
Replace `[.., .., .., slice]` with `[..., slice]`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165241
Approved by: https://github.com/ezyang
2025-10-14 02:30:46 +00:00
5e58420dff LocalTensor (#164537)
A LocalTensor is a tensor subclass which simulates a tensor that is
distributed across SPMD ranks.  A LocalTensor might be size N, but in fact
there are world_size shards/replicas of it stored internally.  When you do a
plain PyTorch operation on it, we apply the operation to each shard; when you
do a collective, we do the mathematically equivalent operation on the local
shards.  A LocalTensor is associated with a list of ranks which specify
which ranks it holds local tensors for.

NB, this is NOT a DataParallel like abstraction where you can run operations
on multiple different GPUs. It is intended purely for *debugging* purposes,
the overhead is almost certainly too high to keep eight GPUs (even the C++
autograd needs multithreading to keep up!)  (It might potentially be possible
to trace through this with torch.compile and then compile it with CUDA graphs
but this is currently a non-goal.)

In order to handle MPMD, we provide a helper decorator that allows you to
run a function with no side effects for each LocalTensor shard and combine
results back into LocalTensor or LocalIntNode.

Note: This PR convert all DTensor ops and some DTensor tests to illustrate
intended usage and ensure conrrectness. In subsequent PR more tests will be
converted. DUring test conversion we aim to share as much as possible of
test logic between multi-process / multi-threaded and local tensor tests.
We would like to developers to be able to run both flavors of the tests.

Note: This work is based on the original proposal
by @ezyang (WIP PR https://github.com/pytorch/pytorch/pull/162753).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164537
Approved by: https://github.com/ezyang
2025-10-12 20:06:41 +00:00
7457d139c5 Add pyrefly suppressions to torch/distributed (7/n) (#165002)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

One more PR after this one.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002
Approved by: https://github.com/oulgen
2025-10-09 04:08:25 +00:00
df640df68a Revert "Reapply "C++-accessible Placements via pybind11 (#163030)" (#164519)"
This reverts commit 8c0bc879b97bc580aaa0777b2d266bdd068cb528.

Reverted https://github.com/pytorch/pytorch/pull/164519 on behalf of https://github.com/malfet due to Still breaks internal workflows ([comment](https://github.com/pytorch/pytorch/pull/164519#issuecomment-3378469432))
2025-10-07 19:46:17 +00:00
8c0bc879b9 Reapply "C++-accessible Placements via pybind11 (#163030)" (#164519)
This makes Placement data representation available in C++ via pybind11. Reapply with fix for internal errors.

D83788896

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164519
Approved by: https://github.com/Skylion007, https://github.com/ezyang
2025-10-06 23:19:14 +00:00
35c4130fd1 [2/N] Fix ruff warnings (#164460)
Apply ruff `SIM` rules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164460
Approved by: https://github.com/ezyang
2025-10-04 03:40:32 +00:00
f6f7676756 Revert "C++-accessible Placements via pybind11 (#163030)"
This reverts commit 3e03deab6f3c268c85c8efd9546e28cdda0fa4cc.

Reverted https://github.com/pytorch/pytorch/pull/163030 on behalf of https://github.com/swolchok due to doesn't pass pyre ([comment](https://github.com/pytorch/pytorch/pull/163030#issuecomment-3362450379))
2025-10-02 18:25:24 +00:00
3e03deab6f C++-accessible Placements via pybind11 (#163030)
This makes Placement data representation available in C++ via pybind11.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163030
Approved by: https://github.com/ezyang
2025-10-02 02:38:23 +00:00
85012fe167 Remove unnecessary list comprehensions (#164103)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164103
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
2025-09-30 03:56:54 +00:00
da003d7b95 [3/N] Import Callable from collections.abc in torch/distributed (#164104)
This is the result of applying the ruff `UP035` check.
`Callable` is imported from `collections.abc` instead of `typing`.
This PR is the follow-up of #164054.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164104
Approved by: https://github.com/Skylion007
2025-09-30 00:28:53 +00:00
suo
c58e096cd0 [DTensor] implement logsumexp (#163879)
as title, mostly copypasta from internal. I am a dtensor noob, so please scrutinize my added test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163879
Approved by: https://github.com/XilunWu
2025-09-25 23:08:30 +00:00
92f7361e27 [DTensor] fix uneven _StridedShard (#163843)
Previous uneven `_StridedShard` in https://github.com/pytorch/pytorch/pull/150490 seems failing cases like sharding `tensor = torch.arange(6)` with FSDP 2, TP 2.

This PR attempts to reinvent `_StridedShard`.

I didn't test nested `_StridedShard`, because there shouldn't be any use cases. I think it will become quite messy when it comes to **nested uneven** `_StridedShard`. We are probably going to deprecate it anyway after @zpcore 's work https://github.com/pytorch/pytorch/pull/160266 on ordered sharding, so IMO not worth it to make it too general.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163843
Approved by: https://github.com/ezyang
2025-09-25 22:12:29 +00:00
bb7c9a2d41 [DTensor] Fix DTensor.mean with uneven sharding (#163241)
Fixes #162692

When input is uneven sharded, redistribute input as Replicated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163241
Approved by: https://github.com/dcci
2025-09-18 19:53:51 +00:00
821458d97a [dynamo][hop] Introduce Local Map HOP (#161458)
Can't actually deploy it because of: https://github.com/pytorch/pytorch/issues/161456

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161458
Approved by: https://github.com/ydwu4
2025-09-17 09:32:38 +00:00
e7c3f802ff Revert "[dynamo][hop] Introduce Local Map HOP (#161458)"
This reverts commit 505458db803e1ffabac08a2fc150b566d3ea3a57.

Reverted https://github.com/pytorch/pytorch/pull/161458 on behalf of https://github.com/jeffdaily due to broke rocm tests ([comment](https://github.com/pytorch/pytorch/pull/161458#issuecomment-3299230458))
2025-09-16 15:14:36 +00:00
505458db80 [dynamo][hop] Introduce Local Map HOP (#161458)
Can't actually deploy it because of: https://github.com/pytorch/pytorch/issues/161456

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161458
Approved by: https://github.com/ydwu4
2025-09-16 00:37:40 +00:00
8590c3a66b [DTensor] Add _foreach_pow to sharding propagation list. (#162895)
Fixes #152696

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162895
Approved by: https://github.com/ezyang
2025-09-15 21:14:06 +00:00
435c18fb4a [DTensor] add op support for aten.unbind.int (#162560)
As titled.

It seems unbind returns views of the original tensor. E.g. see https://stackoverflow.com/questions/78910951/does-unbind-return-the-views-of-tensors-in-pytorch

So we error out when `shard_dim == unbind_dim`. This is similar to why we error out in view ops.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_view_ops.py#L544-L546

This PR also refactors some other tensor ops code, by creating two utils function `shift_shard_dims_after_insert`, `shift_shard_dims_after_remove`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162560
Approved by: https://github.com/zpcore
2025-09-11 00:58:23 +00:00
e60ad4f628 [DTensor] fix copy_ strategy to support linearity (#162460)
Fixing issue introduced in https://github.com/pytorch/pytorch/pull/158538
where `aten.copy_.default` is registered as a pointwise op, but without linearity.

In particular, when both `src` and `dst` tensors have same `Partial` placements, direct copy should happen without redistribute, instead of redistributing both to `Replicate` before making the copy.

This was discovered from silent incorrect results e.g. on `torch.einsum` backward.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162460
Approved by: https://github.com/zpcore
2025-09-10 00:47:14 +00:00
e381d4b020 [DTensor] forbid view ops to redistribute when local split is impossible (#161950)
This PR is a followup to https://github.com/pytorch/pytorch/pull/149764.

In that PR, it only forbids illegal view due to `Flatten`; this PR also forbids illegal view caused by `Split`.

This PR also updates the error message to be less about internal implementation details, which users may find confusing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161950
Approved by: https://github.com/ezyang
2025-09-03 04:40:11 +00:00
6c05ea6475 [DTensor] add op support: aten.squeeze_.dim (#159532)
**Summary**
This PR enables in-place op `aten.squeeze_.dim` on DTensor with a change to
DTensor dispatch logic: when processing in-place operator, we should assign
`output_sharding.output_spec` back to the first argument. This is because
the in-place op_call on `arg._local_tensor` could also shift the tensor meta.

**Test**
`pytest test/distributed/tensor/test_view_ops.py -s -k  test_squeeze_`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159532
Approved by: https://github.com/zpcore
2025-08-14 18:01:19 +00:00
b4596895b9 [DTensor] Registers sharding rule for rms_norm (#159692)
Reduces collective calls in the forward pass from 2 to 1

In #158716 I added the sharding rule for the backward pass but didn't add the forward pass as it didn't get dispatched. After #159324 this should get properly dispatched hence I am adding it now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159692
Approved by: https://github.com/tianyu-l
2025-08-12 21:05:24 +00:00
8a37f0c903 improve gather and scatter_add strategy (#160140)
As title.

This PR made a small fix on top of https://github.com/meta-pytorch/autoparallel/pull/81.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160140
Approved by: https://github.com/fmassa
2025-08-08 15:06:24 +00:00
e248719ac0 [DTensor] support _StridedShard in view op (#159656)
**Summary**
Some thoughts on view-op and `_StridedShard` interaction:
1. `_StridedShard` has no impact on sharding (i.e. how tensor is partitioned)
compared to `Shard`. It only changes how shards permute across the devices.
2. `view()` op on DTensor strictly forbids shard redistribution which means if
`view()` may cause shard permutation across devices, it should be rejected.
This is enforced in today's sharding prop for `view()`.
3. Since DTensor `view()` won't introduce any redistribution, it's certain that
`placements` won't change except the inner `dim` attribute of `Shard`
or `_StridedShard`.

Therefore, to support `_StridedShard` in `view()` op, the only change required
is to keep `_StridedShard` as `_StridedShard` in the output spec.

**Test**
`pytest test/distributed/tensor/test_view_ops.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159656
Approved by: https://github.com/wconstab
2025-08-07 15:59:25 +00:00
9f753f8c0d [DTensor] Improve sort strategy (#159189)
- Sort strategy now supports sharding on non sorted dim.
~~- Fix histc xfail.~~
  - ~~Previously `python test/distributed/tensor/test_dtensor_ops.py TestDTensorOpsCPU.test_dtensor_op_db_histc_cpu_float32` will fail with `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=18`. However, if we run `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=18 python test/distributed/tensor/test_dtensor_ops.py TestDTensorOpsCPU.test_dtensor_op_db_histc_cpu_float32`, the test will pass. This kind of error is due to DTensor reuses the strategy schema hashing. It turns out that not only the strategy,  the result correctness also depends on `static_argnum` or the op will reuse the previous args from hashed schema and output wrong results. I updated the document also.~~ (fixed in https://github.com/pytorch/pytorch/pull/159289)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159189
Approved by: https://github.com/XilunWu
2025-07-31 21:52:42 +00:00
5cf77a0ea2 Fix redistribution costs for slice_scatter (#159223)
We were previously assuming that the `input_strategy == src_strategy`, which is not true in all cases.

This should fix this.

On the side, I also realized that for `slice_scatter` some DTensorSpecs don't have TensorMeta, e.g., https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_tensor_ops.py#L524

It would be good to fix it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159223
Approved by: https://github.com/ezyang, https://github.com/wconstab
2025-07-29 12:00:39 +00:00
7dafab6a93 Fix SDPA sharding when return_debug_mask is False (#159205)
If `return_debug_mask` is False (which is the default value for SDPA), the attention tensor returned is an empty tensor (which has 0 dimensions). This means that the shardings for the batch and CP case are that are passed can yield invalid dimensions.

This PR fixes it for `scaled_dot_product_flash_attention_strategy`.  Note that `scaled_dot_product_cudnn_attention_strategy` doen't have this issue

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159205
Approved by: https://github.com/wconstab
2025-07-26 17:41:42 +00:00
7f266020de add softmax_backward_strategy missing field (#159167)
Add input_specs in softmax_backward_strategy, as is needed by AutoParallel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159167
Approved by: https://github.com/XilunWu
2025-07-26 00:53:53 +00:00
5be7e187ba Support sort and scatter_add strategy (#159022)
Add `sort`, `scatter_add` strategy.  I am reusing the strategy for `scatter` related ops for a quick support. The strategy can be potential improved after we fix index related strategies.

Minor fix: fix `replicate_op_strategy` to support output multiple tensors, which is required by aten.sort.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159022
Approved by: https://github.com/XilunWu, https://github.com/wconstab
2025-07-24 18:33:18 +00:00
d8425e9c75 [1/N] support of replication fallback strategy (#158046)
#### 1. Provide a default fallback strategy that can apply to arbitrary operator with output in type of single tensor.

We can call register_op_strategy to register using the `fallback_op_strategy`:
- For op without List[Tensor] as input, call:
```
register_op_strategy(op_overload)(replicate_op_strategy)
```
- For op contains List[Tensor] as input, call:
```
register_op_strategy(op_overload, schema_info=RuntimeSchemaInfo(needs_pytree=True))(replicate_op_strategy)
```
The strategy will force all input and output to be replicated with the corresponding redistribute_cost.

#### 2. Add a test function as a necessary condition for strategy function.
```
detect_exists_identical_opspec(*args, op, mesh, strategy_function)
```
This function detects if identical strategies will be produced given the sample `args`. It will iterate all combinations of placements for each arg and produce the output strategy from the registered `strategy_function`.

#### 3. Provide a context manger `op_strategy_context` to easily register/unregister strategies for testing.
E.g.,
```
with op_strategy_context(test_op.default, replicate_op_strategy):
    ...
```
#### 4. Fix a bug that TupleStrategy never get flatten as expected:
9df0176408/torch/distributed/tensor/_op_schema.py (L286)
Basically we need to 1) register_pytree_node for TupleStrategy, 2) propagate the schema_info to `strategy_schema` after  `strategy_schema = _wrap_with_op_strategy(op_schema)`.

This is the first implementation. Plan to add support to enable sharding on the batch dim as the output strategy next.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158046
Approved by: https://github.com/wanchaol, https://github.com/wconstab
2025-07-23 21:14:20 +00:00
04a393507b Fused RMSNorm implementation (#153666)
Relevant #72643

Benchmarked versus unfused torch implementation and torch.compile implementation. Around 9x speedup vs unfused implementation on cuda and slightly faster vs inductor compile on 5090.

```py
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm_x = x.norm(2, dim=-1, keepdim=True)
        rms_x = norm_x * torch.rsqrt(torch.tensor(x.shape[-1], dtype=x.dtype))
        x_normed = x / (rms_x + self.eps)
        return self.scale * x_normed

def benchmark_rmsnorm_cuda(input_shape, normalized_dim, num_iterations=100, warmup_iterations=10, dtype=torch.float16):
    rms_norm_layer = torch.nn.RMSNorm(normalized_dim, device='cuda', dtype=dtype)
    input_data = torch.randn(input_shape, device='cuda', dtype=dtype)

    for _ in range(warmup_iterations):
        _ = rms_norm_layer(input_data)
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for _ in range(num_iterations):
        _ = rms_norm_layer(input_data)

    end_event.record()
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    avg_time_ms = elapsed_time_ms / num_iterations

    print(f"--- RMSNorm CUDA Benchmark ---")
    print(f"Input Shape: {input_shape}")
    print(f"Normalized Dimension: {normalized_dim}")
    print(f"Benchmark Iterations: {num_iterations}")
    print(f"--- Fused Implementation ---")
    print(f"Average Time per Iteration: {avg_time_ms:.4f} ms")
    print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms")

    compiled_rms_norm = torch.compile(RMSNorm(dim=normalized_dim)).cuda()
    for _ in range(warmup_iterations):
        _ = compiled_rms_norm(input_data)
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for _ in range(num_iterations):
        _ = compiled_rms_norm(input_data)
    end_event.record()
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    avg_time_ms = elapsed_time_ms / num_iterations

    print(f"--- TorchCompile Implementation ---")
    print(f"Average Time per Iteration: {avg_time_ms:.4f} ms")
    print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms")

    print("-" * 50)

if __name__ == '__main__':
    parameter_sets = [
        {'batch_size': 16, 'sequence_length': 256, 'hidden_features': 512, 'dtype': torch.float16},
        {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float16},
        {'batch_size': 64, 'sequence_length': 1024, 'hidden_features': 1024, 'dtype': torch.float16},
        {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float32},
        {'batch_size': 8, 'sequence_length': 2048, 'hidden_features': 2048, 'dtype': torch.float16},
    ]

    num_benchmark_iterations = 200
    num_warmup_iterations = 20

    for params in parameter_sets:
        batch_size = params['batch_size']
        sequence_length = params['sequence_length']
        hidden_features = params['hidden_features']
        data_type = params.get('dtype', torch.float16)

        shape = (batch_size, sequence_length, hidden_features)
        norm_dim_to_normalize = hidden_features

        print(f"Benchmarking with: BS={batch_size}, SeqLen={sequence_length}, Hidden={hidden_features}, DType={data_type}")
        benchmark_rmsnorm_cuda(input_shape=shape,
                               normalized_dim=norm_dim_to_normalize,
                               num_iterations=num_benchmark_iterations,
                               warmup_iterations=num_warmup_iterations,
                               dtype=data_type)
```

Here are the triton compile tests ran on a 5090 (comparing this branch vs main)
```py
import torch
import torch.nn as nn
from torch._inductor.utils import run_and_get_code, run_fw_bw_and_get_code

torch.manual_seed(0)

device = torch.device("cuda")

for batch in range(0, 9):
    for i in range(9, 16):
        normalized_shape_arg = (2**batch, 2**i)
        input_tensor = torch.randn(2**batch, 2**i, device=device, requires_grad=True)
        weight_tensor = torch.randn(2**batch, 2**i,device=device, requires_grad=True)

        model = torch.nn.functional.rms_norm
        compiled_model = torch.compile(model)
        loss = torch.randn_like(input_tensor)

        num_iter = 5
        for j in range(num_iter):
            output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor)
            output.backward(loss)

        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()
        num_iter = 10
        for j in range(num_iter):
            output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor)
            output.backward(loss)

        end_event.record()
        torch.cuda.synchronize()

        elapsed_time_ms = start_event.elapsed_time(end_event)
        avg_time_ms = round(elapsed_time_ms / num_iter, 5)
        print(2**batch, 2**i, avg_time_ms)
```
main
```
32 512 0.1812
32 1024 0.19021
32 2048 0.18871
32 4096 0.17019
32 8192 0.21944
32 16384 0.38871
32 32768 0.83282
64 512 0.14705
64 1024 0.13987
64 2048 0.14111
64 4096 0.21699
64 8192 0.43141
64 16384 0.90652
64 32768 2.18573
128 512 0.19361
128 1024 0.1963
128 2048 0.20122
128 4096 0.38888
128 8192 0.93795
128 16384 2.23437
128 32768 5.50079
256 512 0.16722
256 1024 0.22856
256 2048 0.39421
256 4096 0.96621
256 8192 2.48746
256 16384 5.53571
256 32768 11.97932
```
current branch
```
32 512 0.16328
32 1024 0.18104
32 2048 0.15508
32 4096 0.14356
32 8192 0.20111
32 16384 0.45974
32 32768 0.94799
64 512 0.16874
64 1024 0.18701
64 2048 0.16107
64 4096 0.20152
64 8192 0.46568
64 16384 0.96599
64 32768 2.21661
128 512 0.14982
128 1024 0.15565
128 2048 0.22241
128 4096 0.46128
128 8192 0.88883
128 16384 2.3097
128 32768 5.84448
256 512 0.14346
256 1024 0.2007
256 2048 0.45927
256 4096 0.87876
256 8192 2.10571
256 16384 5.73948
256 32768 12.98581
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153666
Approved by: https://github.com/ngimel, https://github.com/albanD
2025-07-22 22:25:44 +00:00
187c2deb40 Fix clamp(min/max) strategy (#158619)
Part of plan https://github.com/pytorch/pytorch/issues/157495.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158619
Approved by: https://github.com/wanchaol
2025-07-21 23:26:08 +00:00
a3aacd6cb2 [DTensor] fix copy_ strategy (#158538)
The previous strategy directly used 'self' input strategy for 'src'
input.  The fixed strategy correctly maps the self dim to src dim
so that it works even if the src input is broadcast.

E.g. for this program, broadcasting will occur on dims 0,1,3 of self.

```
self = torch.ones((2,3,4,5))
src = torch.ones((4,1))
self.copy_(src)
```

These are the correct sharding combinations:

|   self   |     src |
|-------|------|
| Shard(0)  |   Replicate() |
| Shard(1)  |   Replicate() |
| Shard(2)  |   Shard(0) |
| Shard(3)  |   Shard(1) |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158538
Approved by: https://github.com/zpcore, https://github.com/XilunWu, https://github.com/wanchaol
ghstack dependencies: #158490
2025-07-18 23:44:43 +00:00
36bddcd18c [DTensor] Fix default_strategy and rename for clarity (#158490)
Fixes several bugs in the original.
- foremost, fixes a serious bug where we returned incorrect strategies
  by mixing input_specs that were frozen from
  select_strategy.strategies[0] with output_specs that varied across
  select_strategy.strategies[0..N] (e.g. we could create a nonsense
  strategy like input:Shard(0) output(Replicate) for an op like clone
- fixes the redistribute costs: they should not actually be 0, they
  should be the cost of redistributing our single input from another
  strategy to the current strategy, in our list of output strategies
- adds a note, wondering if we should have just literally returned the
  input strategy instead of creating this new object
- Currently, using default_strategy is incorrect becuase it maps 'self'
  tensor's strategies directly onto 'src' tensor without accounting for
  the fact that copy_ supports broadcasting a smaller rank tensor into a
  larger one.

Separates out copy_  op from default strategy, adds missing test case,
but does not fix the underlying issue with copy_, leaves that for future
PR

Renames to `propagate_single_input_strategy` since that's more
descriptive

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158490
Approved by: https://github.com/wanchaol, https://github.com/XilunWu
2025-07-18 23:44:42 +00:00