499 Commits

Author SHA1 Message Date
0b0eea2229 [dtensor] move pad/unpad_tensor to separate utils (#124871)
as titled, 1. pad/unpad is a general util not specific to the Shard
placement, 2. for the propose of the next PR, move these two out of Shard
placement itself, and give additional pad_dim argument

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124871
Approved by: https://github.com/awgu, https://github.com/wz337
2024-04-25 03:36:16 +00:00
7809b34288 [DTensor][Easy] Update OpSchema __repr__ to show args_schema in format print (#124812)
When printing op_schema with `print(f"{op_schema=}")`:

Before -- can't view into the OpStrategy/TupleStrategy in format print:
```
# A pointwise strategy
op_schema=OpSchema(op=aten.relu.default, args_schema=(<torch.distributed._tensor.op_schema.OpStrategy object at 0x7f4e763e0520>,), kwargs_schema={})
# A pointwise strategy
pointwise_strategy -- op_schema=OpSchema(op=aten.threshold_backward.default, args_schema=(<torch.distributed._tensor.op_schema.OpStrategy object at 0x7f4e763e1540>, <torch.distributed._tensor.op_schema.OpStrategy object at 0x7f4e763e1510>, 0), kwargs_schema={})
# A tuple strategy
op_schema=OpSchema(op=aten._foreach_lerp_.Scalar, args_schema=(<torch.distributed._tensor.op_schema.TupleStrategy object at 0x7f4e763e31f0>, <torch.distributed._tensor.op_schema.TupleStrategy object at 0x7f4e763e3460>, 0.09999999999999998), kwargs_schema={})
```

After -- printing out the OpStrategy/TupleStrategy string:
```
# A pointwise strategy
op_schema=OpSchema(op=aten.relu.default, args_schema=(OpStrategy:[None -> R] @ mesh: (4,)), kwargs_schema={})
# A pointwise strategy
op_schema=OpSchema(op=aten.threshold_backward.default, args_schema=(OpStrategy:[None -> R] @ mesh: (4,), OpStrategy:[None -> R] @ mesh: (4,), 0), kwargs_schema={})
# A tuple strategy
op_schema=OpSchema(op=aten._foreach_lerp_.Scalar, args_schema=(TupleStrategy(OpStrategy:[None -> S(0)] @ mesh: (4,)), TupleStrategy(OpStrategy:[None -> S(0)] @ mesh: (4,)),0.09999999999999998), kwargs_schema={})
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124812
Approved by: https://github.com/wanchaol
2024-04-24 21:34:39 +00:00
29cc293725 [BE]: FURB142 - Remove set mutations. Use set update (#124551)
Uses set mutation methods instead of manually reimplementing (update, set_difference etc).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124551
Approved by: https://github.com/ezyang
2024-04-21 14:12:33 +00:00
ddd0ed1b43 distributed: templated ring attention (#124215)
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR.

This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way.

Misc changes:

* Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test
* Adds compile support to the ring attention implementations (required some tweaks to process groups)

Test plan:

```
pytest test/distributed/_tensor/test_attention.py
pytest test/distributed/test_functional_api.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124215
Approved by: https://github.com/wanchaol
2024-04-19 00:57:08 +00:00
b3f88317ec [dtensor][5/N] have table-wise sharding use LocalShardsWrapper on participating ranks only (#122853)
**Summary**
We wrap DTensor's local tensor in `LocalShardsWrapper` for torchrec's table-wise sharding. The exception is on non-participating ranks: for non-participating ranks, the local tensor is an empty torch.Tensor object. The reason of this design is to avoid complexity on supporting empty tensor case on `LocalShardsWrapper`.

**Test**
`torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e table-wise`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122853
Approved by: https://github.com/wz337
ghstack dependencies: #120265, #121392, #122843
2024-04-16 22:27:30 +00:00
d419fcd19f [dtensor][4/N] have row-wise sharding always use LocalShardsWrapper (#122843)
**Summary**
Always wrap local tensor into a `LocalShardsWrapper`. This is for uniformity and it leads to easiness on adoption of DTensor as a wrapper for local shard(s) representation. To support more tensor ops over `LocalShardsWrapper`, users need to extend its `__torch_dispatch__`.

**Test**
`torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e row-wise-even`

**Result**
```
Row-wise even sharding example in DTensor
         Col 0-15
-------  ----------
Row 0-1  cuda:0
Row 2-3  cuda:1
Row 4-5  cuda:2
Row 6-7  cuda:3
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122843
Approved by: https://github.com/wz337
ghstack dependencies: #120265, #121392
2024-04-16 22:27:30 +00:00
1d7ac7baa0 [dtensor][3/N] add torchrec row-wise uneven sharding example (#121392)
**Test**
`torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e row-wise-uneven`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121392
Approved by: https://github.com/wanchaol
ghstack dependencies: #120265
2024-04-16 22:27:28 +00:00
9d3543df9a [dtensor][2/N] add torchrec table-wise sharding example (#120265)
**Summary**
This PR serves as a start of this effort by adding an example test that represents TorchRec's `ShardingType.TABLE_WISE` using DTensor.

**Test**
`torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e table-wise`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120265
Approved by: https://github.com/wanchaol
2024-04-16 22:27:24 +00:00
9d88339b53 Revert "make sure dynamo doesn't inline DTensor __new__ or __torch_dispatch__ (#123347)"
This reverts commit 63dcb5b0f2ef3578e81841fd8a2166e732c0ca99.

Reverted https://github.com/pytorch/pytorch/pull/123347 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/123347#issuecomment-2059994989))
2024-04-16 22:08:24 +00:00
63dcb5b0f2 make sure dynamo doesn't inline DTensor __new__ or __torch_dispatch__ (#123347)
Fixes https://github.com/pytorch/pytorch/issues/122459, https://github.com/pytorch/torchtrain/issues/61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123347
Approved by: https://github.com/zou3519
ghstack dependencies: #122502, #122751, #123348
2024-04-15 17:23:20 +00:00
1d6c5972c1 [BE]: Optimize min/max/sum comprehensions C419 (#123960)
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
2024-04-12 23:54:15 +00:00
68cffd19f6 DTensor: add ring attention for _scaled_dot_product_flash_attention (#122460)
Ring attention support for _scaled_dot_product_flash_attention with DTensor.

This assumes the query and key/value are sharded along the sequence length dimension. See the tests for example usage with PT Transformer as well as direct usage with _scaled_dot_product_flash_attention.

## Notable caveats
* Numerical accuracy: The backwards pass doesn't match numerically with the non-chunked version but the forwards pass does. I assume this is due to accumulated errors. I've added a chunked version that uses autograd to verify that the distributed version matches the chunked version.
* nn.Linear has incorrect behavior when running on a sharded tensor of size (bs, heads, seq_len, dim) with `Shard(2)` and does an unnecessary accumulate which requires `Replicate()` on QKV when using `nn.MultiHeadedAttention` to work around the issue.
* If enabled, it forces sequence parallelism and doesn't interop with tensor parallelism.

## SDPA usage

```py
with attention_context_parallel(), sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
    dquery = distribute_tensor(query, device_mesh, [Shard(2)])
    dkey = distribute_tensor(key, device_mesh, [Shard(2)])
    dvalue = distribute_tensor(value, device_mesh, [Shard(2)])

    dout: DTensor = torch.nn.functional.scaled_dot_product_attention(
        dquery, dkey, dvalue, is_causal=is_causal
    )
    out = dout.to_local()
```

## Transformer usage

```py
with attention_context_parallel(), sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=dim,
        nhead=nheads,
        dim_feedforward=dim,
        batch_first=True,
    ).to(dtype)
    encoder_layer = parallelize_module(
        module=encoder_layer,
        device_mesh=device_mesh,
        parallelize_plan={
            "self_attn": ContextParallel(),
        },
    )
    model = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
```

## Test plan

```
pytest test/distributed/_tensor/test_attention.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122460
Approved by: https://github.com/drisspg, https://github.com/wanchaol
2024-04-03 06:45:00 +00:00
102c676418 [DTensor] Added some more foreach ops (#123214)
These ops should already work with the existing strategy. We need these for precomputing fp32 -> fp8 casts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123214
Approved by: https://github.com/wz337
ghstack dependencies: #123142
2024-04-03 02:07:45 +00:00
d7a274e1b0 [dtensor] switch aten.t to use op strategy (#122950)
as titled

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122950
Approved by: https://github.com/awgu, https://github.com/tianyu-l
ghstack dependencies: #122929, #122949
2024-04-01 17:39:43 +00:00
9e1447dad6 [dtensor] make sure expected input spec have correct tensor meta (#122949)
as titled, previously we could possibly return the expected input spec
that shared by multiple args, this is not ok since different args might
have different tensor metas, why it was working before is because
redistribute in these cases become a no-op.

This PR fixes it by making each expected input spec to shallow clone the
corresponding input metadata

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122949
Approved by: https://github.com/tianyu-l
ghstack dependencies: #122929
2024-04-01 17:39:42 +00:00
afee5bea92 [dtensor] refactor schema suggestions in output sharding (#122929)
This PR refactors the schema_suggestions in OuputSharding to be a single
OpSchema instead of list of schemas, which in practice we only have one,
for the multiple resharding case we also moved to OpStrategy so there's
no case that needs it to be a list

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122929
Approved by: https://github.com/tianyu-l
2024-04-01 17:39:39 +00:00
47e8d60627 [dtensor] add op support for view_as_complex and view_as_real (#122569)
This PR will unblock DTensor computations for [rotary embeddings](https://github.com/meta-llama/llama/blob/main/llama/model.py#L132) used in LLaMa training.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122569
Approved by: https://github.com/wanchaol
ghstack dependencies: #122541
2024-03-26 03:32:04 +00:00
4e0b5d59fa [dtensor] add backward support for scaled dot product attention (flash-attention) (#122541)
As titled, as a followup to the forward part #120298.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122541
Approved by: https://github.com/wanchaol
2024-03-26 01:50:24 +00:00
e7fa3f7812 AOTDispatch: allow subclasses to correct when we guess metadata of tangents incorrectly (#118670)
This PR is enough to fix https://github.com/pytorch/pytorch/issues/118600.

More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like:

"We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents"

Here, the problem is similar:

"We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass".

This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial).

One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by:

(1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error

(2) In the error message, provide the name of an optional method that the subclass must implement to handle this case:

`def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement.

`__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement.

`__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time.

I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118670
Approved by: https://github.com/ezyang
2024-03-22 23:16:08 +00:00
11e64b4ba8 [dtensor] aten.cat to use stack strategy approach (#122209)
This PR switch aten.cat to use the strategy approach that is similar to
aten.stack, as these two ops share similar semantics

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122209
Approved by: https://github.com/wz337
2024-03-20 04:19:25 +00:00
3bd38928ba [export] Improve consistency for nn_module_stack metadata, add checks to _trace.py (#120661)
We would like to improve consistency for nn_module_stack metadata in torch.export.

This PR ensures that all tests in test/export/test_export.py has the following constraints:
- Remove nn_module_stack for all placeholder & output nodes, for all modules and submodules
- Ensure nn_module_stack is present for all other node types for the top-level module (there is still an issue with torch.cond submodules having empty fields)
- Add these checks to _export() in _trace.py (we would add this in the Verifier, but downstream apps construct ExportedPrograms separate from _export(), and metadata may not be maintained there)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120661
Approved by: https://github.com/avikchaudhuri
2024-03-16 21:44:52 +00:00
256c0ec1e5 [docs] Added comment on replicate -> partial for _NormPartial (#121976)
Add a version of https://github.com/pytorch/pytorch/pull/121945#discussion_r1525697167 as a comment in the code

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121976
Approved by: https://github.com/wanchaol
ghstack dependencies: #121747, #121869, #121945
2024-03-15 23:04:06 +00:00
b92daff6e9 [DTensor] Enable ASGD foreach optimizer and add the associated unit test (#121942)
Enable ASGD foreach optimizer and add DTensor optimizer unit test for ASGD.

Note that we need to investigate why when using ASGD we need higher atol and rtol when comparing model parameters. Listing it as a TODO now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121942
Approved by: https://github.com/wanchaol
2024-03-15 20:21:27 +00:00
f4dd2fda51 [DTensor] Supported 2D clip_grad_norm_ (#121945)
This PR adds support for 2D `clip_grad_norm_` (`foreach=True`).
- This PR changes `OpSchema.args_spec` to use pytree if the runtime schema info specifies it.
- This PR includes a unit test for 2D FSDP2 + SP with `clip_grad_norm_` enabled, which serves as a complete numerics test for 2D.

Note: With this PR patched, 2-way SP + 4-way FSDP matches 8-way FSDP numerics on Llama-7B (doubling local batch size for the 2-way SP run).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121945
Approved by: https://github.com/wanchaol
ghstack dependencies: #121747, #121869
2024-03-15 20:11:24 +00:00
710446b1eb [dtensor] refactor and generalize stack strategy (#121869)
This PR rewrite the stack strategy to be more generalized, basically
stack/cat like strategy follow pattern need to be smarter, i.e. it
should be able to identify:
1. PR, PP, RP -> follow PP
2. RR, SR, RS -> follow SS

So this PR refactors how the follow strategy should work, and make sure
we start following the strategy that incurred lowest cost. i.e. for
multiple PR, RP placements, we should be able to further delay the
pending sum reductions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121869
Approved by: https://github.com/awgu
2024-03-15 00:34:25 +00:00
a26480a4d1 [dtensor] move early return check into redistribute autograd function (#121653)
This PR fixed the bug of redistribute to move early return check into the
redistribute autograd function, so that even though we redistribute the
same placement, the grad_placements from the `to_local` call might be
different, the redistribute backward still need to happen

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121653
Approved by: https://github.com/awgu
2024-03-12 17:37:30 +00:00
605c0a28aa [dtensor][debug] force visualize_sharding not to print for empty tensors (#121217)
**Summary**
Current `visualize_sharding` code cannot print for empty DTensor objects which leads to an exception. This PR skips the print logic if the DTensor passed in has 0 element.
<img width="2165" alt="Pasted Graphic" src="https://github.com/pytorch/pytorch/assets/12968408/fa40b5e7-dad7-4d3a-a485-6a18067320ff">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121217
Approved by: https://github.com/wanchaol
ghstack dependencies: #121385, #121382
2024-03-11 09:22:49 +00:00
3a5ab17bdc [dtensor][debug] visualize_sharding skip if the current rank is not in mesh (#121382)
**Summary**
We should skip the `visualize_sharding()` function on those ranks that are not a part of the DTensor's mesh. If not, exception will be thrown in current visualize logic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121382
Approved by: https://github.com/wanchaol
ghstack dependencies: #121385
2024-03-11 09:22:49 +00:00
b383123e37 [dtensor][debug] visualize_sharding only compute offset on the first rank in mesh (#121385)
**Summary**
avoid computing on ranks where we do not plan to visualize the DTensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121385
Approved by: https://github.com/wanchaol
2024-03-11 09:22:31 +00:00
242e03ba86 [dtensor] add async_op option to redistribute and some refactor (#121477)
async output option was only available in `full_tensor()` call, but I think it's
generally good to make this option available in the `redistribute` call directly
so that user can control it

This PR adds async_op option to redistribute call, to allow user control
whether to perform tensor redistribution asynchronously or not.

By default we set this to False, this is to follow the semantics of the c10d
collectives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121477
Approved by: https://github.com/wz337
2024-03-09 06:17:23 +00:00
6791b0c09e Change default torch_function behavior to be disabled when torch_dispatch is defined (take 2) (#120632)
This does not introduce a new test but is tested by checking that all the classes we already have still behave as before now that they don't explicitly disable torch_function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120632
Approved by: https://github.com/ezyang
2024-03-09 01:08:37 +00:00
bc02fca358 [dtensor] to_local backward grad placement passthrough (#121474)
to_local accepts a `grad_placements` if user choose to pass, previously
we enforce the grad_out to be the "same" placement as the current
DTensor for safety.

But I realized that we DO NOT need to enforce this constraint. Why?
backward placement does not need to be the same as fwd tensor placement, this
is already the case for param vs param.grad (i.e. param can be replicate
and grad can be partial), so we should not restrict this to activation
vs activation grad too

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121474
Approved by: https://github.com/awgu, https://github.com/yoyoyocmu, https://github.com/yifuwang
2024-03-08 23:11:49 +00:00
08460f4bae [tp] remove deprecated tp_mesh_dim arg (#121432)
This PR removes the deprecated tp_mesh_dim arg to prepare for release.
As we deprecated this arg for a while (by throwing deprecating
messages), we should remove it before the release

#suppress-api-compatibility-check

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121432
Approved by: https://github.com/wz337
ghstack dependencies: #121431
2024-03-08 17:46:44 +00:00
f7ec984b1b [DTensor][XLA] support XLA backend in distirbute_module API (#121355)
Addresses #92909  cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121355
Approved by: https://github.com/wanchaol
2024-03-08 15:47:33 +00:00
4f9d4e1ab0 [DTensor][XLA] refactor DTensor _xla API (#113214)
In response to the change pytorch/xla#5776 and #92909

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113214
Approved by: https://github.com/wanchaol
2024-03-07 06:18:05 +00:00
df2ad1fecc [dtensor][debug] have visualize_sharding correctly print for sub-mesh DTensor (#121216)
**Summary**
In `visualize_sharding` we chose to only print on rank 0 (global rank) which means calling `visualize_sharind` will never print anything when the dtensor object's mesh doesn't include rank 0 (i.e. a sub-mesh). This PR has `visualize_sharding` always print on rank whose mesh coordinate is (0, 0, ..., 0) instead of whose global rank is 0.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121216
Approved by: https://github.com/wanchaol
ghstack dependencies: #121179, #120260
2024-03-07 04:50:15 +00:00
77873f6fe5 [dtensor][1/N] add torchrec even row-wise sharding example (#120260)
**Summary**
our goal is to demonstrate that DTensor's capability to represent TorchRec's parameter sharding. Currently this is done with `ShardedTensor` and theoretically `DTensor` can replace it with minor change.

This PR serves as a start of this effort by adding an example test that represents TorchRec's `ShardingType.ROW_WISE` using DTensor. Note that this PR only covers the even sharding case.

**Test Run**
`torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e row-wise`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120260
Approved by: https://github.com/wanchaol
ghstack dependencies: #121179
2024-03-07 04:50:15 +00:00
9cc0f23e5c [dtensor][debug] allow visualize_sharding to print header (#121179)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121179
Approved by: https://github.com/wanchaol
2024-03-07 04:50:06 +00:00
a88356f45c [dtensor] make add_.Tensor/div_.Scalar to be linear pointwise instead (#121294)
add_.Tensor and div_.Scalar should support linearity so that we delay the partial
results.

This fixes the additional collective in the layernorm layer that we seen

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121294
Approved by: https://github.com/tianyu-l
2024-03-06 22:52:18 +00:00
372f192050 [DTensor] Initialized RNG tracker if needed (#121328)
Since we are already checking if the RNG tracker is initialized, there is no real performance difference between erroring vs. just initializing a default RNG tracker (which we choose to be the `OffsetBasedRNGTracker`).

```
pytest test/distributed/_composable/fsdp/test_fully_shard_init.py -k test_meta
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121328
Approved by: https://github.com/wanchaol
ghstack dependencies: #120351
2024-03-06 22:21:44 +00:00
2e50566722 [dtensor] change distribute_module input/output_fn to accept module (#120895)
This is a BC breaking change to distribute_module. The underlying rationle
for this change is that sometimes in the input_fn/output_fn, user would want
to access to the current module for some attributes. This might not be
common enough, but in some cases it's worth to access to the module.

An outstanding use case we want to support is float8, if we want to make
float8 works with the TP API, the input_fn/output_fn of TP parallel
styles would need to get access to the module, where the module might
encapsulates `dynamic_linear.emulate` attribute, that is useful for
input/output casting

Since this is needed for fp8 and DTensor still under prototype release,
I feel it's worth the change and it's better we make the change as
early.

Right now making it a soft BC breaking, which means we maintain BC still
but throw deprecation messages.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120895
Approved by: https://github.com/tianyu-l
2024-03-04 07:22:32 +00:00
af5376c444 [dtensor] add support for loss parallel (#119877)
Loss parallel is the last piece of sequence parallelism to enable. It enables efficient distributed cross entropy computation when the input is sharded on the class dimension (in a classification problem with many classes). The implementation is via a context manager `loss_parallel`, after enabling which users can directly use `torch.nn.functional.cross_entropy` or `torch.nn.CrossEntropyLoss` without modifying other parts of their code.

Here are the underlying rationales why we are going through these op replacements:

1. `nn.functional.cross_entropy` is the common method that OSS user is using for things like transformer training, to avoid changing user code, we want user to still use this function for loss calculation if they are already using it.
2. `nn.functional.cross_entropy` boils down into `aten.log_softmax` and `aten.nll_loss_foward/backward`, and DTensor now supports those ops already (#117723 #119255 #118917 #119256). They are doing computation with input *replicated* on the class dimension.
3. However when the input of this loss calculation is **sharded on the class dimension**, to run sharded computation efficiently, we need to run both `aten.log_softmax` and `aten.nll_loss_foward` with multiple all-reduce collectives **in the middle of** those aten ops. This is not possible if we are just overriding these two ops, so we need to have some way to **decompose** these two ops into smaller ops to have collectives run in the middle of these two ops.
4. We explored the existing decompositions (#118950). It seems working, except that `log_softmax_backward` and `nll_loss_backward` combined together in aten are implemented in a inefficient way, which would trigger an additional expensive collective. Recently some user also reported similar issues https://github.com/pytorch/pytorch/issues/119261.
5. Therefore, currently we are doing our own decomposition inside a context manager for sequence parallelism specifically. Once we have a better decomposition in core, we can possibly take that instead of reinventing the wheels here.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119877
Approved by: https://github.com/wanchaol
2024-03-02 05:06:26 +00:00
7c71d7f32b [DTensor] Supported foreach=True for clip_grad_norm_ (#120910)
This PR adds support for `clip_grad_norm_(foreach=True)` by implementing `aten._foreach_norm.Scalar` and `aten._foreach_mul_.Tensor`. `foreach=True` is required to get competitive performance with `DTensor`.

`foreach=True` reduces CPU overhead for Llama-7B from 388 ms to 63 ms. Existing flat-parameter FSDP's `clip_grad_norm_` takes 3 ms on CPU 😢 .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120910
Approved by: https://github.com/wanchaol, https://github.com/janeyx99
ghstack dependencies: #120238
2024-03-02 00:28:09 +00:00
f0e8e7cf43 [DTensor] Supported foreach=False for clip_grad_norm_ (#120238)
This PR adds `DTensor` support for `aten.linalg_vector_norm.default` and `aten.stack.default` so that we can run `clip_grad_norm_` (with `foreach=False`).

To implement `linalg_vector_norm`, we introduce a `_NormPartial` placement since the reduction op for norm is the norm itself.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120238
Approved by: https://github.com/wanchaol
2024-03-02 00:25:16 +00:00
09aefe1502 Fix ouput typos (#120870)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120870
Approved by: https://github.com/clee2000
2024-02-29 08:29:14 +00:00
9e0631cc8a get CommsDebugMode to work with DTensor (#118769)
Tested with Wanchao's repro:
```
from typing import Tuple, List, Dict, cast
import torch
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed._tensor import distribute_tensor, DTensor, Shard, Placement, Replicate

mesh = init_device_mesh(device_type="cuda", mesh_shape=(2,))
x = torch.randn(4, 8, requires_grad=True)
y = torch.randn(4, 32, requires_grad=True)
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
from torch.distributed._tensor.debug import CommDebugMode
comm_mode = CommDebugMode()
with comm_mode:
    z = torch.mm(x_dtensor, y_dtensor)
print(comm_mode.get_comm_counts())
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118769
Approved by: https://github.com/wanchaol
2024-02-29 01:11:05 +00:00
0c8bb6f70c [dtensor] standardize tuple strategy handling for foreach ops (#120695)
This PR refactors the tuple strategy handling logic, and allow
TupleStrategy to have both input/output specs for each OpStrategy child,
so that we could further enable operators like foreach norm

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120695
Approved by: https://github.com/awgu
2024-02-27 18:23:11 +00:00
cf6df886a0 Remove hard numpy dependency from experimental_ops.py (#119520)
Based on similar code in the codebase

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119520
Approved by: https://github.com/albanD
2024-02-27 02:46:13 +00:00
5a3e19578f Make tests using CommDebugMode work for both legacy and native funcol (#120070)
We have many tests that use CommDebugMode to verify the occurrence of collectives. These tests do so by querying comm_counts with legacy funcol ops as key. For the purpose of native funcol migration, we need these tests to work for both legacy and native funcol. To avoid the need to modify all tests to accommodate the two implementations, we make CommDebugMode translate native funcol ops into legacy funcol ops until the migration finishes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120070
Approved by: https://github.com/wconstab, https://github.com/wanchaol
ghstack dependencies: #120042, #120043
2024-02-22 20:24:15 +00:00
65627cfd6a [dtensor] implement scaled dot product attention (flash-attention) (#120298)
as titled, this PR implements the sdpa flash attention op in DTensor

Adding flash attention first but efficient attention and other attention
ops should be similar

fixes https://github.com/pytorch/pytorch/issues/120333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120298
Approved by: https://github.com/XilunWu
ghstack dependencies: #120297
2024-02-22 17:53:47 +00:00