Compare commits

...

57 Commits

Author SHA1 Message Date
fd1ff815fb Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution. Note that ``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh dimension order is used. When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of the mesh dimensions for a given tensor dimension.


#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 13:55:48 -07:00
45f00e42fc Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution. Note that ``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh dimension order is used. When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of the mesh dimensions for a given tensor dimension.


#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-30 13:55:48 -07:00
356cc87ab8 Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution. Note that ``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh dimension order is used. When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of the mesh dimensions for a given tensor dimension.


#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-16 01:00:25 -07:00
0b5293b978 Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution. Note that ``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh dimension order is used. When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of the mesh dimensions for a given tensor dimension.


#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-16 01:00:25 -07:00
9bc0591e09 Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution. Note that ``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh dimension order is used. When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of the mesh dimensions for a given tensor dimension.


#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 23:24:03 -07:00
cda68d6ff3 Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution. Note that ``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh dimension order is used. When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of the mesh dimensions for a given tensor dimension.


#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 23:24:03 -07:00
70ccc4d31a Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 16:48:46 -07:00
5a769daa65 Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 16:48:46 -07:00
8ed10cd23b Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 15:09:24 -07:00
571ac2bd77 Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 15:09:24 -07:00
6bc1c74646 Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 14:25:39 -07:00
f6887cbc6c Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-15 14:25:38 -07:00
b16c099072 Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-14 14:16:40 -07:00
3d042833fc Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-14 14:16:39 -07:00
99a68c0f16 Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-14 13:34:35 -07:00
ba7c00d73e Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-14 13:34:35 -07:00
d5480cfb0d Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-13 11:14:52 -07:00
131b421df4 Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-13 11:14:51 -07:00
49314d2459 Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-13 10:12:02 -07:00
47512f0a39 Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-13 10:12:02 -07:00
d6133b9ecf Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-10 17:20:07 -07:00
3f9d597374 Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-10 17:20:07 -07:00
3515e04a18 Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-10 15:57:33 -07:00
c594c1b09d Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-10 15:57:33 -07:00
ddd7c4f4bf Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-09 15:58:25 -07:00
01807d80e6 Update base for Update on "[6/N][DTensor device order] User API for device ordered distribution"
Using `**placements**` and `**shard_order**` in DTensor Distribution
-----------------

This PR updates the documentation and implementation for **placements** and **shard_order** parameters in `distribute_tensor()` and `redistribute()` methods.

### Overview

DTensor supports two complementary ways to specify tensor distribution:

*   **placements**: Describes _what_ placement each mesh dimension has (PyTorch-style, mesh-centric)
*   **shard_order**: Describes _the order_ in which shardings are applied (JAX-style, tensor-centric)

### API Format

#### User-Facing API

*   **placements**: `Sequence[Placement]` - e.g., `[Shard(0), Shard(1), Replicate()]`
*   **shard_order**: `dict[int, Sequence[int | str]]` - e.g., `{0: [1], 1: [0, 2]}` or using mesh dim string name `{0: [1], 1: ['dp', 'tp']}`
    *   Keys: tensor dimensions
    *   Values: mesh dimensions (or names) in execution order

#### Internal Representation

Internally, **shard_order** is converted to `ShardOrder` (tuple of `ShardOrderEntry` objects):
```python
shard_order:  ShardOrder = (
    ShardOrderEntry(tensor_dim=0, mesh_dims=(1,)), 
    ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
)
```

### Usage Patterns

#### 1. Using Only **placements**

When only **placements** is specified, default left-to-right ordering is used:

```python
# Shard tensor dim 0 on mesh dim 0, tensor dim 1 on mesh dim 1 and 2 (mesh dim order is fixed)
dt = distribute_tensor( 
    tensor,
    device_mesh,
    placements=[Shard(0), Shard(1), Shard(1)]
) # Internally: shard_order = (ShardOrderEntry(0, (0,)), ShardOrderEntry(1, (1, 2)))
```

#### 2. Using Only **shard_order** (preferred way)

When only **shard_order** is specified, **placements** are inferred:

```python
# Shard tensor dim 0 on mesh dim 1, tensor dim 1 on mesh dims 2 then 0 (mesh dim order can be specified)
dt = distribute_tensor(
     tensor,
     device_mesh,
     shard_order={0: [1], 1: [2, 0]}   # Inferred placements: [Shard(1), Shard(0), Shard(1), Replicate()]
)
```

#### 3. Using Both (Maximum Control)

Specify both when you need:

*   Explicit control over multi-mesh-dimension sharding order

```python
# Tensor dim 1 sharded across mesh dims 2 and 0 (in that order) 
dt = distribute_tensor(
    tensor,
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={0: [1], 1: [2, 0]} )
```

(in theory **shard_order** should be expressive enough without **placements**. We leave the option there considering special placement types (e.g., uneven sharding) may be introduced in the future)

**Important**: When both are specified, they must be consistent. The function validates that placements and shard_order describe the same distribution.

#### 4. Neither Specified (Default Replication)

```python
dt = distribute_tensor(tensor, device_mesh) # Results in: placements=[Replicate()] * mesh.ndim
```


### Use in `redistribute()`

Same API applies to `redistribute()`:

```python
# Redistribute with new sharding order
dt_redistributed = dt.redistribute(
    device_mesh,
    placements=[Shard(1), Shard(0), Shard(1), Replicate()],
    shard_order={1: [0, 2]}
) 
# Or just change shard order
dt_redistributed = dt.redistribute(shard_order={0: [1], 1: [2, 0]} )
```

### Constraints

1.  **Consistency**: When both are specified, placements and shard_order must describe the same distribution
2.  **No _StridedShard with shard_order**: `_StridedShard` placement cannot be used with explicit **shard_order**
3.  **No Partial inference**: User doesn't allow to specify `Partial` placement, and `Partial` placements cannot be inferred from **shard_order**
4.  **Dimension validation**: All tensor and mesh dimensions must be valid for the given tensor rank and mesh shape




--------



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-09 15:58:25 -07:00
26e7302ce7 Update on "[6/N][DTensor device order] User API for device ordered distribution"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-09 12:56:17 -07:00
d600d9a16d [6/N][DTensor device order] User API for device ordered distribution
[ghstack-poisoned]
2025-10-09 11:49:47 -07:00
48a3a84138 Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)

Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.

### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis: 
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />

Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.

### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-09 01:02:10 -07:00
f3bb4de7f4 Update base for Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)

Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.

### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis: 
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />

Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.

### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-09 01:02:09 -07:00
2f7011cf92 Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)

Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.

### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis: 
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />

Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.

### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-09 00:24:17 -07:00
93a704d269 Update base for Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)

Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.

### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis: 
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />

Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.

### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-09 00:24:17 -07:00
8888e8c1a9 Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)

Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.

### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis: 
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />

Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.

### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-08 16:54:24 -07:00
e8035d42d2 Update base for Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)

Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.

### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis: 
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />

Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.

### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-08 16:54:24 -07:00
8acc8e022f Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-08 11:23:21 -07:00
45ff858ae0 Update base for Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-08 11:23:21 -07:00
407912a3a4 Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-08 11:12:17 -07:00
64fb7c6548 Update base for Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-08 11:12:17 -07:00
d5e0fa71e5 Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-08 10:59:13 -07:00
1f045012bb Update base for Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-08 10:59:13 -07:00
b26c4379c0 Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-07 23:33:54 -07:00
67e83b184e Update base for Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-07 23:33:54 -07:00
63affab34a Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-07 17:24:23 -07:00
965ae294f8 Update base for Update on "[5/N][DTensor device order] Implement graph based redistribution algorithm"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-07 17:24:23 -07:00
80783f0fc8 [5/N][DTensor device order] Implement graph based redistribution algorithm
[ghstack-poisoned]
2025-10-07 17:14:10 -07:00
7da1778906 Update on "[4/N] [DTensor device order] Support debugmode to show dtensor distribution transform path"
Enable the DebugMode to print out how `placements` and `shard_order` will update when we execute `transform_infos` to transform from source placement to target placement.




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-07 10:40:55 -07:00
61c54e119a Update base for Update on "[4/N] [DTensor device order] Support debugmode to show dtensor distribution transform path"
Enable the DebugMode to print out how `placements` and `shard_order` will update when we execute `transform_infos` to transform from source placement to target placement.




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-07 10:40:55 -07:00
a90cef063b Update on "[4/N] [DTensor device order] Support debugmode to show dtensor distribution transform path"
Enable the DebugMode to print out how `placements` and `shard_order` will update when we execute `transform_infos` to transform from source placement to target placement.




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-07 09:47:24 -07:00
b08a11bd92 Update base for Update on "[4/N] [DTensor device order] Support debugmode to show dtensor distribution transform path"
Enable the DebugMode to print out how `placements` and `shard_order` will update when we execute `transform_infos` to transform from source placement to target placement.




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-07 09:47:24 -07:00
87a72fc603 Update on "[4/N] [DTensor device order] Support debugmode to show dtensor distribution transform path"
Enable the DebugMode to print out how `placements` and `shard_order` will update when we execute `transform_infos` to transform from source placement to target placement.




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-06 23:14:54 -07:00
5528e5352a Update base for Update on "[4/N] [DTensor device order] Support debugmode to show dtensor distribution transform path"
Enable the DebugMode to print out how `placements` and `shard_order` will update when we execute `transform_infos` to transform from source placement to target placement.




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-06 23:14:54 -07:00
941536ba4f [4/N] [DTensor device order] Support debugmode to show dtensor distribution transform path
[ghstack-poisoned]
2025-10-06 23:01:58 -07:00
da307d6317 [3/N] [DTensor device order] Make some placement type class method static
[ghstack-poisoned]
2025-10-06 23:01:40 -07:00
0d8c18f4c1 Update on "[2/N] [DTensor device order] Add shard_order attribute in DTensorSpec"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-06 18:04:42 -07:00
5f15b6d9aa Update base for Update on "[2/N] [DTensor device order] Add shard_order attribute in DTensorSpec"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-10-06 18:04:42 -07:00
c0c4862443 [2/N] [DTensor device order] Add shard_order attribute in DTensorSpec
[ghstack-poisoned]
2025-10-06 18:02:03 -07:00
a36b2da37a Device mesh util function to support device order placement
[ghstack-poisoned]
2025-10-06 16:14:39 -07:00
6 changed files with 1063 additions and 82 deletions

View File

@ -1,11 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import contextlib
import tempfile
import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
@ -16,8 +18,14 @@ from torch.distributed.tensor import (
Shard,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests
from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorContinuousTestBase,
DTensorTestBase,
with_comms,
)
@ -120,7 +128,7 @@ class DTensorAPITest(DTensorTestBase):
distribute_tensor(tensor_to_distribute, device_mesh, shard_spec)
with self.assertRaisesRegex(RuntimeError, "distribute leaf tensor"):
shard_spec = [Shard(0)]
shard_spec = [Shard(0), Shard(0)]
global_tensor = torch.randn(*tensor_shape, requires_grad=True)
global_tensor_to_distribute = global_tensor + 2
distribute_tensor(global_tensor_to_distribute, device_mesh, shard_spec)
@ -388,5 +396,441 @@ class DTensorAPITest(DTensorTestBase):
dcp.save({"fqn": dtensor}, checkpoint_id=tempfile.mkdtemp())
class DTensorDeviceOrderAPITest(DTensorContinuousTestBase):
world_size = 4
@property
def device(self):
return f"{DTensorContinuousTestBase.device_type()}:{self.rank}"
def build_device_mesh(self, mesh_shape=None) -> DeviceMesh:
if mesh_shape is None:
mesh_shape = (2, self.world_size // 2)
return init_device_mesh(DTensorContinuousTestBase.device_type(), mesh_shape)
def test_shard_order_property(self):
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
input_tensor_dt = distribute_tensor(input_tensor, mesh, shard_order={0: [1, 0]})
# check if we can reuse the shard_order property
input_tensor_dt_reuse = distribute_tensor(
input_tensor, mesh, shard_order=input_tensor_dt.shard_order
)
self.assertEqual(input_tensor_dt._spec, input_tensor_dt_reuse._spec)
def test_neither_placements_nor_shard_order(self):
"""Test that neither placements nor shard_order, use default"""
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
input_tensor_dt = distribute_tensor(input_tensor, mesh)
self.assertEqual(
input_tensor_dt.placements, [Replicate() for _ in range(mesh.ndim)]
)
self.assertEqual(input_tensor_dt.shard_order, {})
input_tensor_dt.redistribute(mesh, (Shard(0), Shard(0)))
input_tensor_dt.redistribute(mesh)
self.assertEqual(
input_tensor_dt.placements, [Replicate() for _ in range(mesh.ndim)]
)
self.assertEqual(input_tensor_dt.shard_order, {})
@parametrize(
"placements, shard_order_dict, should_pass",
[
[(Shard(0), Shard(0)), {0: [0], 1: [1]}, False],
[(Shard(0), Shard(0)), {0: [0]}, False],
[(Shard(0), Shard(0)), {0: [0, 1]}, True],
[(Shard(0), Shard(0)), {0: [1, 0]}, True],
[(Shard(1), Shard(0)), {0: [1], 1: [0]}, True],
[(Shard(1), Shard(0)), {0: [0], 1: [1]}, False],
[(Shard(1), Shard(2)), {1: [0], 2: [1]}, True],
[(Replicate(), Shard(2)), {2: [1]}, True],
[(Replicate(), Replicate()), {}, True],
[(Shard(0), Shard(0)), None, True],
# not mention every Shard() from `placements` in `shard_order` but it still works
[(Shard(0), Shard(0)), {}, True],
[
(Shard(0), Shard(0)),
{0: [0]},
False,
], # need to specify all mesh dims for Shard(0)
[(Shard(1), Shard(0)), {0: [1]}, True],
[(Shard(1), Shard(0)), {1: [0]}, True],
],
)
def test_conflict_placements_and_shard_order(
self, placements, shard_order_dict, should_pass
):
"""Test that providing conflict placements and shard_order raises an error."""
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
test_context = (
contextlib.nullcontext()
if should_pass
else self.assertRaisesRegex(
ValueError,
r"`shard_order` for tensor dim \d+ must include ALL mesh dimensions "
r"that shard this tensor dimension in `placements`",
)
)
with test_context:
distribute_tensor(
input_tensor, mesh, placements=placements, shard_order=shard_order_dict
)
@parametrize(
"placements, expected_shard_order_tuple",
[
[
(Shard(0), Shard(1)),
{0: [0], 1: [1]},
],
[(Shard(0), Shard(0)), {0: [0, 1]}],
[
(Shard(1), Shard(2)),
{1: [0], 2: [1]},
],
[(Replicate(), Shard(2)), {2: [1]}],
[(Replicate(), Replicate()), {}],
],
)
def test_only_placements_provided(self, placements, expected_shard_order_tuple):
"""Test that providing only placements works correctly."""
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
input_tensor_dt = distribute_tensor(input_tensor, mesh, placements)
self.assertEqual(input_tensor_dt.placements, tuple(placements))
self.assertEqual(input_tensor_dt.full_tensor(), input_tensor)
self.assertEqual(input_tensor_dt.shard_order, expected_shard_order_tuple)
@parametrize(
"expected_placements, shard_order_dict",
[
[(Shard(0), Shard(1)), {0: [0], 1: [1]}],
[(Shard(0), Shard(0)), {0: [0, 1]}],
[(Shard(0), Shard(0)), {0: [1, 0]}],
[(Shard(1), Shard(2)), {1: [0], 2: [1]}],
[(Replicate(), Shard(2)), {2: [1]}],
[(Replicate(), Replicate()), {}],
[(Replicate(), Replicate()), {0: []}], # allow empty_shard_order_sequences
],
)
def test_only_shard_order_provided(self, expected_placements, shard_order_dict):
"""Test that providing only shard_order works correctly."""
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
input_tensor_dt = distribute_tensor(
input_tensor, mesh, shard_order=shard_order_dict
)
self.assertEqual(input_tensor_dt.placements, expected_placements)
self.assertEqual(input_tensor_dt.full_tensor(), input_tensor)
# all replicate tensor, test for redistribution
input_tensor_dt = distribute_tensor(input_tensor, mesh)
input_tensor_dt = input_tensor_dt.redistribute(
mesh, shard_order=shard_order_dict
)
self.assertEqual(input_tensor_dt.placements, expected_placements)
self.assertEqual(input_tensor_dt.full_tensor(), input_tensor)
@parametrize(
"placements, shard_order_dict, should_pass",
[
[(Shard(0), Shard(0)), {0: [1, 0]}, True],
[None, {0: [1], 1: [0]}, True],
[(Shard(1), Shard(2)), {1: [0], 2: [2]}, False],
[(Shard(1), Shard(2)), {1: [0], 2: [-1]}, False],
[(Shard(1), Shard(2)), {1: [0], -1: [1]}, True],
[None, {1: [0, 1]}, True],
[None, {1: [1, -3]}, False],
],
)
def test_out_of_range_shard_order(self, placements, shard_order_dict, should_pass):
"""Test that providing only shard_order works correctly."""
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
test_context = (
contextlib.nullcontext()
if should_pass
else self.assertRaisesRegex(
IndexError,
"`shard_order` is out of range for placements",
)
)
with test_context:
distribute_tensor(
input_tensor, mesh, placements=placements, shard_order=shard_order_dict
)
# all replicate tensor, test for redistribution
input_tensor_dt = distribute_tensor(input_tensor, mesh)
with test_context:
input_tensor_dt.redistribute(
mesh, placements=placements, shard_order=shard_order_dict
)
@parametrize(
"placements, shard_order_dict, should_pass",
[
[(Shard(0), Shard(0)), {-3: [1, 0]}, True],
[(Shard(0), Shard(0)), {0: [0], -3: [1]}, False],
[(Shard(0), Shard(0)), {0: [0, 1], -3: []}, False],
[(Shard(0), Shard(0)), {0: [1, 0]}, True],
],
)
def test_duplicated_tensor_dim_shard_order(
self, placements, shard_order_dict, should_pass
):
"""Test that providing only shard_order works correctly."""
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
test_context = (
contextlib.nullcontext()
if should_pass
else self.assertRaisesRegex(
ValueError,
r"both normalized tensor dim * and un-normalized tensor dim * are specified in shard_order",
)
)
with test_context:
distribute_tensor(
input_tensor, mesh, placements=placements, shard_order=shard_order_dict
)
# all replicate tensor, test for redistribution
input_tensor_dt = distribute_tensor(input_tensor, mesh)
with test_context:
input_tensor_dt.redistribute(
mesh, placements=placements, shard_order=shard_order_dict
)
@parametrize(
"placements, shard_order_dict, should_pass",
[
[(Shard(0), Shard(0)), {0: [1, 0]}, True],
[(Shard(0), Shard(0)), {3: [1, 0]}, False],
[None, {3: [1, 0]}, False],
],
)
def test_shard_order_out_of_tensor_rank_spec(
self, placements, shard_order_dict, should_pass
):
"""Test that providing only shard_order works correctly."""
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
test_context = (
contextlib.nullcontext()
if should_pass
else self.assertRaisesRegex(
ValueError,
"`shard_order` is out of range for tensor_rank",
)
)
with test_context:
distribute_tensor(
input_tensor, mesh, placements=placements, shard_order=shard_order_dict
)
# all replicate tensor, test for redistribution
input_tensor_dt = distribute_tensor(input_tensor, mesh)
with test_context:
input_tensor_dt.redistribute(
mesh, placements=placements, shard_order=shard_order_dict
)
def test_placement_length_validation_edge_cases(self):
"""Test edge cases for placement length validation."""
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
# Empty placements
with self.assertRaisesRegex(
ValueError,
"`placements` must have the same length",
):
distribute_tensor(input_tensor, mesh, placements=[])
# Too many placements
with self.assertRaisesRegex(
ValueError,
"`placements` must have the same length",
):
distribute_tensor(
input_tensor,
mesh,
placements=[
Shard(0),
Shard(1),
Replicate(),
], # mesh.ndim = 2, but 3 placements
)
@parametrize(
"placements, shard_order_dict, should_pass",
[
[(Shard(0), Shard(2)), {0: [0], 2: [1]}, True],
[(Shard(0), Shard(2)), None, True],
[(Shard(-3), Shard(2)), None, True],
[(Shard(-4), Shard(2)), None, False],
[(Shard(-4), Shard(3)), None, False],
],
)
def test_placement_out_of_tensor_rank_spec(
self, placements, shard_order_dict, should_pass
):
"""Test that providing only shard_order works correctly."""
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
test_context = (
contextlib.nullcontext()
if should_pass
else self.assertRaisesRegex(
ValueError,
"`placements` is out of range for tensor_rank",
)
)
with test_context:
distribute_tensor(
input_tensor, mesh, placements=placements, shard_order=shard_order_dict
)
# all replicate tensor, test for redistribution
input_tensor_dt = distribute_tensor(input_tensor, mesh)
with test_context:
input_tensor_dt.redistribute(
mesh, placements=placements, shard_order=shard_order_dict
)
def test_empty_shard_order_creates_replicated_dtensor(self):
"""Test that empty shard_order creates a replicated DTensor."""
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
empty_shard_order = {}
dt_empty_shard_order = distribute_tensor(
input_tensor, mesh, shard_order=empty_shard_order
)
expected_default_placements = (Replicate(), Replicate())
self.assertEqual(dt_empty_shard_order.placements, expected_default_placements)
self.assertEqual(dt_empty_shard_order.full_tensor(), input_tensor)
# test for redistribution
dt_empty_shard_order = dt_empty_shard_order.redistribute(mesh, shard_order={})
self.assertEqual(dt_empty_shard_order.placements, expected_default_placements)
self.assertEqual(dt_empty_shard_order.full_tensor(), input_tensor)
@parametrize(
"placements, expected_shard_order_tuple",
[
[
(Shard(0), Shard(1)),
{0: [0], 1: [1]},
],
[(Shard(0), Shard(0)), {0: [0, 1]}],
[
(Shard(1), Shard(2)),
{1: [0], 2: [1]},
],
[(Replicate(), Shard(2)), {2: [1]}],
[(Replicate(), Replicate()), {}],
],
)
def test_redistribute_with_placements_only(
self, placements, expected_shard_order_tuple
):
"""Test redistribution using placements only."""
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
dt_default = distribute_tensor(
input_tensor, mesh, placements=(Replicate(), Replicate())
)
dt_redist_placements = dt_default.redistribute(mesh, placements)
self.assertEqual(dt_redist_placements.placements, placements)
self.assertEqual(dt_redist_placements.full_tensor(), input_tensor)
self.assertEqual(dt_redist_placements.shard_order, expected_shard_order_tuple)
@parametrize(
"expected_placements, shard_order_dict",
[
[(Shard(0), Shard(1)), {0: [0], 1: [1]}],
[(Shard(0), Shard(0)), {0: [0, 1]}],
[(Shard(0), Shard(0)), {0: [1, 0]}],
[(Shard(1), Shard(2)), {1: [0], 2: [1]}],
[(Replicate(), Shard(2)), {2: [1]}],
[(Replicate(), Replicate()), {}],
],
)
def test_redistribute_with_shard_order_only(
self, expected_placements, shard_order_dict
):
"""Test redistribution using shard_order only."""
mesh = self.build_device_mesh((2, self.world_size // 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
dt_default = distribute_tensor(
input_tensor, mesh, placements=(Replicate(), Replicate())
)
dt_redist_shard_order = dt_default.redistribute(
mesh, shard_order=shard_order_dict
)
self.assertEqual(dt_redist_shard_order.placements, expected_placements)
self.assertEqual(dt_redist_shard_order.full_tensor(), input_tensor)
def test_special_placement_with_shard_order(self):
"""Test special placement when specify shard_order together."""
mesh = self.build_device_mesh((self.world_size // 2, 2))
input_tensor = torch.randn(8, 6, 5, device=self.device)
# test _StridedShard
dt_default = distribute_tensor(
input_tensor,
mesh,
placements=(_StridedShard(0, split_factor=2), Replicate()),
)
# _StridedShard doesn't have shard_order
self.assertEqual(dt_default.shard_order, {})
with self.assertRaisesRegex(
RuntimeError,
"Cannot specify both `placements` and `shard_order` when `placements` contains `_StridedShard`!",
):
distribute_tensor(
input_tensor,
mesh,
placements=(_StridedShard(0, split_factor=2), Replicate()),
shard_order={0: [0]},
)
# all replicate tensor, test for redistribution
input_tensor_dt = distribute_tensor(input_tensor, mesh)
with self.assertRaisesRegex(
RuntimeError,
"Cannot specify both `placements` and `shard_order` when `placements` contains `_StridedShard`!",
):
input_tensor_dt.redistribute(
mesh,
placements=(_StridedShard(0, split_factor=2), Shard(1)),
shard_order={0: [0]},
)
# test Partial
gathered_tensor = DTensor.from_local(
input_tensor, mesh, placements=(Partial(), Shard(0))
)
self.assertEqual(gathered_tensor.placements, (Partial(), Shard(0)))
self.assertEqual(gathered_tensor.shard_order, {0: [1]})
# can redistribute to Partial from Partial
dt_redist_shard_order = gathered_tensor.redistribute(
mesh, placements=(Partial(), Shard(1)), shard_order={1: [1]}
)
# doesn't allow create new Partial
with self.assertRaisesRegex(
RuntimeError,
"redistributing to Partial is for internal use only",
):
gathered_tensor.redistribute(mesh, placements=(Partial(), Partial()))
# can redistribute from Partial
dt_redist_shard_order = gathered_tensor.redistribute(mesh, shard_order={1: [0]})
self.assertEqual(dt_redist_shard_order.placements, (Shard(1), Replicate()))
instantiate_parametrized_tests(DTensorDeviceOrderAPITest)
if __name__ == "__main__":
run_tests()

View File

@ -219,7 +219,7 @@ class DTensorTest(DTensorTestBase):
dtensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
with self.assertRaisesRegex(
RuntimeError, "the local_tensor argument only accepts torch.Tensor"
ValueError, "the local_tensor argument only accepts torch.Tensor"
):
DTensor.from_local(dtensor, device_mesh, shard_spec)
@ -263,7 +263,7 @@ class DTensorTest(DTensorTestBase):
)
with self.assertRaisesRegex(
RuntimeError, "Please pass both shape and stride at the same time."
ValueError, "Please pass both shape and stride at the same time."
):
DTensor.from_local(
map_local_tensor_for_rank(tensor_list, self.rank, lambda tl, r: tl[r]),
@ -273,7 +273,7 @@ class DTensorTest(DTensorTestBase):
)
with self.assertRaisesRegex(
RuntimeError, "Please pass both shape and stride at the same time."
ValueError, "Please pass both shape and stride at the same time."
):
DTensor.from_local(
map_local_tensor_for_rank(tensor_list, self.rank, lambda tl, r: tl[r]),

View File

@ -249,7 +249,9 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
def test_dropout_errors(self):
device_mesh = self.build_device_mesh()
with self.assertRaisesRegex(RuntimeError, "supported"):
with self.assertRaisesRegex(
RuntimeError, "redistributing to Partial is for internal use only"
):
self._run_sharded_elementwise_ops(
device_mesh=device_mesh,
placements=[Partial("sum")],

View File

@ -5,6 +5,7 @@ import contextlib
import copy
import itertools
import unittest
from unittest.mock import patch
import torch
from torch.distributed.device_mesh import init_device_mesh
@ -18,7 +19,10 @@ from torch.distributed.tensor import (
)
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor._redistribute import (
DTensorRedistributePlanner,
redistribute_local_tensor,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing._internal.common_utils import (
@ -1162,6 +1166,105 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
)
self.assertEqual(x_ordered_dt.to_local(), x_strided_dt.to_local())
@unittest.expectedFailure
@with_comms
@patch.object(DTensorRedistributePlanner, "find_min_cost_path")
def test_with_partial_in_backward(self, mock_find_path):
"""
Test with backward redistribution path contains Partial().
"""
# generate ground truth
input_tensor = torch.randn(8, 8, requires_grad=True)
input_tensor_orig = input_tensor.clone().detach().requires_grad_(True)
loss = input_tensor_orig.sum()
loss.backward()
forward_path = [
DTensorRedistributePlanner.DistState(
placements=(Replicate(), Replicate()),
tensor_dim_to_mesh_dim=(),
),
DTensorRedistributePlanner.DistState(
placements=(Shard(0), Replicate()),
tensor_dim_to_mesh_dim=(),
),
]
backward_path = [
DTensorRedistributePlanner.DistState(
placements=(Replicate(), Replicate()),
tensor_dim_to_mesh_dim=(),
),
# Note: The current DTensor implementation silently skips the
# transition from Replicate to Partial (R->P) during the backward
# pass. This can lead to numerical correctness issues if any
# operations are performed on that mesh dimension after the
# Partial() conversion. If we change the Partial() to something like
# Shard(0), the issue will be resolved. The will not be an issue for
# greedy solution, because it won't generate a path like P->R, but
# this may happen in graph based redistribution.
DTensorRedistributePlanner.DistState(
placements=(Partial(), Replicate()),
tensor_dim_to_mesh_dim=(),
),
DTensorRedistributePlanner.DistState(
placements=(Replicate(), Replicate()),
tensor_dim_to_mesh_dim=(),
),
]
# set side_effect with a list - first call gets first item, second call gets second item
mock_find_path.side_effect = [forward_path, backward_path]
import torch.distributed.tensor._redistribute as redistribute_module
original_redistribute = redistribute_module.redistribute_local_tensor
def force_graph_based(*args, **kwargs):
kwargs["use_graph_based_transform"] = True
return original_redistribute(*args, **kwargs)
def disable_graph_based(*args, **kwargs):
kwargs["use_graph_based_transform"] = False
return original_redistribute(*args, **kwargs)
device_mesh = init_device_mesh(self.device_type, (4, 2))
# disable the graph based path finding in `distribute_tensor`
with patch.object(
redistribute_module,
"redistribute_local_tensor",
side_effect=disable_graph_based,
):
dtensor = distribute_tensor(
input_tensor, device_mesh, [Replicate(), Replicate()]
)
assert mock_find_path.call_count == 0
# enable the graph based path finding in `distribute_tensor`, so that
# mock_find_path.side_effect will be used
with patch.object(
redistribute_module,
"redistribute_local_tensor",
side_effect=force_graph_based,
):
dtensor_sharded = dtensor.redistribute(
device_mesh,
[Shard(0), Replicate()],
)
loss = dtensor_sharded.sum()
assert type(loss) is DTensor
# loss.placement is supposed to be [Partial(), Replicate()], but at
# some place, it silently get updated to [Replicate(), Replicate()].
loss.backward()
# verify forward and backward paths in mock_find_path.side_effect were used
assert mock_find_path.call_count == 2, (
f"Run {mock_find_path.call_count} calls to find_min_cost_path"
)
self.assertEqual(input_tensor_orig.grad, dtensor.grad.full_tensor())
if __name__ == "__main__":
run_tests()

View File

@ -3,8 +3,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import inspect
import warnings
from collections.abc import Callable, Sequence
from typing import Any, cast, Optional
from collections import defaultdict
from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast, Optional, TypeAlias
from typing_extensions import deprecated
import torch
@ -14,7 +15,12 @@ import torch.nn as nn
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._dtensor_spec import (
DTensorSpec,
ShardOrder,
ShardOrderEntry,
TensorMeta,
)
from torch.distributed.tensor._redistribute import (
Redistribute,
redistribute_local_tensor,
@ -48,6 +54,45 @@ __all__ = [
aten = torch.ops.aten
TensorShardingDict: TypeAlias = dict[int, Sequence[int | str]]
r"""
.. _shard_order:
Shard Order
-----------
The ``shard_order`` parameter specifies the mapping of tensor dimensions to the order of device mesh
dimensions they are sharded over. It is a dictionary where keys are tensor dimensions and values
are sequences of mesh dimensions (or mesh dimension names) that the tensor dimension is sharded
across, in execution order.
Internally, this is converted to a ``ShardOrder`` (tuple of ``ShardOrderEntry`` objects) where each
``ShardOrderEntry`` contains a ``tensor_dim`` and ``mesh_dims`` tuple.
**Sparse Specification**
``shard_order`` can be a **sparse specification** - it does not need to mention every tensor dimension
that is sharded. For tensor dimensions not mentioned in ``shard_order``, the default left-to-right mesh
dimension order is used.
**IMPORTANT**: When a tensor dimension IS mentioned in ``shard_order``, it must include ALL mesh
dimensions that shard that tensor dimension in ``placements``. You cannot specify only a subset of
the mesh dimensions for a given tensor dimension.
**Examples**
For ``placements=[Shard(0), Shard(0), Shard(1), Shard(1)]``:
- Valid: ``shard_order={0: [1, 0]}`` - customizes dim 0's order (both mesh dims included), dim 1 defaults to [2, 3]
- Invalid: ``shard_order={0: [1]}`` - ERROR! Tensor dim 0 is sharded on mesh dims [0, 1] but only [1] is specified
- Valid: ``shard_order={0: [1, 0], 1: [3, 2]}`` - fully customizes both dims
If not specified, a default left-to-right sharding order is used.
"""
# NOTE [Autograd interaction between torch.Tensor]
#
# The autograd functions defined below are being used by the public
@ -141,9 +186,9 @@ class _FromTorchTensor(torch.autograd.Function):
)
tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride)
else:
raise RuntimeError(
f"Found shape:{shape}, stride:{stride}.",
"Please pass both shape and stride at the same time.",
raise ValueError(
f"Found shape:{shape}, stride:{stride}. "
"Please pass both shape and stride at the same time."
)
if device_mesh.get_coordinate() is None:
@ -212,6 +257,241 @@ class _FromTorchTensor(torch.autograd.Function):
return grad_output.to_local(), None, None, None, None, None
def _prepare_placements_and_shard_order(
device_mesh: DeviceMesh,
tensor_rank: int,
placements: Optional[Sequence[Placement]] = None,
shard_order: Optional[TensorShardingDict] = None,
) -> tuple[
tuple[Placement, ...],
ShardOrder,
]:
"""
This function places `placements` and `shard_order` in a redundant but
canonical form, whereas in the input users can elide redundant information
when specifying `placements`. For example, when there is never a tensor
dimension that is sharded by multiple mesh dims, `shard_order` can be elided
(this is traditional "PyTorch" style). Similarly, if a user specifies a
`shard_order` that has an entry for every device mesh dim, the `placements`
can be inferred (this is traditional "JAX" style).
You can also specify both arguments, which may be necessary in some
situations as `placements` and `shard_order` have different expressivity.
For example, to express that a placement is something other than `Shard`
(e.g., `Partial`), this can only be specified via the `placements` kwarg. To
express the order of multiple shardings applied to a single tensor
dimension, you must use `shard_order`. If you want to express both of these
things, you will need to use both arguments.
`placements` and `shard_order` must be consistent with each other. When both
are set:
1. For each entry in `shard_order` (tensor_dim: [mesh_dims...]), each
mesh_dim must correspond to a `Shard(tensor_dim)` placement at that
position in `placements`.
2. `shard_order` can be a **sparse specification** - it does not need to
mention every tensor dimension that is sharded. For tensor dimensions not
mentioned in `shard_order`, the default left-to-right mesh dimension
order is used.
**IMPORTANT**: When a tensor dimension IS mentioned in `shard_order`, it
must include ALL mesh dimensions that shard that tensor dimension in
`placements`. You cannot specify only a subset of the mesh dimensions for
a given tensor dimension.
Example: For `placements=[Shard(0), Shard(0), Shard(1), Shard(1)]`:
- Valid: `shard_order={0: [1, 0]}` - customizes dim 0's order (both mesh
dims included), dim 1 defaults to [2, 3]
- Invalid: `shard_order={0: [1]}` - ERROR! Tensor dim 0 is sharded on mesh
dims [0, 1] but only [1] is specified
- Valid: `shard_order={0: [1, 0], 1: [3, 2]}` - fully customizes both dims
3. A mesh dim mentioned in `shard_order` must be a Shard placement; no other
placement types are supported.
In the returned canonical representation, no information is omitted. In
particular, `shard_order` is no longer a dict, it is a sparse tuple with
each inner tuple corresponding to a sharded tensor dimension as the first
element and remaining elements as the indices of device mesh dimensions that
this tensor dimension is sharded over. Check the "Returns" section for more
details.
Args:
device_mesh (:class:`DeviceMesh`): DeviceMesh to place the tensor.
tensor_rank (int): The rank (number of dimensions) of the tensor to be
distributed or redistributed.
placements (Sequence[:class:`Placement`], optional): the placements that
describe how to place the local torch.Tensor on DeviceMesh. Must
have the same number of elements as ``device_mesh.ndim``.
shard_order (dict[int, Sequence[int | str]], optional): See :ref:`shard_order`
Returns:
Tuple:
- placements (Tuple[:class:`Placement`, ...]): The computed
placements as a tuple.
- shard_order (ShardOrder): The computed shard order as a tuple
of ShardOrderEntry objects. Each ShardOrderEntry contains:
* tensor_dim (int): The tensor dimension being sharded
* mesh_dims (tuple[int, ...]): The device mesh dimensions that
this tensor dimension is sharded across, in execution order.
For example, ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 2))
means tensor dimension 1 is sharded first over mesh dimension 0,
then mesh dimension 2. If a tensor dimension is not sharded, it
won't have a corresponding ShardOrderEntry in the tuple.
Raises:
ValueError: If the length of `placements` does not match
`device_mesh.ndim`, if `shard_order` contains invalid tensor or mesh
dimensions, if both normalized and un-normalized tensor_dim are
specified in `shard_order`, or if a tensor_dim or mesh_dim in
`shard_order` is out of range.
RuntimeError: If attempting to redistribute from a non-Partial to a
Partial placement, from one Partial type to a different Partial
type, or use _StridedShard with `shard_order`.
AssertionError: If a placement's shard dim normalization would result in
a negative value, or if there is a conflict between placements and
shard_order for sharding annotation.
"""
def _normalize_tensor_dim(unnormalized_tensor_dim: int) -> int:
if (
unnormalized_tensor_dim < -tensor_rank
or unnormalized_tensor_dim >= tensor_rank
):
raise ValueError(
f"tensor dim {unnormalized_tensor_dim} is out of range for tensor_rank {tensor_rank}."
)
return (
unnormalized_tensor_dim + tensor_rank
if unnormalized_tensor_dim < 0
else unnormalized_tensor_dim
)
def _from_dict_to_ShardOrder(
shard_order_map: dict[int, list[int]],
) -> ShardOrder:
sparse_shard_order = tuple(
ShardOrderEntry(tensor_dim=key, mesh_dims=tuple(value))
for key, value in sorted(shard_order_map.items())
if value
)
return sparse_shard_order
def _convert_shard_order_to_placements(
shard_order_map: Mapping[int, Sequence[int]],
device_mesh: DeviceMesh,
) -> tuple[Placement, ...]:
# convert from shard_order to placements
placements: list[Placement] = [Replicate() for _ in range(device_mesh.ndim)]
for tensor_dim, mesh_dims in shard_order_map.items():
for mesh_dim in mesh_dims:
placements[mesh_dim] = Shard(tensor_dim)
return tuple(placements)
if placements is None and shard_order is None:
placements = [Replicate() for _ in range(device_mesh.ndim)]
if placements is not None and len(placements) != device_mesh.ndim:
raise ValueError(
f"`placements` must have the same length as `device_mesh.ndim`! "
f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}."
)
# normalize tensor dims in shard_order and convert str mesh dim into int
normalized_shard_order: dict[int, list[int]] = {}
has_stridedshard_in_placements = False
if shard_order is not None:
for tensor_dim, mesh_dims in shard_order.items():
tensor_dim = _normalize_tensor_dim(tensor_dim)
if tensor_dim in normalized_shard_order:
raise ValueError(
f"both normalized tensor dim {tensor_dim} and un-normalized "
f"tensor dim {tensor_dim - tensor_rank}) is specified in `shard_order`!"
)
normalized_shard_order[tensor_dim] = []
for mesh_dim in mesh_dims:
if isinstance(mesh_dim, str):
mesh_dim = device_mesh._get_mesh_dim_by_name(mesh_dim)
if mesh_dim < 0 or mesh_dim >= device_mesh.ndim:
raise IndexError(
f"mesh dim {mesh_dim} specified in `shard_order` is out of range "
f"for placements of length {device_mesh.ndim}"
)
normalized_shard_order[tensor_dim].append(mesh_dim)
# set default placements to replicated if not specified
placement_tuple: tuple[Placement, ...]
if placements is None:
assert shard_order is not None
# convert from shard_order to placements
placement_tuple = _convert_shard_order_to_placements(
normalized_shard_order, device_mesh
)
shard_order_tuple = _from_dict_to_ShardOrder(normalized_shard_order)
else:
normalized_placements = list(placements) # type: ignore[assignment]
for i, placement in enumerate(placements):
if placement.is_shard():
tensor_dim = placement.dim # type: ignore[attr-defined]
tensor_dim = _normalize_tensor_dim(tensor_dim)
# reconstruct `placement` object in case it is `_StridedShard` for backward compatibility
if isinstance(placement, _StridedShard):
has_stridedshard_in_placements = True
normalized_placements[i] = _StridedShard(
tensor_dim, split_factor=placement.split_factor
)
else:
assert type(placement) is Shard, (
"Expected placement to be exactly of type Shard"
)
normalized_placements[i] = Shard(tensor_dim)
placement_tuple = tuple(normalized_placements)
if shard_order is None:
shard_order_tuple = DTensorSpec.compute_default_shard_order(placement_tuple)
else:
# both shard_order and placements are specified; need to validate their correctness
if has_stridedshard_in_placements:
# _StridedShard doesn't work with shard_order
raise ValueError(
"Cannot specify both `placements` and `shard_order` when "
"`placements` contains `_StridedShard`!"
)
# Build map from tensor_dim -> list of mesh_dims (in left-to-right order)
tensor_dim_to_mesh_dims: dict[int, list[int]] = defaultdict(list)
for mesh_dim, placement in enumerate(placement_tuple):
if placement.is_shard():
tensor_dim = placement.dim # type: ignore[attr-defined]
tensor_dim_to_mesh_dims[tensor_dim].append(mesh_dim)
# Validate: if a tensor_dim is in shard_order, all its sharding mesh dims must be listed
for tensor_dim, mesh_dims_in_order in normalized_shard_order.items():
expected = set(tensor_dim_to_mesh_dims.get(tensor_dim, []))
specified = set(mesh_dims_in_order)
if specified != expected:
raise ValueError(
f"`shard_order` for tensor dim {tensor_dim} must include ALL mesh "
f"dimensions that shard this tensor dimension in `placements`. "
f"Expected {sorted(expected)}, got {sorted(specified)}."
)
# Build complete shard_order: explicit entries + default left-to-right for unspecified
complete_shard_order: dict[int, list[int]] = dict(normalized_shard_order)
for tensor_dim, mesh_dims in tensor_dim_to_mesh_dims.items():
if tensor_dim not in complete_shard_order:
# Use default left-to-right ordering for unspecified tensor dims
complete_shard_order[tensor_dim] = mesh_dims
shard_order_tuple = _from_dict_to_ShardOrder(complete_shard_order)
return placement_tuple, shard_order_tuple
class DTensor(torch.Tensor):
"""
``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like
@ -339,6 +619,7 @@ class DTensor(torch.Tensor):
kwargs or {},
)
# TODO(zpcore): support `shard_order` argument
@staticmethod
def from_local(
local_tensor: torch.Tensor,
@ -380,6 +661,10 @@ class DTensor(torch.Tensor):
Returns:
A :class:`DTensor` object
Raises:
ValueError: If both ``shape`` and ``stride`` are not provided together,
or if the device mesh does not contain the current rank.
.. note:: When ``run_check=False``, it is the user's responsibility to ensure the
local tensor passed in is correct across ranks (i.e. the tensor is sharded for
the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement).
@ -390,7 +675,7 @@ class DTensor(torch.Tensor):
"""
# `local_tensor` argument cannot be DTensor
if isinstance(local_tensor, DTensor):
raise RuntimeError(
raise ValueError(
f"the local_tensor argument only accepts torch.Tensor but got {type(local_tensor)} value."
)
@ -430,6 +715,7 @@ class DTensor(torch.Tensor):
stride,
)
# TODO(zpcore): support `shard_order` argument for grad
def to_local(
self, *, grad_placements: Optional[Sequence[Placement]] = None
) -> torch.Tensor:
@ -455,8 +741,12 @@ class DTensor(torch.Tensor):
it means the local tensor is not ready yet (i.e. communication is not finished). In this
case, user needs to call ``wait`` to wait the local tensor to be ready.
Raises:
ValueError: If ``grad_placements`` has a different length than the device mesh dimensions.
.. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned
will depend on if the `DTensor` requires_grad or not.
"""
if not torch.is_grad_enabled():
return self._local_tensor
@ -471,6 +761,7 @@ class DTensor(torch.Tensor):
self,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
shard_order: Optional[TensorShardingDict] = None,
*,
async_op: bool = False,
forward_dtype: Optional[torch.dtype] = None,
@ -503,6 +794,7 @@ class DTensor(torch.Tensor):
describes how to place the DTensor into the DeviceMesh, must
have the same number of elements as ``device_mesh.ndim``.
default: replicate on all mesh dimensions
shard_order (dict[int, Sequence[int | str]], optional): See :ref:`shard_order`
Keyword args:
async_op (bool, optional): whether to perform the DTensor redistribute operation
@ -517,11 +809,20 @@ class DTensor(torch.Tensor):
Returns:
A :class:`DTensor` object
Raises:
RuntimeError: If attempting to redistribute from a non-Partial to a Partial placement,
or from one Partial type to a different Partial type.
ValueError: If ``placements`` has a different length than the device mesh dimensions,
if ``shard_order`` contains invalid tensor or mesh dimensions, or if both normalized and un-normalized
tensor_dim are specified in ``shard_order``, or if a tensor_dim in ``shard_order`` is out of range.
IndexError: If a mesh_dim specified in ``shard_order`` is out of range for the device mesh.
.. note:: ``redistribute`` is differentiable, which means user do not need to worry about
the backward formula of the redistribute operation.
.. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh,
Please file an issue if you need to redistribute DTensor to different DeviceMesh.
"""
# NOTE: This redistribute API currently only supports out
# of place redistribution, i.e. it always create a new
@ -529,25 +830,29 @@ class DTensor(torch.Tensor):
# if device_mesh is not specified, use the current device_mesh
device_mesh = device_mesh or self.device_mesh
# raise error if new placements not specified
if placements is None:
raise RuntimeError("placements is needed for redistribute!")
placements = list(placements)
for i, placement in enumerate(placements):
if placement.is_partial() and self.placements[i] != placement:
raise RuntimeError(
f"Can not redistribute from {self.placements[i]} to {placement}, "
"redistributing to Partial is for internal use only!"
)
elif isinstance(placement, Shard) and placement.dim < 0:
# normalize shard dim to be positive
placements[i] = Shard(placement.dim + self.ndim)
placements = tuple(placements)
# handle the special case where `Partial` is allowed if we are redistributing to
# the same type of `Partial`
if placements is not None:
for i, placement in enumerate(placements):
if placement.is_partial() and self.placements[i] != placement:
raise RuntimeError(
f"Can not redistribute from {self.placements[i]} to {placement}, "
"redistributing to Partial is for internal use only!"
)
placements_tuple, shard_order_tuple = _prepare_placements_and_shard_order(
device_mesh, self.ndim, placements, shard_order
)
# pyre-fixme[16]: `Redistribute` has no attribute `apply`.
return Redistribute.apply(
self, device_mesh, placements, async_op, forward_dtype, backward_dtype
self,
device_mesh,
placements_tuple,
shard_order_tuple,
async_op,
forward_dtype,
backward_dtype,
)
def full_tensor(
@ -579,7 +884,7 @@ class DTensor(torch.Tensor):
redist_res = self.redistribute(
placements=[Replicate()] * self.device_mesh.ndim, async_op=False
)
return _ToTorchTensor.apply(redist_res, grad_placements)
return _ToTorchTensor.apply(redist_res, grad_placements) # type: ignore[return-value]
@property
def device_mesh(self) -> DeviceMesh:
@ -613,6 +918,51 @@ class DTensor(torch.Tensor):
"DTensor with partial placements!"
)
@property
def shard_order(self) -> TensorShardingDict:
"""
The shard order of this DTensor, which specifies how tensor dimensions
are sharded across device mesh dimensions.
When a tensor dimension is sharded across multiple mesh dimensions,
``shard_order`` specifies the sequence in which these shardings are
applied. This order determines how tensor shards are distributed across devices.
Returns:
A dictionary (:class:`TensorShardingDict`) mapping tensor dimensions (int) to
sequences of mesh dimensions (list[int]). Each entry indicates that the
corresponding tensor dimension is sharded across the specified mesh dimensions
in the given order.
For example, ``{0: [1, 2], 1: [0]}`` means:
- Tensor dimension 0 is sharded first over mesh dimension 1, then mesh dimension 2
- Tensor dimension 1 is sharded over mesh dimension 0
Example:
For a tensor of shape [8, 16, 32] on a 3D device mesh with
``placements=[Shard(1), Shard(0), Shard(0)]``, the shard_order would be::
# Tensor dim 0 is sharded over mesh dims 1 and 2
# Tensor dim 1 is sharded over mesh dim 0
# Tensor dim 2 is not in the dict, meaning it's replicated over all mesh dims
dtensor.shard_order # Returns: {0: [1, 2], 1: [0]}
A tensor dimension not appearing as a key in the returned dictionary means
that dimension is replicated (not sharded) across all device mesh dimensions.
.. note:: ``shard_order`` is a read-only property derived from the DTensor's
internal :class:`ShardOrder` specification (a tuple of :class:`ShardOrderEntry`
objects). The returned dictionary provides a more user-friendly interface
to the underlying shard order information.
.. note:: If you need to redistribute a DTensor with a different shard order,
use the :meth:`redistribute` method with the ``shard_order`` parameter.
"""
tensor_mesh_dim_dict = defaultdict(list)
for entry in self._spec.shard_order:
tensor_mesh_dim_dict[entry.tensor_dim] = list(entry.mesh_dims)
return TensorShardingDict(tensor_mesh_dim_dict)
def __create_write_items__(self, fqn: str, object: Any):
self._raise_if_contains_partial_placements()
from torch.distributed.checkpoint.planner_helpers import (
@ -674,6 +1024,7 @@ def distribute_tensor(
tensor: torch.Tensor,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
shard_order: Optional[TensorShardingDict] = None,
*,
src_data_rank: Optional[int] = 0,
) -> DTensor:
@ -694,11 +1045,35 @@ def distribute_tensor(
device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the
tensor, if not specified, must be called under a DeviceMesh context
manager, default: None
placements (List[:class:`Placement`], optional): the placements that
placements (Sequence[:class:`Placement`], optional): the placements that
describes how to place the tensor on DeviceMesh, must have the same
number of elements as ``device_mesh.ndim``. If not specified, we will
by default replicate the tensor across the ``device_mesh`` from the
first rank of each dimension of the `device_mesh`.
first rank of each dimension of the ``device_mesh``.
shard_order (dict[int, Sequence[int | str]], optional):
Specifies the mapping of tensor dimensions to the order of device mesh
dimensions they are sharded over. It is a dictionary where keys are tensor
dimensions and values are sequences of mesh dimensions (or mesh dimension names)
that the tensor dimension is sharded across, in execution order.
``shard_order`` can be a **sparse specification** - it does not need to
mention every tensor dimension that is sharded. For tensor dimensions not
mentioned in ``shard_order``, the default left-to-right mesh dimension order
is used.
**IMPORTANT**: When a tensor dimension IS mentioned in ``shard_order``, it must
include ALL mesh dimensions that shard that tensor dimension in ``placements``.
You cannot specify only a subset of the mesh dimensions for a given tensor
dimension.
If not specified, a default left-to-right sharding order is used.
Example: For ``placements=[Shard(0), Shard(0), Shard(1), Shard(1)]``:
- Valid: ``shard_order={0: [1, 0]}`` - customizes dim 0's order (both mesh dims
included), dim 1 defaults to [2, 3]
- Invalid: ``shard_order={0: [1]}`` - ERROR! Tensor dim 0 is sharded on mesh
dims [0, 1] but only [1] is specified
- Valid: ``shard_order={0: [1, 0], 1: [3, 2]}`` - fully customizes both dims
Keyword args:
src_data_rank (int, optional): the rank of the source data for the logical/global tensor, it is
@ -722,6 +1097,19 @@ def distribute_tensor(
# get default device mesh if there's nothing specified
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
device_type = device_mesh.device_type
if placements is not None:
for placement in placements:
if placement.is_partial():
raise RuntimeError(
f"Can not distribute to {placements}, "
"redistributing to Partial is for internal use only!"
)
placements_tuple, shard_order_tuple = _prepare_placements_and_shard_order(
device_mesh, tensor.ndim, placements, shard_order
)
if device_type == "xla":
try:
# call PyTorch/XLA SPMD for `xla` backend type device mesh.
@ -730,7 +1118,14 @@ def distribute_tensor(
xla_distribute_tensor,
)
return xla_distribute_tensor(tensor, device_mesh, placements) # type:ignore[return-value]
# TODO: update the XLA API to accept shard order information
if not DTensorSpec.is_default_device_order(shard_order_tuple):
raise RuntimeError(
"The xla_distribute_tensor API currently only supports the "
"default left-to-right device ordering. "
"Support for custom shard_order is not yet implemented."
)
return xla_distribute_tensor(tensor, device_mesh, placements_tuple) # type:ignore[return-value]
except ImportError as e:
msg = "To use DTensor API with xla, you must install the torch_xla package!"
raise ImportError(msg) from e
@ -744,15 +1139,6 @@ def distribute_tensor(
if device_type != tensor.device.type and not tensor.is_meta:
tensor = tensor.to(device_type)
# set default placements to replicated if not specified
if placements is None:
placements = [Replicate() for _ in range(device_mesh.ndim)]
if len(placements) != device_mesh.ndim:
raise ValueError(
f"`placements` must have the same length as `device_mesh.ndim`! "
f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}."
)
if isinstance(tensor, DTensor):
# if the tensor is already a DTensor, we need to check:
# 1. if the we can further shard this DTensor if the two device mesh belong to
@ -763,63 +1149,89 @@ def distribute_tensor(
f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} "
f"to a different device mesh {device_mesh}."
)
if tensor.placements != tuple(placements):
# TODO(zpcore): make sure the shard_order also matches.
if tensor.placements != placements_tuple:
raise ValueError(
f"Cannot distribute a DTensor with placements {tensor.placements} "
f"to a different placements {placements}. do you want to call "
f"to a different placements {placements_tuple}. do you want to call "
f"`redistribute` instead?"
)
return tensor
local_tensor = tensor.detach()
# TODO(xilun): address sharding order
# distribute the tensor according to the placements.
placements = list(placements)
for idx, placement in enumerate(placements):
if isinstance(placement, Shard):
placement_dim = (
placement.dim + tensor.ndim if placement.dim < 0 else placement.dim
)
if isinstance(placement, _StridedShard):
local_tensor = _StridedShard._make_shard_tensor(
placement_dim,
local_tensor,
device_mesh,
idx,
src_data_rank,
split_factor=placement.split_factor,
)
placements[idx] = _StridedShard(
placement_dim, split_factor=placement.split_factor
use_strided_shard = placements is not None and any(
isinstance(p, _StridedShard) for p in placements
)
if use_strided_shard:
# keep original code for backward compatibility considering
# _StridedShard case
assert shard_order is None, "shard_order conflicts with _StridedShard"
for mesh_dim, placement in enumerate(placements_tuple):
if isinstance(placement, Shard):
if isinstance(placement, _StridedShard):
local_tensor = _StridedShard._make_shard_tensor(
placement.dim,
local_tensor,
device_mesh,
mesh_dim,
src_data_rank,
split_factor=placement.split_factor,
)
else:
local_tensor = Shard._make_shard_tensor(
placement.dim,
local_tensor,
device_mesh,
mesh_dim,
src_data_rank,
)
elif isinstance(placement, Replicate):
local_tensor = Replicate._make_replicate_tensor(
local_tensor, device_mesh, mesh_dim, src_data_rank
)
else:
local_tensor = Shard._make_shard_tensor(
placement_dim, local_tensor, device_mesh, idx, src_data_rank
raise RuntimeError(
f"Trying to distribute tensor with unsupported placements {placement} "
f"on device mesh dimension {mesh_dim}!"
)
else:
replicate_on_mesh_dims = set(range(device_mesh.ndim))
for entry in shard_order_tuple:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for mesh_dim in mesh_dims:
assert isinstance(mesh_dim, int)
replicate_on_mesh_dims.remove(mesh_dim)
local_tensor = Shard._make_shard_tensor(
tensor_dim, local_tensor, device_mesh, mesh_dim, src_data_rank
)
for mesh_dim in replicate_on_mesh_dims:
if placements_tuple[mesh_dim].is_replicate():
local_tensor = Replicate._make_replicate_tensor(
local_tensor, device_mesh, mesh_dim, src_data_rank
)
else:
raise RuntimeError(
f"Trying to distribute tensor with unsupported placements "
f"{placements_tuple[mesh_dim]} on device mesh dimension {mesh_dim}!"
)
placements[idx] = Shard(placement_dim)
elif isinstance(placement, Replicate):
local_tensor = Replicate._make_replicate_tensor(
local_tensor, device_mesh, idx, src_data_rank
)
else:
raise RuntimeError(
f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!"
)
placements = tuple(placements)
assert local_tensor is not None, "distributing a tensor should not be None"
# detach the local tensor passed to DTensor since after the construction
# of DTensor, autograd would work on top of DTensor instead of local tensor
spec = DTensorSpec(
mesh=device_mesh,
placements=placements,
placements=placements_tuple,
shard_order=shard_order_tuple,
tensor_meta=TensorMeta(
shape=tensor.size(),
stride=tensor.stride(),
dtype=tensor.dtype,
),
)
return DTensor(
local_tensor.requires_grad_(tensor.requires_grad),
spec,
@ -858,7 +1270,7 @@ def _shard_tensor(
Examples:
>>> # xdoctest: +SKIP("need world_size and rank")
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
>>> device_mesh = torch.distributed.init_device_mesh("cuda", (world_size,))
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
>>> dtensor = _shard_tensor(full_tensor, [Shard(1)], device_mesh)
"""
@ -1084,6 +1496,7 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
)
# TODO(zpcore): support `shard_order` argument for those factory ops
def ones( # type: ignore[no-untyped-def]
*size,
dtype: Optional[torch.dtype] = None,

View File

@ -845,6 +845,7 @@ class Redistribute(torch.autograd.Function):
input: "dtensor.DTensor",
device_mesh: DeviceMesh,
placements: tuple[Placement, ...],
shard_order: Optional[ShardOrder] = None,
async_op: bool = False,
forward_dtype: Optional[torch.dtype] = None,
backward_dtype: Optional[torch.dtype] = None,
@ -852,7 +853,6 @@ class Redistribute(torch.autograd.Function):
ctx.async_op = async_op
ctx.backward_dtype = backward_dtype
ctx.original_dtype = input._local_tensor.dtype
if forward_dtype is not None and forward_dtype != input._local_tensor.dtype:
local_tensor = input._local_tensor.to(dtype=forward_dtype)
current_spec = DTensorSpec(
@ -863,6 +863,7 @@ class Redistribute(torch.autograd.Function):
stride=input.stride(),
dtype=forward_dtype,
),
shard_order=input._spec.shard_order,
)
else:
local_tensor = input._local_tensor
@ -870,11 +871,22 @@ class Redistribute(torch.autograd.Function):
ctx.current_spec = current_spec
if current_spec.placements != placements:
target_spec = DTensorSpec(
device_mesh, placements, tensor_meta=current_spec.tensor_meta
)
shard_order = (
DTensorSpec.compute_default_shard_order(placements)
if shard_order is None
else shard_order
)
if (
current_spec.placements != placements
or current_spec.shard_order != shard_order
):
target_spec = DTensorSpec(
device_mesh,
placements,
tensor_meta=current_spec.tensor_meta,
shard_order=shard_order,
)
output = redistribute_local_tensor(
local_tensor, current_spec, target_spec, async_op=async_op
)
@ -926,6 +938,10 @@ class Redistribute(torch.autograd.Function):
if output.dtype != ctx.original_dtype:
output = output.to(ctx.original_dtype)
# TODO(zpcore): During backward, some Partial related transform ops got
# silently skipped. This will be an issue for the graph-based
# redistribute planner. Need fix.
# normalize the target placement to replicate if it is partial
normalized_placements: list[Placement] = []
for previous_placement in previous_spec.placements:
@ -943,6 +959,8 @@ class Redistribute(torch.autograd.Function):
stride=grad_output.stride(),
dtype=output.dtype,
),
# this is subject to be wrong if we skip Partial() transform
shard_order=previous_spec.shard_order,
)
output_dtensor = dtensor.DTensor(
output,
@ -957,4 +975,5 @@ class Redistribute(torch.autograd.Function):
None,
None,
None,
None,
)