mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 07:24:54 +08:00
Address feedback
This commit is contained in:
committed by
PyTorch MergeBot
parent
7e7712cf8d
commit
2fa6f50f8b
@ -336,12 +336,12 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
|
||||
# from Partial to Replicate placement (issue #163374)
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
tensor = torch.ones(8, 8, device=self.device_type)
|
||||
in_dtensor = distribute_tensor(tensor, device_mesh, [Shard(0)])
|
||||
partial_dt = in_dtensor.sum()
|
||||
input_tensor = torch.tensor(64.0, device=self.device_type)
|
||||
partial_dt = DTensor.from_local(
|
||||
input_tensor, device_mesh, placements=(Partial(),)
|
||||
)
|
||||
|
||||
self.assertTrue(partial_dt.placements[0].is_partial())
|
||||
self.assertTrue(partial_dt.placements[0].is_partial("sum"))
|
||||
out = partial_dt.clamp_(max=10)
|
||||
self.assertEqual(out.placements, (Replicate(),))
|
||||
self.assertEqual(partial_dt.placements, (Replicate(),))
|
||||
@ -350,5 +350,6 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
|
||||
self.assertEqual(full.item(), 10.0)
|
||||
self.assertEqual(out.to_local().item(), 10.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
Reference in New Issue
Block a user