Compare commits

...

5 Commits

Author SHA1 Message Date
9179b8a722 Update on "[DTensor][Partial] fixing .item() when dtensor is Partial "
**Summary:** Currently, when users call .item() on a Partial DTensor, each rank returns the local tensor value instead of the global tensor value. The root cause is for aten._local_scalar_dense.default, we return OutputSharding with the original schema, which would be correct for Shard(n) and Replicate(). However, for Partial, we should return a redistribution schema that has Replicate placements and sets needs_redistribution to True. This will cause a redistribution to happen before .item() executes.

**Test Cases**
1. pytest test/distributed/tensor/test_tensor_ops.py -k test_item




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci tianyu-l XilunWu SherlockNoMad

[ghstack-poisoned]
2025-11-13 13:45:20 -08:00
2b686a84fe Update on "[DTensor][Partial] fixing .item() when dtensor is Partial "
**Summary:** Currently, when users call .item() on a Partial DTensor, each rank returns the local tensor value instead of the global tensor value. The root cause is for aten._local_scalar_dense.default, we return OutputSharding with the original schema, which would be correct for Shard(n) and Replicate(). However, for Partial, we should return a redistribution schema that has Replicate placements and sets needs_redistribution to True. This will cause a redistribution to happen before .item() executes.

**Test Cases**
1. pytest test/distributed/tensor/test_tensor_ops.py -k test_item




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci tianyu-l XilunWu SherlockNoMad

[ghstack-poisoned]
2025-11-13 13:38:47 -08:00
a263ae2026 Update on "[DTensor][Partial] fixing .item() when dtensor is Partial #167595"
**Summary:** Currently, when users call .item() on a Partial DTensor, each rank returns the local tensor value instead of the global tensor value. The root cause is for aten._local_scalar_dense.default, we return OutputSharding with the original schema, which would be correct for Shard(n) and Replicate(). However, for Partial, we should return a redistribution schema that has Replicate placements and sets needs_redistribution to True. This will cause a redistribution to happen before .item() executes.

**Test Cases**
1. pytest test/distributed/tensor/test_tensor_ops.py -k test_item




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci tianyu-l XilunWu SherlockNoMad

[ghstack-poisoned]
2025-11-11 20:18:22 -08:00
b474a61a98 Update on "[DTensor][Partial] fixing .item() when dtensor is Partial #167595"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci tianyu-l XilunWu SherlockNoMad

[ghstack-poisoned]
2025-11-11 17:31:13 -08:00
5e86bafc9e [DTensor][Partial] fixing .item() when dtensor is Partial #167595
[ghstack-poisoned]
2025-11-11 17:25:57 -08:00
3 changed files with 59 additions and 0 deletions

View File

@ -260,6 +260,11 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
self.assertEqual(k1, k2)
self.assertEqual(type(v1), DTensor)
self.assertEqual(type(v2), DTensor)
# in order to do ops on DTensor that require redistribution, we need to move
# DTensors back to the gpu
v1 = v1.to("cuda")
v2 = v2.to("cuda")
# check whether DTensor are the same
self.assertEqual(v1, v2)

View File

@ -20,6 +20,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorConverter,
DTensorTestBase,
map_local_for_rank,
with_comms,
)
@ -826,6 +827,29 @@ class DistTensorOpsTest(DTensorTestBase):
):
self.assertEqual(x.full_tensor(), y)
@with_comms
def test_item(self):
mesh = self.build_device_mesh()
rank = self.rank
local_tensor = map_local_for_rank(rank, lambda rank: torch.tensor([rank]))
dt = DTensor.from_local(
local_tensor, device_mesh=mesh, placements=[Partial("sum")]
)
item_without_redistribute = dt.item()
self.assertEqual(item_without_redistribute, 6)
mesh_2d = DeviceMesh(self.device_type, torch.arange(4).reshape(2, 2))
dt = DTensor.from_local(
local_tensor,
device_mesh=mesh_2d,
placements=[Partial("sum"), Partial("sum")],
)
item_without_redistribute = dt.item()
self.assertEqual(item_without_redistribute, 6)
DistTensorOpsTestWithLocalTensor = create_local_tensor_test_class(
DistTensorOpsTest,

View File

@ -27,6 +27,7 @@ from torch.distributed.tensor._utils import (
compute_local_shape_and_global_offset,
compute_local_stride,
)
from torch.distributed.tensor.placement_types import Partial
aten = torch.ops.aten
@ -366,6 +367,35 @@ class ShardingPropagator:
# special case op, we don't need to propagate for local
# scalar. TODO: figure out a better way to handle this
if op_schema.op is aten._local_scalar_dense.default:
# there should only be one argument for _local_scalar_dense
arg = op_schema.args_schema[0]
if isinstance(arg, DTensorSpec) and any(
p.is_partial() for p in arg.placements
):
# Need to reduce partial values to replicate before calling item()
# Replace Partial placements with Replicate, keep others unchanged
from torch.distributed.tensor.placement_types import Replicate
new_placements = tuple(
Replicate() if isinstance(p, Partial) else p for p in arg.placements
)
replicate_spec = DTensorSpec(
arg.mesh,
new_placements,
arg.tensor_meta,
)
suggestion_schema = OpSchema(
op_schema.op,
(replicate_spec,),
{},
)
suggestion_schema._inplace_rewrap_schema_suggestion(op_schema)
return OutputSharding(
None,
suggestion_schema,
needs_redistribute=True,
)
return OutputSharding(None, op_schema)
out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)