Compare commits

...

1 Commits

Author SHA1 Message Date
e16085a217 improve sharding propagation error msg in dtensor dispatch 2025-10-09 14:44:32 -07:00
4 changed files with 46 additions and 11 deletions

View File

@ -1,13 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import re
import unittest
import warnings
import torch
import torch.distributed as dist
import torch.testing._internal.common_methods_invocations as common_ops
from torch.distributed.tensor import distribute_tensor, DTensor, init_device_mesh, Shard
from torch.distributed.tensor import (
distribute_tensor,
DTensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.overrides import resolve_name
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
@ -690,6 +697,28 @@ class TestDTensorOps(DTensorOpTestBase):
else:
self.assertTrue("[S(0)] -> [R])" in debug_mode.debug_string())
def test_embedding_error_msg(self):
self.mesh_2d = init_device_mesh(
DEVICE_TYPE, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
)
self.mesh_1d = self.mesh_2d["tp"]
weight_global = torch.randn(2048, 256, device=DEVICE_TYPE)
weight_dtensor = distribute_tensor(weight_global, self.mesh_1d, [Shard(0)])
input_global = torch.randint(0, 2048, (16, 2048), device=DEVICE_TYPE)
input_dtensor = distribute_tensor(
input_global, self.mesh_2d, [Shard(0), Replicate()]
)
expected_error_msg = (
"Sharding propagation failed for aten.embedding.default"
"(Spec(f32[2048, 256](S(0))), Spec(i64[16, 2048](S(0)R))) "
"on DeviceMesh((dp=2, tp=2), "
)
with self.assertRaisesRegex(RuntimeError, re.escape(expected_error_msg)):
_ = torch.ops.aten.embedding.default(weight_dtensor, input_dtensor)
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
instantiate_device_type_tests(TestDTensorOps, globals(), only_for=(DEVICE_TYPE,))

View File

@ -686,7 +686,7 @@ else:
if self._mesh_dim_names
else f"{tuple(self._mesh.shape)}"
)
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, device: '{self._device_type}', stride: {self._mesh.stride()}"
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._mesh.stride()}"
# We only print the mesh tensor if the debug mode is turned on.
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
device_mesh_repr += f", Mesh: {self._mesh.tolist()}"

View File

@ -9,6 +9,8 @@ from torch.distributed.tensor.placement_types import (
Replicate,
Shard,
)
from torch.utils._debug_mode import _stringify_shape
from torch.utils._dtype_abbrs import dtype_abbrs
class TensorMeta(NamedTuple):
@ -106,14 +108,16 @@ class DTensorSpec:
if len(self.placements) == 1:
placement_str = str(self.placements[0])
else:
placement_str = str(self.placements)
placement_str = f"{''.join(str(p) for p in self.placements)}"
if self.tensor_meta is not None:
tensor_shape = str(tuple(self.tensor_meta.shape))
tensor_shape = _stringify_shape(self.tensor_meta.shape)
tensor_dtype = dtype_abbrs[self.tensor_meta.dtype]
else:
tensor_shape = "unknown shape"
tensor_dtype = "unknown dtype"
return f"Spec({placement_str} on {tensor_shape})"
return f"Spec({tensor_dtype}{tensor_shape}({placement_str}))"
@property
def shape(self) -> torch.Size:

View File

@ -205,7 +205,7 @@ class OpStrategy(StrategyType):
def __str__(self) -> str:
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
mesh_shape = self.mesh_shape
return f"[{strategy_list_str}] @ mesh: {mesh_shape}"
return f"OpStragety[{strategy_list_str}] @ mesh: {mesh_shape}"
def max_num_shards(self) -> int:
"""
@ -373,23 +373,25 @@ class OpSchema:
def __str__(self) -> str:
args_schema: list[str] = []
mesh_shape = None
device_mesh = None
for arg in self.args_schema:
if isinstance(arg, DTensorSpec):
args_schema.append(str(arg))
mesh_shape = arg.mesh.shape
device_mesh = arg.mesh
elif isinstance(arg, OpStrategy):
assert len(arg.strategies) == 1
args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs))
mesh_shape = arg.mesh_shape
device_mesh = arg.mesh
elif isinstance(arg, TupleStrategy):
first_op_strategy = arg.children[0]
assert isinstance(first_op_strategy, OpStrategy)
mesh_shape = first_op_strategy.mesh_shape
device_mesh = first_op_strategy.mesh
args_schema.append(str(arg))
else:
args_schema.append(str(arg))
return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})"
return f"{self.op}({', '.join(args_schema)}) on {device_mesh})"
def __post_init__(self) -> None:
_DTensor_OpSchema_post_init(self)