Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution. Note that ``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh dimension order is used. When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of the mesh dimensions for a given tensor dimension.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution. Note that ``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh dimension order is used. When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of the mesh dimensions for a given tensor dimension.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution. Note that ``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh dimension order is used. When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of the mesh dimensions for a given tensor dimension.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution. Note that ``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh dimension order is used. When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of the mesh dimensions for a given tensor dimension.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution. Note that ``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh dimension order is used. When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of the mesh dimensions for a given tensor dimension.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution. Note that ``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh dimension order is used. When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of the mesh dimensions for a given tensor dimension.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------
This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.
### Overview
DTensor supports two complementary ways to specify tensor distribution:
* **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
* **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)
### API Format
#### User-Facing API
* **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
* **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
* Keys: tensor dimensions
* Values: mesh dimensions (or names) in execution order
#### Internal Representation
Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order: ShardOrder = (
ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)),
ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```
### Usage Patterns
#### 1. Using Only **placements**
When only **placements** is specified, default left-to-right ordering is used:
```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```
#### 2. Using Only **shard_order** (preferred way)
When only **shard_order** is specified, **placements** are inferred:
```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
tensor,
device_mesh,
shard_order={0: [1], 1: [2, 0]} # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```
#### 3. Using Both (Maximum Control)
Specify both when you need:
* Explicit control over multi-mesh-dimension sharding order
```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order)
dt = distribute_tensor(
tensor,
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={0: [1], 1: [2, 0]} )
```
(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)
**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.
#### 4. Neither Specified (Default Replication)
```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```
### Use in `redistribute()`
Same API applies to `redistribute()`:
```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
device_mesh,
placements=[Shard(1), Shard(0), Shard(1), Replicate()],
shard_order={1: [0, 2]}
)
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```
### Constraints
1. **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2. **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3. **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4. **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape
--------
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)
Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.
### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis:
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />
Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.
### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)
Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.
### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis:
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />
Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.
### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)
Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.
### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis:
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />
Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.
### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)
Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.
### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis:
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />
Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.
### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)
Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.
### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis:
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />
Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.
### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)
Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.
### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis:
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />
Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.
### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Enable the DebugMode to print out how `placements` and `shard_order` will update when we execute `transform_infos` to transform from source placement to target placement.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Enable the DebugMode to print out how `placements` and `shard_order` will update when we execute `transform_infos` to transform from source placement to target placement.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Enable the DebugMode to print out how `placements` and `shard_order` will update when we execute `transform_infos` to transform from source placement to target placement.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Enable the DebugMode to print out how `placements` and `shard_order` will update when we execute `transform_infos` to transform from source placement to target placement.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Enable the DebugMode to print out how `placements` and `shard_order` will update when we execute `transform_infos` to transform from source placement to target placement.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Enable the DebugMode to print out how `placements` and `shard_order` will update when we execute `transform_infos` to transform from source placement to target placement.
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
@ -926,6 +938,10 @@ class Redistribute(torch.autograd.Function):
ifoutput.dtype!=ctx.original_dtype:
output=output.to(ctx.original_dtype)
# TODO(zpcore): During backward, some Partial related transform ops got
# silently skipped. This will be an issue for the graph-based
# redistribute planner. Need fix.
# normalize the target placement to replicate if it is partial
normalized_placements:list[Placement]=[]
forprevious_placementinprevious_spec.placements:
@ -943,6 +959,8 @@ class Redistribute(torch.autograd.Function):
stride=grad_output.stride(),
dtype=output.dtype,
),
# this is subject to be wrong if we skip Partial() transform
shard_order=previous_spec.shard_order,
)
output_dtensor=dtensor.DTensor(
output,
@ -957,4 +975,5 @@ class Redistribute(torch.autograd.Function):
None,
None,
None,
None,
)
Reference in New Issue
Block a user
Blocking a user prevents them from interacting with repositories, such as opening or commenting on pull requests or issues. Learn more about blocking a user.