Compare commits

...

3 Commits

Author SHA1 Message Date
61a2aab4bd fix bug 2025-11-13 17:40:48 -08:00
d7992eb16e why we need this fix? Doesn't make sense 2025-11-13 14:24:36 -08:00
9dd965c1ec [DTensor] Fix silently skipping Partial related redistribution in backward 2025-11-13 14:24:35 -08:00
6 changed files with 326 additions and 117 deletions

View File

@ -4,6 +4,7 @@
import contextlib
import itertools
import unittest
from unittest.mock import patch
import torch
from torch.distributed._local_tensor import (
@ -21,6 +22,7 @@ from torch.distributed.tensor import (
)
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor._redistribute import DTensorRedistributePlanner
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial
from torch.testing._internal.common_utils import (
@ -1063,6 +1065,110 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
)
self.assertEqual(x_ordered_dt.to_local(), x_strided_dt.to_local())
@with_comms
@patch.object(DTensorRedistributePlanner, "find_min_cost_path")
def test_with_partial_in_backward(self, mock_find_path):
"""
Test with backward redistribution path contains Partial().
"""
# generate ground truth
input_tensor = torch.randn(8, 8, requires_grad=True)
input_tensor_orig = input_tensor.clone().detach().requires_grad_(True)
loss = input_tensor_orig.sum()
loss.backward()
forward_path = [
DTensorRedistributePlanner.DistState(
placements=(Replicate(), Replicate()),
tensor_dim_to_mesh_dim=(),
),
DTensorRedistributePlanner.DistState(
placements=(Shard(0), Replicate()),
tensor_dim_to_mesh_dim=(),
),
]
backward_path = [
DTensorRedistributePlanner.DistState(
placements=(Replicate(), Replicate()),
tensor_dim_to_mesh_dim=(),
),
# Note: The current DTensor implementation silently skips the
# transition from Replicate to Partial (R->P) during the backward
# pass. This can lead to numerical correctness issues if any
# operations are performed on that mesh dimension after the
# Partial() conversion. If we change the Partial() to something like
# Shard(0), the issue will be resolved. The will not be an issue for
# greedy solution, because it won't generate a path like P->R, but
# this may happen in graph based redistribution.
DTensorRedistributePlanner.DistState(
placements=(Partial(), Replicate()),
tensor_dim_to_mesh_dim=(),
),
DTensorRedistributePlanner.DistState(
placements=(Replicate(), Replicate()),
tensor_dim_to_mesh_dim=(),
),
]
# set side_effect with a list - first call gets first item, second call gets second item
mock_find_path.side_effect = [forward_path, backward_path]
import torch.distributed.tensor._redistribute as redistribute_module
original_redistribute = redistribute_module.redistribute_local_tensor
def force_graph_based(*args, **kwargs):
kwargs["use_graph_based_transform"] = True
return original_redistribute(*args, **kwargs)
def disable_graph_based(*args, **kwargs):
kwargs["use_graph_based_transform"] = False
return original_redistribute(*args, **kwargs)
device_mesh = init_device_mesh(self.device_type, (4, 2))
# disable the graph based path finding in `distribute_tensor`
with patch.object(
redistribute_module,
"redistribute_local_tensor",
side_effect=disable_graph_based,
):
# when set disable_graph_based, it won't consume the mock_find_path.side_effect
dtensor = distribute_tensor(
input_tensor, device_mesh, [Replicate(), Replicate()]
)
assert mock_find_path.call_count == 0
# enable the graph based path finding in `distribute_tensor`, so that
# mock_find_path.side_effect will be used
with patch.object(
redistribute_module,
"redistribute_local_tensor",
side_effect=force_graph_based,
):
with DebugMode() as debug_mode:
dtensor_sharded = dtensor.redistribute(
device_mesh,
[Shard(0), Replicate()],
)
loss = dtensor_sharded.sum()
assert type(loss) is DTensor
# loss.placement is supposed to be [Partial(), Replicate()], but at
# some place, it silently get updated to [Replicate(), Replicate()].
loss.backward()
# verify forward and backward paths in mock_find_path.side_effect were used
assert mock_find_path.call_count == 2, (
f"Run {mock_find_path.call_count} calls to find_min_cost_path"
)
trace_str = self._extract_redistribute_trace_from_debug_mode(
debug_mode.debug_string()
)
# Partial should be removed after adjustment
self.assertTrue("P" not in trace_str)
self.assertEqual(input_tensor_orig.grad, make_full_tensor(dtensor.grad))
RedistributeTestWithLocalTensor = create_local_tensor_test_class(
RedistributeTest,

View File

@ -206,7 +206,7 @@ class _FromTorchTensor(torch.autograd.Function):
tensor_meta=grad_output._spec.tensor_meta,
)
local_tensor = grad_output._local_tensor
output = redistribute_local_tensor(
output, _ = redistribute_local_tensor(
local_tensor, current_spec, target_spec, is_backward=True
)
# TODO: return the redistributed local tensor directly without

View File

@ -417,7 +417,7 @@ class OpDispatcher:
f"Implicit redistribution occurred for {op_info.schema} while ExplicitRedistributionContext was active"
)
with redistribute_context:
resharded_local_tensor = redistribute_local_tensor(
resharded_local_tensor, _ = redistribute_local_tensor(
local_tensor,
arg_spec,
# pyrefly: ignore [bad-argument-type]

View File

@ -685,6 +685,166 @@ def _gen_transform_infos(
)
def _maybe_adjust_transform_info_inplace(
src_placement: Sequence[Placement],
device_mesh: DeviceMesh,
transform_infos: list[_TransformInfo],
ignore_partial_to_replicate: bool = False,
) -> list[Placement]:
"""
Post-process and optimize transformation infos by applying adjustment rules.
This function modifies the transform_infos list in-place to optimize the redistribution
plan by skipping unnecessary transformations and preventing invalid operations. It applies
several adjustment rules to filter out redundant or problematic transformations while
tracking the updated placements after each transformation step.
Args:
src_placement: Initial sequence of placements before any transformations.
device_mesh: The device mesh used for the redistribution.
transform_infos: List of transformation steps to be adjusted in-place. After this
function returns, this list will contain only the necessary transformations.
ignore_partial_to_replicate: If True (typically in backward pass), prevents
Replicate -> Partial transformations to keep gradients replicated rather than
converting to partial. This optimization avoids unnecessary conversions since
gradients would need to be reduced later anyway. Default is False.
Returns:
A list of final placements after applying all transformations and adjustments.
This represents the actual ending state, which may differ from the original
target placements due to optimizations.
"""
new_transforms_infos = []
updated_placements: list[Placement] = list(src_placement[:])
for transform_info in transform_infos:
i = transform_info.mesh_dim
current, target = transform_info.src_dst_placements
# we did adjustment to placements[i], need to use the updated value instead of using current placement
current = updated_placements[i]
assert current is not None
num_chunks = device_mesh.size(mesh_dim=i)
# adjust rule 1: skip when current and target placement are the same
if current == target:
continue
# adjust rule 2: skip when not sharded in this dim
if num_chunks == 1:
# short cut, if there's only one shard, we don't need to do any collective
# comm, just use the original local tensor
updated_placements[i] = target
continue
if ignore_partial_to_replicate:
# Skip the replicate to partial transformation when we are in
# backward pass. In this case we keep the grad as replicate, this
# is because we don't want to convert the replicated gradients
# back to partial, although that's logically conform with the
# same layout, converting the gradients back to partial is
# actually useless as you would have to do reduce later which
# would be more expensive than keeping it replicate! For this
# reason, we keep the replicate grad here.
# adjust rule 3: no Replicate -> Partial
if target.is_partial() and current.is_replicate():
updated_placements[i] = Replicate()
continue
# adjust rule 4: no Shard -> Partial, redistribute with Replicate()
if target.is_partial() and current.is_shard():
updated_placements[i] = Replicate()
new_transforms_infos.append(
_TransformInfo(
mesh_dim=i,
src_dst_placements=(current, Replicate()),
logical_shape=transform_info.logical_shape,
)
)
continue
else:
# adjust rule 5: safe check, we should not see Shard -> Partial in forward
if target.is_partial() and current.is_shard():
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
updated_placements[i] = target
new_transforms_infos.append(transform_info)
transform_infos[:] = new_transforms_infos
return updated_placements
def _execute_transform(
local_tensor: torch.Tensor,
device_mesh: DeviceMesh,
transform_infos: list[_TransformInfo],
async_op: bool,
) -> torch.Tensor:
my_coordinate = device_mesh.get_coordinate()
assert my_coordinate is not None
new_local_tensor = local_tensor
for transform_info in transform_infos:
i = transform_info.mesh_dim
current, target = transform_info.src_dst_placements
if target.is_replicate():
# Case 1: target is Replicate
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_value(
local_tensor, device_mesh, i
)
elif current.is_shard():
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
elif target.is_shard():
# Case 2: target is Shard
target_placement = cast(Shard, target)
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_shard_value(
local_tensor, device_mesh, i, target_placement
)
elif current.is_replicate():
# split the tensor and return the corresponding cloned local shard
new_local_tensor = target_placement._replicate_to_shard(
local_tensor, device_mesh, i, my_coordinate[i]
)
else:
assert current.is_shard(), (
f"Current placement should be shard but found {current}"
)
shard_spec = cast(Shard, current)
if shard_spec.dim != target_placement.dim:
new_local_tensor = shard_spec._to_new_shard_dim(
local_tensor,
device_mesh,
i,
transform_info.logical_shape,
target_placement.dim,
)
elif target.is_partial():
if current.is_replicate():
partial_spec = cast(Partial, target)
new_local_tensor = partial_spec._partition_value(
local_tensor, device_mesh, i
)
elif current.is_shard():
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
else:
# partial -> partial no op, should never hit
new_local_tensor = local_tensor
if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
new_local_tensor = new_local_tensor.wait()
local_tensor = new_local_tensor
return new_local_tensor
def redistribute_local_tensor(
local_tensor: torch.Tensor,
current_spec: DTensorSpec,
@ -693,11 +853,36 @@ def redistribute_local_tensor(
async_op: bool = False,
is_backward: bool = False,
use_graph_based_transform: Optional[bool] = None,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[list[Placement]]]:
"""
This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to
the target DTensorSpec, which involves the necessary collective calls to transform
the local shard of the DTensor from its current spec to the target spec.
Redistribute a local tensor from current DTensorSpec to target DTensorSpec.
This function transforms the local shard of a DTensor by performing the necessary
collective communications to change from current placements to target placements.
The redistribution may involve multiple transformation steps such as all-gather,
reduce-scatter, all-to-all, and all-reduce operations.
Args:
local_tensor: The local tensor shard to redistribute.
current_spec: The current DTensorSpec describing the local tensor's distribution.
target_spec: The target DTensorSpec describing the desired distribution.
async_op: If True, collective operations will be asynchronous. Default is False.
is_backward: If True, indicates this redistribution is happening in the backward
pass. This affects how certain transformations are handled (e.g., skipping
Replicate -> Partial transformations). Default is False.
use_graph_based_transform: If True, uses graph-based search algorithm to find
optimal redistribution path. If False, uses greedy algorithm. If None,
automatically determined based on shard order. Default is None.
Returns:
A tuple containing:
- The redistributed local tensor after applying transformations.
- A list of final placements after redistribution (may differ from
target_spec.placements due to optimization adjustments), or None
if the current rank is not part of the device mesh.
Raises:
NotImplementedError: If current_spec and target_spec have different device meshes.
"""
if current_spec.mesh != target_spec.mesh:
@ -712,7 +897,7 @@ def redistribute_local_tensor(
if my_coordinate is None:
# if rank is not part of mesh, we skip redistribute and simply return local_tensor,
# which should be an empty tensor
return local_tensor
return local_tensor, None
if _are_we_tracing():
transform_infos = _gen_transform_infos_non_cached(
@ -723,12 +908,19 @@ def redistribute_local_tensor(
current_spec, target_spec, use_graph_based_transform
)
end_placements = _maybe_adjust_transform_info_inplace(
current_spec.placements,
device_mesh,
transform_infos,
ignore_partial_to_replicate=is_backward,
)
debug_mode = get_active_debug_mode()
redistribute_context = (
debug_mode.record_redistribute_calls( # type: ignore[union-attr]
local_tensor,
current_spec.placements,
target_spec.placements,
end_placements, # not target_spec.placements because placements maybe updated
DTensorRedistributePlanner.stringify_transform_infos(
device_mesh,
transform_infos,
@ -741,100 +933,10 @@ def redistribute_local_tensor(
)
with redistribute_context:
for transform_info in transform_infos:
i = transform_info.mesh_dim
current, target = transform_info.src_dst_placements
num_chunks = device_mesh.size(mesh_dim=i)
if current == target:
# short cut, just use the original local tensor
new_local_tensor = local_tensor
continue
if num_chunks == 1:
# short cut, if there's only one shard, we don't need to do any collective
# comm, just use the original local tensor
new_local_tensor = local_tensor
continue
if target.is_replicate():
# Case 1: target is Replicate
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_value(
local_tensor, device_mesh, i
)
elif current.is_shard():
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
elif target.is_shard():
# Case 2: target is Shard
target_placement = cast(Shard, target)
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_shard_value(
local_tensor, device_mesh, i, target_placement
)
elif current.is_replicate():
# split the tensor and return the corresponding cloned local shard
new_local_tensor = target_placement._replicate_to_shard(
local_tensor, device_mesh, i, my_coordinate[i]
)
else:
assert current.is_shard(), (
f"Current placement should be shard but found {current}"
)
shard_spec = cast(Shard, current)
if shard_spec.dim != target_placement.dim:
new_local_tensor = shard_spec._to_new_shard_dim(
local_tensor,
device_mesh,
i,
transform_info.logical_shape,
target_placement.dim,
)
elif target.is_partial():
if current.is_replicate():
partial_spec = cast(Partial, target)
# skip the replicate to partial transformation when we are in backward pass
# In this case we keep the grad as replicate, this is because we don't
# want to convert the replicated gradients back to partial, although
# that's logically conform with the same layout, converting the gradients
# back to partial is actually useless as you would have to do reduce later
# which would be more expensive than keeping it replicate! For this reason,
# we keep the replicate grad here.
new_local_tensor = (
partial_spec._partition_value(local_tensor, device_mesh, i)
if not is_backward
else local_tensor
)
elif current.is_shard():
if not is_backward:
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
# for backward shard -> partial, we just need to convert the shard to replicate
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
# partial -> partial no op, should never hit
new_local_tensor = local_tensor
if not async_op and isinstance(
new_local_tensor, funcol.AsyncCollectiveTensor
):
new_local_tensor = new_local_tensor.wait()
local_tensor = new_local_tensor
return new_local_tensor
new_local_tensor = _execute_transform(
local_tensor, device_mesh, transform_infos, async_op
)
return new_local_tensor, end_placements
class Redistribute(torch.autograd.Function):
@ -875,9 +977,18 @@ class Redistribute(torch.autograd.Function):
device_mesh, placements, tensor_meta=current_spec.tensor_meta
)
output = redistribute_local_tensor(
output, end_placements = redistribute_local_tensor(
local_tensor, current_spec, target_spec, async_op=async_op
)
if end_placements is not None and not tuple(end_placements) == tuple(
target_spec.placements
):
raise RuntimeError(
f"End placements {end_placements} do not match target placements "
f"{target_spec.placements} in Redistribute forward. "
"This can be caused by _TransformInfo used in redistribution "
"be incorrectly adjusted."
)
else:
# use the same local tensor if placements are the same.
output = local_tensor
@ -918,29 +1029,20 @@ class Redistribute(torch.autograd.Function):
local_tensor = grad_output._local_tensor
current_spec = grad_output._spec
output = redistribute_local_tensor(
output, end_placements = redistribute_local_tensor(
local_tensor,
current_spec,
previous_spec,
async_op=async_op,
is_backward=True,
)
if output.dtype != ctx.original_dtype:
output = output.to(ctx.original_dtype)
# normalize the target placement to replicate if it is partial
normalized_placements: list[Placement] = []
for previous_placement in previous_spec.placements:
if previous_placement.is_partial():
# keep target placement to replicate instead of partial in this case
normalized_placements.append(Replicate())
else:
normalized_placements.append(previous_placement)
assert end_placements is not None
spec = DTensorSpec(
previous_spec.device_mesh,
tuple(normalized_placements),
tuple(end_placements),
tensor_meta=TensorMeta(
shape=grad_output.shape,
stride=grad_output.stride(),

View File

@ -452,11 +452,12 @@ def _insert_reshard_gm(
# insert reshard operation
def reshard_fn(local_tensor: torch.Tensor) -> torch.Tensor:
return redistribute_local_tensor(
new_tensor, _ = redistribute_local_tensor(
local_tensor,
input_arg_spec,
desired_spec,
)
return new_tensor
reshard_gm = make_fx(reshard_fn)(input_arg_tensor)
reshard_gm_nodes = list(reshard_gm.graph.nodes)

View File

@ -861,7 +861,7 @@ def redistribute(
old_spec,
new_spec,
use_graph_based_transform=use_graph_based_transform,
),
)[0],
device_mesh,
)
dtensor_input._spec = copy.deepcopy(new_spec)