mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
AsyncCollectiveTensor: dont sync on view ops (#105240)
AsyncCollectiveTensor is a tensor subclass that is meant to "delay synchronization" when you call into the functional collectives API's. It does this (if I understand correctly) by internally holding an "unsynchronized" version of the tensor, which is the result of the communication op, and internally calling `.wait()` to synchronize the data the next time it is used.
Previously, these wait() calls would happen immediately, because `AsyncCollectiveTensor` gets wrapped by `DTensor()`, which calls `.detach()` on its inner tensor, immediately causing the sync (code: 1518d5eec4/torch/distributed/_tensor/api.py (L207)
)
AsyncCollectiveTensor shouldn't need to do a synchronization if you try to detach() it though - in fact, it should be fine to avoid synchronizing if you perform any view ops on it (which just require viewing metadata, but not actual data). This PR tries to update `AsyncCollectiveTensor` to delay `wait()` calls whenever the subclass encounters a view op.
Added some light testing, that just runs some DTensor compute followed by view ops, and confirms that the output is still an `AsyncCollectiveTensor` when we call `.to_local()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105240
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
e165938853
commit
5c48ff20b5
@ -5,6 +5,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from numpy.testing import assert_array_equal
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
|
||||
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
|
||||
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
|
||||
@ -245,6 +246,47 @@ class DTensorTest(DTensorTestBase):
|
||||
except RuntimeError:
|
||||
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
|
||||
|
||||
@with_comms
|
||||
def test_dtensor_async_output(self):
|
||||
# Tests that if the output of some dtensor operations isn't used in any compute,
|
||||
# the output should be an AsyncCollectiveTensor (representing the fact that
|
||||
# we haven't synced the collective yet).
|
||||
from torch.distributed._functional_collectives_impl import _tensor_needs_wait
|
||||
|
||||
mesh = DeviceMesh(
|
||||
self.device_type, torch.arange(self.world_size), _validate_mesh=False
|
||||
)
|
||||
|
||||
def fn(dt):
|
||||
dt_out_redistribute = dt.redistribute(mesh, [Replicate()])
|
||||
# Make sure we haven't synced yet
|
||||
# TODO: figure out why this is returning None
|
||||
# self.assertTrue(_tensor_needs_wait(dt_out_redistribute))
|
||||
dt_out_redistribute_view = dt_out_redistribute.view(
|
||||
dt_out_redistribute.shape
|
||||
)
|
||||
local_tensor = dt_out_redistribute_view.to_local()
|
||||
return local_tensor
|
||||
|
||||
x = torch.ones((4, 2), device=self.device_type)
|
||||
dt = distribute_tensor(x, mesh, [Shard(0)])
|
||||
out = fn(dt)
|
||||
# Make sure we haven't synced yet
|
||||
self.assertEqual(type(out), AsyncCollectiveTensor)
|
||||
self.assertTrue(_tensor_needs_wait(out.elem))
|
||||
out_view = out.view(-1)
|
||||
|
||||
# Assert that output is a `AsyncCollectiveTensor`
|
||||
self.assertEqual(type(out_view), AsyncCollectiveTensor)
|
||||
self.assertTrue(_tensor_needs_wait(out_view.elem))
|
||||
|
||||
# Use the daa, requiring a sync
|
||||
ref = torch.ones((4, 2), device=self.device_type) + 1
|
||||
ref = ref.view(-1)
|
||||
out_data = out_view + 1
|
||||
self.assertEqual(type(out_data), torch.Tensor)
|
||||
self.assertEqual(out_data, ref)
|
||||
|
||||
@with_comms
|
||||
def test_from_local_then_to_local(self):
|
||||
# this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works
|
||||
|
Reference in New Issue
Block a user