mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Record redistribute_local_tensor in DebugMode (#163704)
Explicit redistribute_local_tensor API call could also results in communication, record it! Pull Request resolved: https://github.com/pytorch/pytorch/pull/163704 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
5d0f639234
commit
4c2c401ccf
@ -50,8 +50,9 @@ class TestDTensorDebugMode(TestCase):
|
||||
torch.mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
|
||||
aten::mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
|
||||
redistribute_input(1, [S(0)] -> [R])
|
||||
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
|
||||
_c10d_functional::wait_tensor(t: f32[8, 32])
|
||||
redistribute_input(t: f32[1, 32], [S(0)] -> [R])
|
||||
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
|
||||
_c10d_functional::wait_tensor(t: f32[8, 32])
|
||||
aten::mm(t: f32[1, 8], t: f32[8, 32])
|
||||
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 32][S(0)])
|
||||
aten::sum(dt: f32[8, 32][S(0)])
|
||||
@ -90,7 +91,8 @@ class TestDTensorDebugMode(TestCase):
|
||||
<method 'add' of 'torch._C.TensorBase' objects>(dt: f32[8, 8][S(0)], dt: f32[8, 8][S(1)])
|
||||
aten::add.Tensor(dt: f32[8, 8][S(0)], dt: f32[8, 8][S(1)])
|
||||
redistribute_input(1, [S(1)] -> [S(0)])
|
||||
_dtensor::shard_dim_alltoall(t: f32[8, 1], 1, 0, 0)
|
||||
redistribute_input(t: f32[8, 1], [S(1)] -> [S(0)])
|
||||
_dtensor::shard_dim_alltoall(t: f32[8, 1], 1, 0, 0)
|
||||
aten::add.Tensor(t: f32[1, 8], t: f32[1, 8])
|
||||
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 8][S(0)])
|
||||
aten::sum(dt: f32[8, 8][S(0)])
|
||||
@ -100,12 +102,14 @@ class TestDTensorDebugMode(TestCase):
|
||||
aten::ones_like(t: f32[], pin_memory=False, memory_format=torch.preserve_format)
|
||||
aten::expand(dt: f32[][R], [8, 8])
|
||||
aten::expand(t: f32[], [8, 8])
|
||||
aten::split.Tensor(t: f32[8, 8], 1, 1)
|
||||
aten::clone(t: f32[8, 1])
|
||||
redistribute_input(t: f32[8, 8], [R] -> [S(1)])
|
||||
aten::split.Tensor(t: f32[8, 8], 1, 1)
|
||||
aten::clone(t: f32[8, 1])
|
||||
aten::_to_copy(t: f32[8, 1], dtype=torch.float32, layout=torch.strided, device=cpu)
|
||||
aten::detach(t: f32[8, 1])
|
||||
aten::split.Tensor(t: f32[8, 8], 1)
|
||||
aten::clone(t: f32[1, 8])
|
||||
redistribute_input(t: f32[8, 8], [R] -> [S(0)])
|
||||
aten::detach(t: f32[8, 1])
|
||||
aten::split.Tensor(t: f32[8, 8], 1)
|
||||
aten::clone(t: f32[1, 8])
|
||||
aten::_to_copy(t: f32[1, 8], dtype=torch.float32, layout=torch.strided, device=cpu)
|
||||
aten::detach(t: f32[1, 8])""",
|
||||
)
|
||||
@ -150,19 +154,21 @@ class TestDTensorDebugMode(TestCase):
|
||||
aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16])
|
||||
aten::bmm(dt: f32[1, 96, 8][P, R], dt: f32[1, 8, 16][R, P])
|
||||
redistribute_input(0, [P, R] -> [S(2), S(2)])
|
||||
aten::chunk(t: f32[1, 96, 8], 4, 2)
|
||||
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
|
||||
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
|
||||
_c10d_functional::wait_tensor(t: f32[1, 96, 2])
|
||||
aten::chunk(t: f32[1, 96, 2], 2, 2)
|
||||
aten::clone(t: f32[1, 96, 1])
|
||||
redistribute_input(t: f32[1, 96, 8], [P, R] -> [S(2), S(2)])
|
||||
aten::chunk(t: f32[1, 96, 8], 4, 2)
|
||||
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
|
||||
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
|
||||
_c10d_functional::wait_tensor(t: f32[1, 96, 2])
|
||||
aten::chunk(t: f32[1, 96, 2], 2, 2)
|
||||
aten::clone(t: f32[1, 96, 1])
|
||||
redistribute_input(1, [R, P] -> [S(1), S(1)])
|
||||
aten::chunk(t: f32[1, 8, 16], 4, 1)
|
||||
aten::clone(t: f32[1, 2, 16])
|
||||
aten::chunk(t: f32[1, 2, 16], 2, 1)
|
||||
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
|
||||
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
|
||||
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
|
||||
redistribute_input(t: f32[1, 8, 16], [R, P] -> [S(1), S(1)])
|
||||
aten::chunk(t: f32[1, 8, 16], 4, 1)
|
||||
aten::clone(t: f32[1, 2, 16])
|
||||
aten::chunk(t: f32[1, 2, 16], 2, 1)
|
||||
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
|
||||
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
|
||||
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
|
||||
aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16])
|
||||
aten::view(dt: f32[1, 96, 16][P, P], [16, 6, 1, 4, 4])
|
||||
aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4])
|
||||
|
@ -670,7 +670,7 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
.to(DEVICE_TYPE)
|
||||
)
|
||||
|
||||
for is_evenly_shardable in [True]:
|
||||
for is_evenly_shardable in [True, False]:
|
||||
if is_evenly_shardable:
|
||||
placement = [Shard(1)]
|
||||
reduce_dim = 1
|
||||
@ -686,9 +686,9 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
self.assertEqual(full_tensor, tensor.mean(dim=reduce_dim))
|
||||
|
||||
if is_evenly_shardable:
|
||||
self.assertFalse("redistribute_input" in debug_mode.debug_string())
|
||||
self.assertTrue("[P] -> [R]" in debug_mode.debug_string())
|
||||
else:
|
||||
self.assertTrue("redistribute_input" in debug_mode.debug_string())
|
||||
self.assertTrue("[S(0)] -> [R])" in debug_mode.debug_string())
|
||||
|
||||
|
||||
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
|
||||
|
@ -23,11 +23,8 @@ from torch.distributed.tensor._tp_conv import (
|
||||
)
|
||||
from torch.distributed.tensor._utils import try_find_mesh_from_args
|
||||
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
|
||||
from torch.utils._debug_mode import DebugMode
|
||||
from torch.utils._python_dispatch import (
|
||||
_get_current_dispatch_mode,
|
||||
return_and_correct_aliasing,
|
||||
)
|
||||
from torch.utils._debug_mode import get_active_debug_mode
|
||||
from torch.utils._python_dispatch import return_and_correct_aliasing
|
||||
|
||||
|
||||
try:
|
||||
@ -338,8 +335,7 @@ class OpDispatcher:
|
||||
suggested_input_schema: OpSchema,
|
||||
use_val_from_redistribute_schema: bool,
|
||||
) -> None:
|
||||
debug_mode = _get_current_dispatch_mode()
|
||||
in_debug_mode = isinstance(debug_mode, DebugMode)
|
||||
debug_mode = get_active_debug_mode()
|
||||
|
||||
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
|
||||
if op_info.args_tree_spec is not None:
|
||||
@ -359,7 +355,7 @@ class OpDispatcher:
|
||||
debug_mode.record_redistribute_calls( # type: ignore[union-attr]
|
||||
i, arg_spec, reshard_arg_spec
|
||||
)
|
||||
if in_debug_mode
|
||||
if debug_mode is not None
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import contextlib
|
||||
import logging
|
||||
from functools import cache
|
||||
from typing import cast, NamedTuple, Optional
|
||||
@ -16,6 +17,7 @@ from torch.distributed.tensor.placement_types import (
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.utils._debug_mode import get_active_debug_mode
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -187,92 +189,106 @@ def redistribute_local_tensor(
|
||||
else:
|
||||
transform_infos = _gen_transform_infos(current_spec, target_spec)
|
||||
|
||||
for transform_info in transform_infos:
|
||||
i = transform_info.mesh_dim
|
||||
current, target = transform_info.src_dst_placements
|
||||
device_mesh.size(mesh_dim=i)
|
||||
debug_mode = get_active_debug_mode()
|
||||
redistribute_context = (
|
||||
debug_mode.record_redistribute_calls( # type: ignore[union-attr]
|
||||
local_tensor, current_spec, target_spec
|
||||
)
|
||||
if debug_mode is not None
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
if current == target:
|
||||
# short cut, just use the original local tensor
|
||||
new_local_tensor = local_tensor
|
||||
continue
|
||||
with redistribute_context:
|
||||
for transform_info in transform_infos:
|
||||
i = transform_info.mesh_dim
|
||||
current, target = transform_info.src_dst_placements
|
||||
device_mesh.size(mesh_dim=i)
|
||||
|
||||
logger.debug("redistribute from %s to %s on mesh dim %s", current, target, i)
|
||||
if current == target:
|
||||
# short cut, just use the original local tensor
|
||||
new_local_tensor = local_tensor
|
||||
continue
|
||||
|
||||
if target.is_replicate():
|
||||
# Case 1: target is Replicate
|
||||
if current.is_partial():
|
||||
partial_spec = cast(Partial, current)
|
||||
new_local_tensor = partial_spec._reduce_value(
|
||||
local_tensor, device_mesh, i
|
||||
)
|
||||
elif current.is_shard():
|
||||
current_placement = cast(Shard, current)
|
||||
new_local_tensor = current_placement._to_replicate_tensor(
|
||||
local_tensor, device_mesh, i, transform_info.logical_shape
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"redistribute from {current} to {target} not supported yet"
|
||||
)
|
||||
elif target.is_shard():
|
||||
# Case 2: target is Shard
|
||||
target_placement = cast(Shard, target)
|
||||
if current.is_partial():
|
||||
partial_spec = cast(Partial, current)
|
||||
new_local_tensor = partial_spec._reduce_shard_value(
|
||||
local_tensor, device_mesh, i, target_placement
|
||||
)
|
||||
elif current.is_replicate():
|
||||
# split the tensor and return the corresponding cloned local shard
|
||||
new_local_tensor = target_placement._replicate_to_shard(
|
||||
local_tensor, device_mesh, i, my_coordinate[i]
|
||||
)
|
||||
else:
|
||||
assert current.is_shard(), (
|
||||
f"Current placement should be shard but found {current}"
|
||||
)
|
||||
shard_spec = cast(Shard, current)
|
||||
if shard_spec.dim != target_placement.dim:
|
||||
new_local_tensor = shard_spec._to_new_shard_dim(
|
||||
local_tensor,
|
||||
device_mesh,
|
||||
i,
|
||||
transform_info.logical_shape,
|
||||
target_placement.dim,
|
||||
logger.debug(
|
||||
"redistribute from %s to %s on mesh dim %s", current, target, i
|
||||
)
|
||||
|
||||
if target.is_replicate():
|
||||
# Case 1: target is Replicate
|
||||
if current.is_partial():
|
||||
partial_spec = cast(Partial, current)
|
||||
new_local_tensor = partial_spec._reduce_value(
|
||||
local_tensor, device_mesh, i
|
||||
)
|
||||
elif target.is_partial():
|
||||
if current.is_replicate():
|
||||
partial_spec = cast(Partial, target)
|
||||
# skip the replicate to partial transformation when we are in backward pass
|
||||
# In this case we keep the grad as replicate, this is because we don't
|
||||
# want to convert the replicated gradients back to partial, although
|
||||
# that's logically conform with the same layout, converting the gradients
|
||||
# back to partial is actually useless as you would have to do reduce later
|
||||
# which would be more expensive than keeping it replicate! For this reason,
|
||||
# we keep the replicate grad here.
|
||||
new_local_tensor = (
|
||||
partial_spec._partition_value(local_tensor, device_mesh, i)
|
||||
if not is_backward
|
||||
else local_tensor
|
||||
)
|
||||
elif current.is_shard():
|
||||
if not is_backward:
|
||||
elif current.is_shard():
|
||||
current_placement = cast(Shard, current)
|
||||
new_local_tensor = current_placement._to_replicate_tensor(
|
||||
local_tensor, device_mesh, i, transform_info.logical_shape
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"redistribute from {current} to {target} not supported yet"
|
||||
)
|
||||
# for backward shard -> partial, we just need to convert the shard to replicate
|
||||
current_placement = cast(Shard, current)
|
||||
new_local_tensor = current_placement._to_replicate_tensor(
|
||||
local_tensor, device_mesh, i, transform_info.logical_shape
|
||||
)
|
||||
else:
|
||||
# partial -> partial no op, should never hit
|
||||
new_local_tensor = local_tensor
|
||||
elif target.is_shard():
|
||||
# Case 2: target is Shard
|
||||
target_placement = cast(Shard, target)
|
||||
if current.is_partial():
|
||||
partial_spec = cast(Partial, current)
|
||||
new_local_tensor = partial_spec._reduce_shard_value(
|
||||
local_tensor, device_mesh, i, target_placement
|
||||
)
|
||||
elif current.is_replicate():
|
||||
# split the tensor and return the corresponding cloned local shard
|
||||
new_local_tensor = target_placement._replicate_to_shard(
|
||||
local_tensor, device_mesh, i, my_coordinate[i]
|
||||
)
|
||||
else:
|
||||
assert current.is_shard(), (
|
||||
f"Current placement should be shard but found {current}"
|
||||
)
|
||||
shard_spec = cast(Shard, current)
|
||||
if shard_spec.dim != target_placement.dim:
|
||||
new_local_tensor = shard_spec._to_new_shard_dim(
|
||||
local_tensor,
|
||||
device_mesh,
|
||||
i,
|
||||
transform_info.logical_shape,
|
||||
target_placement.dim,
|
||||
)
|
||||
elif target.is_partial():
|
||||
if current.is_replicate():
|
||||
partial_spec = cast(Partial, target)
|
||||
# skip the replicate to partial transformation when we are in backward pass
|
||||
# In this case we keep the grad as replicate, this is because we don't
|
||||
# want to convert the replicated gradients back to partial, although
|
||||
# that's logically conform with the same layout, converting the gradients
|
||||
# back to partial is actually useless as you would have to do reduce later
|
||||
# which would be more expensive than keeping it replicate! For this reason,
|
||||
# we keep the replicate grad here.
|
||||
new_local_tensor = (
|
||||
partial_spec._partition_value(local_tensor, device_mesh, i)
|
||||
if not is_backward
|
||||
else local_tensor
|
||||
)
|
||||
elif current.is_shard():
|
||||
if not is_backward:
|
||||
raise RuntimeError(
|
||||
f"redistribute from {current} to {target} not supported yet"
|
||||
)
|
||||
# for backward shard -> partial, we just need to convert the shard to replicate
|
||||
current_placement = cast(Shard, current)
|
||||
new_local_tensor = current_placement._to_replicate_tensor(
|
||||
local_tensor, device_mesh, i, transform_info.logical_shape
|
||||
)
|
||||
else:
|
||||
# partial -> partial no op, should never hit
|
||||
new_local_tensor = local_tensor
|
||||
|
||||
if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
|
||||
new_local_tensor = new_local_tensor.wait()
|
||||
local_tensor = new_local_tensor
|
||||
if not async_op and isinstance(
|
||||
new_local_tensor, funcol.AsyncCollectiveTensor
|
||||
):
|
||||
new_local_tensor = new_local_tensor.wait()
|
||||
local_tensor = new_local_tensor
|
||||
return new_local_tensor
|
||||
|
||||
|
||||
|
@ -1,14 +1,19 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode, TorchDispatchMode
|
||||
from torch.utils._python_dispatch import (
|
||||
_get_current_dispatch_mode,
|
||||
_get_current_dispatch_mode_stack,
|
||||
TorchDispatchMode,
|
||||
)
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
__all__ = ["DebugMode"]
|
||||
__all__ = ["DebugMode", "get_active_debug_mode"]
|
||||
|
||||
REDISTRIBUTE_FUNC = "redistribute_input"
|
||||
|
||||
@ -168,3 +173,12 @@ class DebugMode(TorchDispatchMode):
|
||||
for op, args, kwargs, depth in self.operators
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_active_debug_mode() -> Optional[DebugMode]:
|
||||
debug_mode = None
|
||||
for mode in _get_current_dispatch_mode_stack():
|
||||
if isinstance(mode, DebugMode):
|
||||
debug_mode = mode
|
||||
break
|
||||
return debug_mode
|
||||
|
Reference in New Issue
Block a user