reland of https://github.com/pytorch/pytorch/pull/133113
I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(
----
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
We found a corner case that when a tensor dimension is 1, calling `view(1)` would result in an unexpected replication (see case 1 below). When the tensor dimension to shard is not 1, no matter whether the tensor dimension is evenly-shardable across the mesh dimension, it won't cause an implicit replication behind the scenes if view doesn't change the size of the given tensor dimension (see case 2 and 3).
When the tensor dimension to shard is of size 1, it is not being added to shardable_dims here:
https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/_view_ops.py#L518
```
# uneven case where the size of the tensor dimension to shard is 1
p = torch.randn(1,2)
mesh = init_device_mesh(“cuda”, (2,))
dtensor = distribute_tensor(p, mesh, [Shard(0)])
t = dtensor.view(1, 2)
# this would result in replication, meaning t is now replicated across all ranks.
# uneven case where the size of the tensor dimension to shard is not 1
p = torch.randn(3, 2)
mesh = init_device_mesh(“cuda”, (2,))
dtensor = distribute_tensor(p, mesh, [Shard(0)])
t = dtensor.view(3, 2) # this would not result in replication.
# this would not result in replication, meaning t stays as sharded.
# even case
p = torch.randn(2,2)
dtensor = distribute_tensor(p, mesh, [Shard(0)])
t = dtensor.view(2, 2)
# this would not result in replication, meaning t stays as sharded.
```
Differential Revision: [D62155606](https://our.internmc.facebook.com/intern/diff/D62155606)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135054
Approved by: https://github.com/tianyu-l, https://github.com/wanchaol
Solve the request [here](https://github.com/pytorch/pytorch/issues/120003#issuecomment-2248805798).
Enable DTensor input in gradient scaler's APIs, especially on `.unscale_()`
Related dispatch strategy is added to accept DTensor input.
To enable found_inf to conduct reduce action across devices, we add allreduce at dispatch with args after dispatch strategy and kernel.
Since `aten._amp_foreach_non_finite_check_and_unscale_.default` is an inplace_op, grad_scale as the arg[0] with be inplaced, so that redesign a strategy or refactoring the kernel would not help
Test files are testing 2 parts under 1-d(dp) and 2-d(dp,tp) cases:
1. whether the non-inf values unscaled
2. whether all DTensors at each device could found inf even not at their device.
3. If inf not found, will new parameters generates
4. if inf found, will scale be updated
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132816
Approved by: https://github.com/XilunWu, https://github.com/weifengpy, https://github.com/wanchaol
**Summary**
reland of https://github.com/pytorch/pytorch/pull/134294Fixes#131446Fixes#126852Fixes#126868Fixes#126493
The PR was reverted due to CI red signal in https://github.com/pytorch/pytorch/actions/runs/10537099590/job/29201744658. It seems that the `gaussian_nll_loss` test had been flaky before my original PR #134294 . Therefore this PR also removes the `xfail` mark on this specific test to make CI signal green.
See the error message below:
```
2024-08-24T13:42:01.3228990Z ==================================== RERUNS ====================================
2024-08-24T13:42:01.3229530Z [31m[1m_ TestDTensorOpsCPU.test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 _[0m
2024-08-24T13:42:01.3229710Z Unexpected success[90m[39;49;00m
2024-08-24T13:42:01.3230235Z [31m[1m_ TestDTensorOpsCPU.test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 _[0m
2024-08-24T13:42:01.3230407Z Unexpected success[90m[39;49;00m
2024-08-24T13:42:01.3230594Z =================================== FAILURES ===================================
2024-08-24T13:42:01.3231128Z [31m[1m_ TestDTensorOpsCPU.test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 _[0m
2024-08-24T13:42:01.3231296Z Unexpected success[90m[39;49;00m
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134509
Approved by: https://github.com/tianyu-l, https://github.com/wz337
Fixes [134212](https://github.com/pytorch/pytorch/issues/134212)
Currently, when we use 2D FSDP with TP, `optimizer.step()` would fail if the model were not fully tensor parallelized. If we don't have the entire model tensor parallelized when doing 2D, we would have both 1D and 2D DTensor parameters. As foreach is turned on by default, `optimizer.step()` would fail as cross mesh op is not allowed. Error as follows:
```
NotImplementedError: aten._foreach_mul_.Scalar: DTensor does not support cross-mesh operation yet!Got meshes: DeviceMesh('cuda', [[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp')) DeviceMesh('cuda', [1, 3], mesh_dim_names=('dp',))
```
In this PR, we extend implicit_replication to replicate DTensor in missing dimensions for foreach ops. If users don't want to fully tensor parallelize the model when using 2D, they have the option of using the `implicit_replication()` context manager for `optimizer.step()`. In this case, we would swap out the 1D DTensorSpec and replace it with 2D DTensorSpec. However, we don't want to turn this on by default yet, as we want the users to be aware that the tp dimension is replicated if a layer is not tp-ed.
With implicit implication turning on, try replicate dtensor spec in missing dimension would work for most cases for foreach case except when the first DTensor in the list is one that also need to be replicated. This is currently a limitation, which I don't have a good solution yet. Currently, with this change, we can handle most of the cases except the case that the first DTensor's ndim is not the largest.
```
[2D_DTensor, 1D_DTensor...] ---> Implicit_replication() can handle this.
[1D_DTensor, 2D_DTensor...] ---> Implicit_replication() can't handle this.
```
This change doesn't affect the existing default behavior, as `implicit_replication()` is not turned on by default.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134551
Approved by: https://github.com/tianyu-l
**Summary**
reland of https://github.com/pytorch/pytorch/pull/134294Fixes#131446Fixes#126852Fixes#126868Fixes#126493
The PR was reverted due to CI red signal in https://github.com/pytorch/pytorch/actions/runs/10537099590/job/29201744658. It seems that the `gaussian_nll_loss` test had been flaky before my original PR #134294 . Therefore this PR also removes the `xfail` mark on this specific test to make CI signal green.
See the error message below:
```
2024-08-24T13:42:01.3228990Z ==================================== RERUNS ====================================
2024-08-24T13:42:01.3229530Z [31m[1m_ TestDTensorOpsCPU.test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 _[0m
2024-08-24T13:42:01.3229710Z Unexpected success[90m[39;49;00m
2024-08-24T13:42:01.3230235Z [31m[1m_ TestDTensorOpsCPU.test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 _[0m
2024-08-24T13:42:01.3230407Z Unexpected success[90m[39;49;00m
2024-08-24T13:42:01.3230594Z =================================== FAILURES ===================================
2024-08-24T13:42:01.3231128Z [31m[1m_ TestDTensorOpsCPU.test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 _[0m
2024-08-24T13:42:01.3231296Z Unexpected success[90m[39;49;00m
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134509
Approved by: https://github.com/tianyu-l, https://github.com/wz337
**Summary**
This PR is a follow-up of #126924 to address reviewer's comments:
1) add a test case to show the use of `local_map` as a function decorator.
2) simplify the logic of handling different data types of `out_placements`.
3) correct variable naming in test cases to match math formulas.
**Test**
see #126924
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127752
Approved by: https://github.com/wanchaol
Fixes#133499
### The issue
Testing a variety of TP `requires_grad` patterns (validating maximally flexible finetuning) revealed `DTensor` sharding propagation of `aten.native_layer_norm_backward` (default) fails with an `IndexError` for certain `requires_grad` patterns (pattern 1) (e.g. `output_mask` `[True, False, False]`) and an `AssertionError` for others (pattern 2) (e.g. output mask `[False, True, *]`). Please see issue #133499 for a full description of the observed failure patterns along with reproduction.
### Use Cases and Remediation
Failure pattern 1 is potentially problematic for a variety of finetuning scenarios. Though failure pattern 2 is really an xfail right now since it's not fully supported, IMHO there are use cases (e.g. especially wrt to mechanistic interpretability research, but certain finetuning scenarios too potentially) that justify supporting this output mask (especially since supporting it is fairly straightforward I think).
In this PR I propose some modest changes that:
* Address the aforementioned failure modes.
* Add a couple tests that I'm hopeful will help ensure `DTenso`r op dispatch (which is so well implemented and such a pleasure working with btw! 🚀🎉) accommodates a wide variety of (potentially unanticipated) `requires_grad` patterns as it evolves.
To address both failure modes, I'm proposing the following changes:
1. To [`torch.distributed._tensor.ops._math_ops.layer_norm_bwd_strategy`](7b269cc484/torch/distributed/_tensor/ops/_math_ops.py (L873)):
- Refactor conditional `output_mask` handling such that the input and output specs in the`PlacementStrategy`s of the returned `output_strategy.strategies` list remain aligned with the `op_schema.args_spec` (whose definition does not change at runtime based upon unused optional args).
2. To [`torch.distributed._tensor._sharding_prop.propagate_op_sharding_non_cached`](7b269cc484/torch/distributed/_tensor/_sharding_prop.py (L256-L262)):
- When iterating through the active `op_schema.args_spec` to build the relevant `expected_input_specs` list, filter any `None` `desired_specs`.
3. To [`torch/distributed/_tensor/_op_schema.OpSchema._inplace_rewrap_schema_suggestion`](7b269cc484/torch/distributed/_tensor/_op_schema.py (L418))
- When inputs need a redistribute, for runtime-unrequired (`None` arguments in the aligned `suggestion_args_schema`), ignore the associated `suggestion_args_spec`
### Implementation considerations:
- Regarding `1`, to avoid changing the op strategy return args ([`op_strategy`](cf81180007/torch/distributed/_tensor/_sharding_prop.py (L234))), the change in `1` allows `None` elements to exist temporarily in `PlacementStrategy.input_specs` (treating it as `Sequence[DTensorSpec | None] | None` when it's `Sequence[DTensorSpec] | None`. This could be addressed in any number of ways but I thought it best to leave that for a subsequent PR since it could have broader ramifications (e.g. allowing op_strategies to return an output_strategy.input_specs` mask explicitly, explicitly allowing `None`s in `PlacementStrategy.input_specs`, creating a `Null` DTensorSpec etc.). That's why I'm using an ignore arg-type directive there for now.
- Regarding `2` and `3` above, I don't introspect `op_schema.op._schema.arguments` to verify any `None` arguments are `torch.OptionalType`, leaving adherence to the schema contract the responsibility of the given op. Regarding `2`, I assume any `desired_spec` will be either a `DTensorSpec` or `None`, so only `None` can be Falsy in this context.
- I considered altering the active `args_schema`, which could be inspected and aligned with the active `output_strategy.input_specs` in some cases and avoid the changes in `3`, but I think that would rely on one of (among other possibilities):
- all supported op signatures having optional Tensors (`DTensorSpec`) args after required tensors (which isn't a planned required as far as I know),
- (somewhat brittle) heuristic-driven arg alignment
- only supporting kwargs etc.
### Added Tests
To facilitate detection of future `requires_grad` pattern op failure modes as `DTensor` evolves, I added the following two tests:
1. `test/distributed/_tensor/test_math_ops.py DistMathOpsTest.test_layer_norm_bwd_req_grad`
- Tests `native_layer_norm_backward` specifically with 20 subtests that sweep valid `output_mask` patterns along in different LayerNorm dimensionality and `elementwise_affine` configurations.
2. `test/distributed/tensor/parallel/test_tp_examples.py DistTensorParallelExampleTest.test_transformer_req_grad`
- Samples a subset of `requires_grad` patterns in a more realistic (relative to the `LayerNorm`-specific test) Transformer usage context with different `dtype` and `is_seq_parallel` configurations. Note since there was substantial overlap with the existing `test_transformer_training` test, I took the opportunity to refactor that test to allow relevant code-sharing. I also added an `ExpCommCounts` `NamedTuple` to facilitate the addition of additional `requires_grad` patterns that we may want to test in the future which may result in different comm counts. I created the separate `requires_grad` test to allow decoupling the multi-iteration `test_transformer_training` test and allow addition of new `requires_grad` scenarios as desired while being mindful of resources.
Thanks again to the PyTorch distributed team for your immensely valuable contributions to the open-source ML community!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133502
Approved by: https://github.com/XilunWu
For `aten.any`, we can use `reduce_op="sum"` as the linear reduction op.
When we do `all_reduce` with `reduce_op="sum"` on bool tensor, if one rank returns `torch.Tensor([True]) `, then the reduction result is `torch.Tensor([True]) `. Only when all ranks return `torch.Tensor([False]) ` would the reduction result be `torch.Tensor([False]) `. This matches with `any`'s behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134206
Approved by: https://github.com/tianyu-l, https://github.com/chuanhaozhuge
Fixes#134050
### The issue
The current `DTensor` sharding propagation caching policy for `aten.scaled_dot_product_efficient_attention` (default) can result in silently incorrect gradients or trigger an IMA after cuda kernel launch in mixed `require_grad` configurations. Please see issue #134050 for a full description of the observed failure patterns along with reproduction. Note `aten.scaled_dot_product_flash_attention` presents a similar concern so this PR addresses both [as discussed here.](https://github.com/pytorch/pytorch/issues/134050#issuecomment-2299887602)
### Remediation
While there are a number of ways this could be addressed, the most straightforward remediation is to modify the sharding propagation caching policy of [`aten._scaled_dot_product_efficient_attention.default`](b03381cac2/torch/distributed/_tensor/ops/_matrix_ops.py (L337-L340)), registering it with `schema_info=RuntimeSchemaInfo(4)` to prevent cache sharing between differing `compute_log_sumexp` values i.e.
```python
@register_op_strategy(aten._scaled_dot_product_efficient_attention.default, schema_info=RuntimeSchemaInfo(4))
def scaled_dot_product_efficient_attention_strategy(
...
```
[As discussed here](https://github.com/pytorch/pytorch/issues/134050#issuecomment-2299887602), since `aten::_scaled_dot_product_flash_attention` could be affected by a similar issue wrt `return_debug_mask`, this PR adjusts the sharding propagation caching policy for that op as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134146
Approved by: https://github.com/tianyu-l
Part of #134054.
This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
**What does this PR achieve**
1. This PR rewrite ring attention backward algorithm to fuse the alltoall and overlap the gradient communication with computation.
2. Enables memory efficient attention with CP by templating the ring attention backward to verify the accuracy as fp32 gives us higher confident about the implementation correctness.
3. Provides some experimental APIs to enable context parallelism.
4. Ensures CP work with torch.compiler. The combination of causal masking and torch.compiler has not
yet worked.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131351
Approved by: https://github.com/wanchaol
# Summary
Changes the stance of SDPA on what to do for fully masked out rows
## Current Behavior
Several PyTorch users have expressed frustration over this issue:
- https://github.com/pytorch/pytorch/issues/41508
- https://github.com/pytorch/pytorch/issues/103749
- https://github.com/pytorch/pytorch/issues/103963
These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617
Can be paraphrased as follows:
When passing in fully masked out rows, attention becomes ambiguous. We have two main options:
1. Uniformly attend to all values:
```python
scores[masked_out_rows] = 1 / len(row)
out[masked_out_rows] = 1 / len(row) * value
```
2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:
```python
output[fully_masked_rows] = NaN
```
We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:
``` Python
>fill_value = -float("inf")
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])
```
Cool, problem solved. But what happends when you call backwards..
```Python
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
[ nan, nan, nan, nan]])
```
Those pesky NaNs are back!
## Why do we see NaNs today?
The core of the problem revolves around using softmax function in sdpa:
```python
> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])
```
## Quick Aside: Masking in Attention
Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs.
We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.
## Alternative Approaches
If we use a very large negative number instead of -inf:
```python
> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])
```
However if users always remembered to "slice" out their outputs i.e.:
```Python
>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564, 0.1613, -0.0486],
[ 0.0000, 0.0000, 0.0000, 0.0000]])
```
This would bring us back into a better state.
## A Third Option
We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.
This PR implements the new semantic for masking w/ attention in fully masked-out rows:
```python
out[masked_out_rows] = 0
```
**Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.
## Details
This PR stack does 3 things:
1. Adds a PRIVATE _safe_softmax op
2. Updates semantic for flash_cpu fused kernel
3. Updates semantic for efficient_cuda fused kernel
_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.
Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them?
The only case that this can happen is if the input itself had a NaN or an Inf
For example:
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))
```
Will return
`tensor([0., 1., 0., 0.], dtype=torch.float16)`
Where
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)
```
returns:
`tensor([nan, nan, nan, nan], dtype=torch.float16)`
If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this
```Python
max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)
```
however we would be paying for this in math performance.
## Why Now
I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131060
Approved by: https://github.com/jbschlosser
As titled, this PR rewrite the current redistribute algorithm to make
the multi-mesh dim redistribute logic more sound. The previous algorithm
works numerically but it could incur additional non-necessary steps
when transforming shardings in the multi-dimesnion device mesh, i.e.
Let's say we want to transform from (S(1), S(1)) -> (S(1), S(2)). The
previous algorithm yield the following steps:
* mesh_dim 1: S(1) -> R, mesh_dim 0: S(1) -> R
* mesh_dim 0: R -> S(1), mesh_dim 1: R -> S(2)
Although it works semantically but it incurs two allgather
transformations, where it should really only incur a S(1) -> S(2) on the
mesh dim 1.
The rewrite algorithm basically take it in a more principled way:
1. we check if src_spec have sharding, if not, we don't need to worry about nested sharding case, as sharding would always be in order, so we just go from left to right in the placements and add the transform steps
2. if src_spec have sharding, this potentially means that there would be either nested or mis-aligned shardings. So we first tranverse from right to left to check if there's mis-aligned sharding as the above example showed, if there is, we replicate that mesh dimension so that it unshard the nested sharding
3. we tranverse again from left to right to generate the transformation
after we unshard the nested sharding
should also fix https://github.com/pytorch/pytorch/issues/132751
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131210
Approved by: https://github.com/tianyu-l
**Summary**
1. change `compute_local_shape_and_global_offset` to correctly compute shape and offset for strided sharding placement (currently it only handles 2D and some 3D+ sharding).
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
**Test**
`test/distributed/_tensor/test_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132391
Approved by: https://github.com/wanchaol
ghstack dependencies: #126697, #130239
**Summary**
This PR adds a new private placement type `_StridedShard` for FSDP2 + TP style tensor sharding. The previously used `Shard` placement type cannot produce correct `full_tensor()` result because it assumes the tensor to be first sharded over `dp` mesh dimension then `tp` mesh dimension which does not hold true in FSDP2 + TP case.
**Test**
`pytest test/distributed/_tensor/test_utils.py -s -k strided_sharding`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126697
Approved by: https://github.com/wanchaol
fixes https://github.com/pytorch/pytorch/issues/132016.
Right now if you run an op that DTensor has no sharding prop rule, **and** that op accepts non-trivial pytrees of inputs tensors as arguments, DTensor can end up infinite looping before it has the chance to error due to not having a sharding prop rule.
This PR doesn't fix the problem, but adds rules for the culprit ops (missing foreach ops)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132066
Approved by: https://github.com/wanchaol
**Summary**
I created functions that reduced repeating code in the console and json APIs which also improved their readability for future developers.
**Test Plan**
1. torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/comm_mode_features_example.py -e transformer_json_dump
2. torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/comm_mode_features_example.py -e transformer_operation_tracing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132070
Approved by: https://github.com/XilunWu
`register_sharding` is an experimental API that allows users to register sharding strategies for an operator when the tensor inputs and outputs are :class:`DTensor`s. It can be useful when: (1) there doesn't exist a default sharding strategy for ``op``, e.g. when `op` is a custom operator that is not supported by `DTensor`; (2) when users would like to overwrite default sharding strategies of existing operators.
Here's an example:
@register_sharding(aten._softmax.default)
def custom_softmax_sharding(x, dim, half_to_float):
softmax_dim = dim if dim >= 0 else dim + x.ndim
acceptable_shardings = []
all_replicate = ([Replicate()], [Replicate(), None, None])
acceptable_shardings.append(all_replicate)
for sharding_dim in range(x.ndim):
if sharding_dim != softmax_dim:
all_sharded = (
[Shard(sharding_dim)],
[Shard(sharding_dim), None, None],
)
acceptable_shardings.append(all_sharded)
return acceptable_shardings
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131108
Approved by: https://github.com/wanchaol
Try to unblock https://github.com/pytorch/pytorch/issues/131991
- `nn.init.orthogonal_` uses `tensor.new`, which is the legacy factory function. We change this to `tensor.new_empty` (empty is okay since it will be immediately followed by `.normal_()` to fill the tensor) so that it preserves `DTensor`-ness.
- `nn.init.orthogonal_` uses QR decomposition (`aten.linalg_qr.default`) and `torch.diag` (calling into `aten.diagonal_copy.default`). For simplicity, we use naive replicate strategies for now. `aten.diagonal_copy.default` could do something more sophisticated for sharded inputs, but I would rather defer that to later due to the complexity. For `orthogonal_` support specifically, since the result of the QR decomp will be replicated, the input to `aten.diagonal_copy.default` will be replicated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132104
Approved by: https://github.com/albanD, https://github.com/wanchaol
Previously, using _MaskPartial when multiple embeddings have the following issues:
1. Suppose an `nn.Embedding` has shape `[vocab_size, emb_size]`. When there are more than one embeddings, sharing the same `vocab_size` but with different `emb_size`s. Then they would not share `OpStrategy` since each, when involved in computation, would have different `OpSchema`; however, there would be cache hit for redistribute (specifically `_gen_transform_infos` in `torch/distributed/_tensor/_redistribute.py` when doing `Replicate` -> `_MaskPartial`) as the `_MaskPartial` only has `vocab_size` as `logical_dim_size` but not `emb_size` as attribute. This cache hit is undesirable and would cause trouble when doing all-reduce/reduce-scatter on the new `_MaskPartial` in a separate `OpStrategy`. The error was reported in #130725. In this PR, we introduce `offset_shape` to represent the embedding's full shape to avoid cache hit from embeddings of different shapes.
2. The second issue is when we have two `nn.Embedding`s `emb1` and `emb2` with the same shape. There will be cache hit not only in `_gen_transform_infos`, but also in `OpStrategy` generation. Previously, if we sequentially do `Replicate` -> `_MaskPartial` for both `emb1` `emb2` and then sequentially do reduction on the `_MaskPartial` of `emb1`, it would destroy the `MaskBuffer` and `emb2` would hit error. This PR adds a `refcount` for the `MaskBuffer` so that it can be properly shared by multiple `nn.Embedding`s.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131264
Approved by: https://github.com/wanchaol