Record redistribute_local_tensor in DebugMode (#163704)

Explicit redistribute_local_tensor API call could also results in communication, record it!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163704
Approved by: https://github.com/ezyang
This commit is contained in:
Sherlock Huang
2025-09-24 16:11:22 +00:00
committed by PyTorch MergeBot
parent 5d0f639234
commit 4c2c401ccf
5 changed files with 143 additions and 111 deletions

View File

@ -50,8 +50,9 @@ class TestDTensorDebugMode(TestCase):
torch.mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
aten::mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
redistribute_input(1, [S(0)] -> [R])
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
_c10d_functional::wait_tensor(t: f32[8, 32])
redistribute_input(t: f32[1, 32], [S(0)] -> [R])
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
_c10d_functional::wait_tensor(t: f32[8, 32])
aten::mm(t: f32[1, 8], t: f32[8, 32])
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 32][S(0)])
aten::sum(dt: f32[8, 32][S(0)])
@ -90,7 +91,8 @@ class TestDTensorDebugMode(TestCase):
<method 'add' of 'torch._C.TensorBase' objects>(dt: f32[8, 8][S(0)], dt: f32[8, 8][S(1)])
aten::add.Tensor(dt: f32[8, 8][S(0)], dt: f32[8, 8][S(1)])
redistribute_input(1, [S(1)] -> [S(0)])
_dtensor::shard_dim_alltoall(t: f32[8, 1], 1, 0, 0)
redistribute_input(t: f32[8, 1], [S(1)] -> [S(0)])
_dtensor::shard_dim_alltoall(t: f32[8, 1], 1, 0, 0)
aten::add.Tensor(t: f32[1, 8], t: f32[1, 8])
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 8][S(0)])
aten::sum(dt: f32[8, 8][S(0)])
@ -100,12 +102,14 @@ class TestDTensorDebugMode(TestCase):
aten::ones_like(t: f32[], pin_memory=False, memory_format=torch.preserve_format)
aten::expand(dt: f32[][R], [8, 8])
aten::expand(t: f32[], [8, 8])
aten::split.Tensor(t: f32[8, 8], 1, 1)
aten::clone(t: f32[8, 1])
redistribute_input(t: f32[8, 8], [R] -> [S(1)])
aten::split.Tensor(t: f32[8, 8], 1, 1)
aten::clone(t: f32[8, 1])
aten::_to_copy(t: f32[8, 1], dtype=torch.float32, layout=torch.strided, device=cpu)
aten::detach(t: f32[8, 1])
aten::split.Tensor(t: f32[8, 8], 1)
aten::clone(t: f32[1, 8])
redistribute_input(t: f32[8, 8], [R] -> [S(0)])
aten::detach(t: f32[8, 1])
aten::split.Tensor(t: f32[8, 8], 1)
aten::clone(t: f32[1, 8])
aten::_to_copy(t: f32[1, 8], dtype=torch.float32, layout=torch.strided, device=cpu)
aten::detach(t: f32[1, 8])""",
)
@ -150,19 +154,21 @@ class TestDTensorDebugMode(TestCase):
aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16])
aten::bmm(dt: f32[1, 96, 8][P, R], dt: f32[1, 8, 16][R, P])
redistribute_input(0, [P, R] -> [S(2), S(2)])
aten::chunk(t: f32[1, 96, 8], 4, 2)
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
_c10d_functional::wait_tensor(t: f32[1, 96, 2])
aten::chunk(t: f32[1, 96, 2], 2, 2)
aten::clone(t: f32[1, 96, 1])
redistribute_input(t: f32[1, 96, 8], [P, R] -> [S(2), S(2)])
aten::chunk(t: f32[1, 96, 8], 4, 2)
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
_c10d_functional::wait_tensor(t: f32[1, 96, 2])
aten::chunk(t: f32[1, 96, 2], 2, 2)
aten::clone(t: f32[1, 96, 1])
redistribute_input(1, [R, P] -> [S(1), S(1)])
aten::chunk(t: f32[1, 8, 16], 4, 1)
aten::clone(t: f32[1, 2, 16])
aten::chunk(t: f32[1, 2, 16], 2, 1)
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
redistribute_input(t: f32[1, 8, 16], [R, P] -> [S(1), S(1)])
aten::chunk(t: f32[1, 8, 16], 4, 1)
aten::clone(t: f32[1, 2, 16])
aten::chunk(t: f32[1, 2, 16], 2, 1)
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16])
aten::view(dt: f32[1, 96, 16][P, P], [16, 6, 1, 4, 4])
aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4])

View File

@ -670,7 +670,7 @@ class TestDTensorOps(DTensorOpTestBase):
.to(DEVICE_TYPE)
)
for is_evenly_shardable in [True]:
for is_evenly_shardable in [True, False]:
if is_evenly_shardable:
placement = [Shard(1)]
reduce_dim = 1
@ -686,9 +686,9 @@ class TestDTensorOps(DTensorOpTestBase):
self.assertEqual(full_tensor, tensor.mean(dim=reduce_dim))
if is_evenly_shardable:
self.assertFalse("redistribute_input" in debug_mode.debug_string())
self.assertTrue("[P] -> [R]" in debug_mode.debug_string())
else:
self.assertTrue("redistribute_input" in debug_mode.debug_string())
self.assertTrue("[S(0)] -> [R])" in debug_mode.debug_string())
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)

View File

@ -23,11 +23,8 @@ from torch.distributed.tensor._tp_conv import (
)
from torch.distributed.tensor._utils import try_find_mesh_from_args
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
from torch.utils._debug_mode import DebugMode
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
return_and_correct_aliasing,
)
from torch.utils._debug_mode import get_active_debug_mode
from torch.utils._python_dispatch import return_and_correct_aliasing
try:
@ -338,8 +335,7 @@ class OpDispatcher:
suggested_input_schema: OpSchema,
use_val_from_redistribute_schema: bool,
) -> None:
debug_mode = _get_current_dispatch_mode()
in_debug_mode = isinstance(debug_mode, DebugMode)
debug_mode = get_active_debug_mode()
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
if op_info.args_tree_spec is not None:
@ -359,7 +355,7 @@ class OpDispatcher:
debug_mode.record_redistribute_calls( # type: ignore[union-attr]
i, arg_spec, reshard_arg_spec
)
if in_debug_mode
if debug_mode is not None
else contextlib.nullcontext()
)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
import logging
from functools import cache
from typing import cast, NamedTuple, Optional
@ -16,6 +17,7 @@ from torch.distributed.tensor.placement_types import (
Replicate,
Shard,
)
from torch.utils._debug_mode import get_active_debug_mode
logger = logging.getLogger(__name__)
@ -187,92 +189,106 @@ def redistribute_local_tensor(
else:
transform_infos = _gen_transform_infos(current_spec, target_spec)
for transform_info in transform_infos:
i = transform_info.mesh_dim
current, target = transform_info.src_dst_placements
device_mesh.size(mesh_dim=i)
debug_mode = get_active_debug_mode()
redistribute_context = (
debug_mode.record_redistribute_calls( # type: ignore[union-attr]
local_tensor, current_spec, target_spec
)
if debug_mode is not None
else contextlib.nullcontext()
)
if current == target:
# short cut, just use the original local tensor
new_local_tensor = local_tensor
continue
with redistribute_context:
for transform_info in transform_infos:
i = transform_info.mesh_dim
current, target = transform_info.src_dst_placements
device_mesh.size(mesh_dim=i)
logger.debug("redistribute from %s to %s on mesh dim %s", current, target, i)
if current == target:
# short cut, just use the original local tensor
new_local_tensor = local_tensor
continue
if target.is_replicate():
# Case 1: target is Replicate
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_value(
local_tensor, device_mesh, i
)
elif current.is_shard():
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
elif target.is_shard():
# Case 2: target is Shard
target_placement = cast(Shard, target)
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_shard_value(
local_tensor, device_mesh, i, target_placement
)
elif current.is_replicate():
# split the tensor and return the corresponding cloned local shard
new_local_tensor = target_placement._replicate_to_shard(
local_tensor, device_mesh, i, my_coordinate[i]
)
else:
assert current.is_shard(), (
f"Current placement should be shard but found {current}"
)
shard_spec = cast(Shard, current)
if shard_spec.dim != target_placement.dim:
new_local_tensor = shard_spec._to_new_shard_dim(
local_tensor,
device_mesh,
i,
transform_info.logical_shape,
target_placement.dim,
logger.debug(
"redistribute from %s to %s on mesh dim %s", current, target, i
)
if target.is_replicate():
# Case 1: target is Replicate
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_value(
local_tensor, device_mesh, i
)
elif target.is_partial():
if current.is_replicate():
partial_spec = cast(Partial, target)
# skip the replicate to partial transformation when we are in backward pass
# In this case we keep the grad as replicate, this is because we don't
# want to convert the replicated gradients back to partial, although
# that's logically conform with the same layout, converting the gradients
# back to partial is actually useless as you would have to do reduce later
# which would be more expensive than keeping it replicate! For this reason,
# we keep the replicate grad here.
new_local_tensor = (
partial_spec._partition_value(local_tensor, device_mesh, i)
if not is_backward
else local_tensor
)
elif current.is_shard():
if not is_backward:
elif current.is_shard():
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
# for backward shard -> partial, we just need to convert the shard to replicate
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
# partial -> partial no op, should never hit
new_local_tensor = local_tensor
elif target.is_shard():
# Case 2: target is Shard
target_placement = cast(Shard, target)
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_shard_value(
local_tensor, device_mesh, i, target_placement
)
elif current.is_replicate():
# split the tensor and return the corresponding cloned local shard
new_local_tensor = target_placement._replicate_to_shard(
local_tensor, device_mesh, i, my_coordinate[i]
)
else:
assert current.is_shard(), (
f"Current placement should be shard but found {current}"
)
shard_spec = cast(Shard, current)
if shard_spec.dim != target_placement.dim:
new_local_tensor = shard_spec._to_new_shard_dim(
local_tensor,
device_mesh,
i,
transform_info.logical_shape,
target_placement.dim,
)
elif target.is_partial():
if current.is_replicate():
partial_spec = cast(Partial, target)
# skip the replicate to partial transformation when we are in backward pass
# In this case we keep the grad as replicate, this is because we don't
# want to convert the replicated gradients back to partial, although
# that's logically conform with the same layout, converting the gradients
# back to partial is actually useless as you would have to do reduce later
# which would be more expensive than keeping it replicate! For this reason,
# we keep the replicate grad here.
new_local_tensor = (
partial_spec._partition_value(local_tensor, device_mesh, i)
if not is_backward
else local_tensor
)
elif current.is_shard():
if not is_backward:
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
# for backward shard -> partial, we just need to convert the shard to replicate
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
# partial -> partial no op, should never hit
new_local_tensor = local_tensor
if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
new_local_tensor = new_local_tensor.wait()
local_tensor = new_local_tensor
if not async_op and isinstance(
new_local_tensor, funcol.AsyncCollectiveTensor
):
new_local_tensor = new_local_tensor.wait()
local_tensor = new_local_tensor
return new_local_tensor

View File

@ -1,14 +1,19 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Optional
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.utils._dtype_abbrs import dtype_abbrs
from torch.utils._python_dispatch import _get_current_dispatch_mode, TorchDispatchMode
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_get_current_dispatch_mode_stack,
TorchDispatchMode,
)
from torch.utils._pytree import tree_map
__all__ = ["DebugMode"]
__all__ = ["DebugMode", "get_active_debug_mode"]
REDISTRIBUTE_FUNC = "redistribute_input"
@ -168,3 +173,12 @@ class DebugMode(TorchDispatchMode):
for op, args, kwargs, depth in self.operators
)
return result
def get_active_debug_mode() -> Optional[DebugMode]:
debug_mode = None
for mode in _get_current_dispatch_mode_stack():
if isinstance(mode, DebugMode):
debug_mode = mode
break
return debug_mode