Compare commits

...

2 Commits

Author SHA1 Message Date
a603f1c118 Update
[ghstack-poisoned]
2025-11-18 17:08:26 -08:00
817efeccf5 Update (base update)
[ghstack-poisoned]
2025-11-18 17:08:26 -08:00
2 changed files with 35 additions and 7 deletions

View File

@ -148,6 +148,30 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
d_3 = d_1 + d_2
self.assertTrue(d_3._spec.placements[0].is_partial())
def test_partial_replicate_add(self):
device_mesh = self.build_device_mesh()
comm_mode = CommDebugMode()
for reduce_op in ("sum", "avg"):
d_1 = DTensor.from_local(
torch.rand(2, 2),
device_mesh,
[Partial(reduce_op=reduce_op)],
)
d_2 = DTensor.from_local(
torch.rand(2, 1),
device_mesh,
[Replicate()],
run_check=True,
)
with comm_mode:
d_3 = d_1 + d_2
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertEqual(d_3.placements, (Partial(reduce_op=reduce_op),))
self.assertEqual(d_3.full_tensor(), d_1.full_tensor() + d_2.full_tensor())
def test_activations(self):
device_mesh = self.build_device_mesh()
self._run_sharded_elementwise_ops(

View File

@ -689,14 +689,18 @@ class Partial(Placement):
# Partial placement contract #3:
# _partition_value: partition the value of a replicated tensor on the mesh dimension
# _partition_value is the conjugate operation of _reduce_value
# - i.e. _partition_value on a sum reduce op is just a division operation
# - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation
# TODO: if the reduce_op is min/max, etc. the _partition_value should be a
# different operation
assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!"
# _partition_value is the conjugate operation of _reduce_value, e.g.
# - _partition_value on a sum reduce op is just a division operation
# - _reduce_value on a sum reduce op would just be a sum(allreduce) operation
num_chunks = mesh.size(mesh_dim=mesh_dim)
return tensor / num_chunks
if self.reduce_op == "sum":
return tensor / num_chunks
elif self.reduce_op in ("avg", "min", "max"):
return tensor
else:
raise ValueError(
f"Replicate to Partial({self.reduce_op}) conversion is not supported."
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Partial):