Address feedback

This commit is contained in:
Rohit Singh Rathaur
2025-10-09 05:57:14 +00:00
committed by PyTorch MergeBot
parent 7e7712cf8d
commit 2fa6f50f8b

View File

@ -336,12 +336,12 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
# from Partial to Replicate placement (issue #163374)
device_mesh = self.build_device_mesh()
tensor = torch.ones(8, 8, device=self.device_type)
in_dtensor = distribute_tensor(tensor, device_mesh, [Shard(0)])
partial_dt = in_dtensor.sum()
input_tensor = torch.tensor(64.0, device=self.device_type)
partial_dt = DTensor.from_local(
input_tensor, device_mesh, placements=(Partial(),)
)
self.assertTrue(partial_dt.placements[0].is_partial())
self.assertTrue(partial_dt.placements[0].is_partial("sum"))
out = partial_dt.clamp_(max=10)
self.assertEqual(out.placements, (Replicate(),))
self.assertEqual(partial_dt.placements, (Replicate(),))
@ -350,5 +350,6 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
self.assertEqual(full.item(), 10.0)
self.assertEqual(out.to_local().item(), 10.0)
if __name__ == "__main__":
run_tests()