mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix wait() missing in redistribute tensor (#162749)
We notice that the wait() op is missing after collective op call: https://github.com/pytorch/pytorch/pull/162665#discussion_r2338460562.
The issue is that `_maybe_warp_tensor` calls AsyncCollectiveTensor in 3ad3bfe11d/torch/distributed/_functional_collectives.py (L829)
We need to check whether the wait() is required after collective op call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162749
Approved by: https://github.com/ezyang, https://github.com/SherlockNoMad, https://github.com/wconstab
This commit is contained in:
@ -153,6 +153,8 @@ class TestDTensorDebugMode(TestCase):
|
||||
aten::chunk(t: f32[1, 96, 8], 4, 2)
|
||||
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
|
||||
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
|
||||
_c10d_functional::wait_tensor(t: f32[1, 96, 2])
|
||||
aten::chunk(t: f32[1, 96, 2], 2, 2)
|
||||
aten::clone(t: f32[1, 96, 1])
|
||||
redistribute_input(1, [R, P] -> [S(1), S(1)])
|
||||
aten::chunk(t: f32[1, 8, 16], 4, 1)
|
||||
|
@ -270,11 +270,9 @@ def redistribute_local_tensor(
|
||||
# partial -> partial no op, should never hit
|
||||
new_local_tensor = local_tensor
|
||||
|
||||
if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
|
||||
new_local_tensor = new_local_tensor.wait()
|
||||
local_tensor = new_local_tensor
|
||||
|
||||
if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
|
||||
new_local_tensor = new_local_tensor.wait()
|
||||
|
||||
return new_local_tensor
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user