fix wait() missing in redistribute tensor (#162749)

We notice that the wait() op is missing after collective op call: https://github.com/pytorch/pytorch/pull/162665#discussion_r2338460562.

The issue is that `_maybe_warp_tensor` calls AsyncCollectiveTensor in 3ad3bfe11d/torch/distributed/_functional_collectives.py (L829) We need to check whether the wait() is required after collective op call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162749
Approved by: https://github.com/ezyang, https://github.com/SherlockNoMad, https://github.com/wconstab
This commit is contained in:
zpcore
2025-09-17 16:24:26 +00:00
committed by PyTorch MergeBot
parent 4ca3f435fb
commit f2206b1ed8
2 changed files with 4 additions and 4 deletions

View File

@ -153,6 +153,8 @@ class TestDTensorDebugMode(TestCase):
aten::chunk(t: f32[1, 96, 8], 4, 2)
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
_c10d_functional::wait_tensor(t: f32[1, 96, 2])
aten::chunk(t: f32[1, 96, 2], 2, 2)
aten::clone(t: f32[1, 96, 1])
redistribute_input(1, [R, P] -> [S(1), S(1)])
aten::chunk(t: f32[1, 8, 16], 4, 1)

View File

@ -270,11 +270,9 @@ def redistribute_local_tensor(
# partial -> partial no op, should never hit
new_local_tensor = local_tensor
if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
new_local_tensor = new_local_tensor.wait()
local_tensor = new_local_tensor
if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
new_local_tensor = new_local_tensor.wait()
return new_local_tensor