AsyncCollectiveTensor: dont sync on view ops (#105240)

AsyncCollectiveTensor is a tensor subclass that is meant to "delay synchronization" when you call into the functional collectives API's. It does this (if I understand correctly) by internally holding an "unsynchronized" version of the tensor, which is the result of the communication op, and internally calling `.wait()` to synchronize the data the next time it is used.

Previously, these wait() calls would happen immediately, because `AsyncCollectiveTensor` gets wrapped by `DTensor()`, which calls `.detach()` on its inner tensor, immediately causing the sync (code: 1518d5eec4/torch/distributed/_tensor/api.py (L207))

AsyncCollectiveTensor shouldn't need to do a synchronization if you try to detach() it though - in fact, it should be fine to avoid synchronizing if you perform any view ops on it (which just require viewing metadata, but not actual data). This PR tries to update `AsyncCollectiveTensor` to delay `wait()` calls whenever the subclass encounters a view op.

Added some light testing, that just runs some DTensor compute followed by view ops, and confirms that the output is still an `AsyncCollectiveTensor` when we call `.to_local()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105240
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/wconstab
This commit is contained in:
Wanchao Liang
2023-08-11 02:15:56 +00:00
committed by PyTorch MergeBot
parent e165938853
commit 5c48ff20b5
3 changed files with 80 additions and 5 deletions

View File

@ -5,6 +5,7 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
from numpy.testing import assert_array_equal
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
@ -245,6 +246,47 @@ class DTensorTest(DTensorTestBase):
except RuntimeError:
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
@with_comms
def test_dtensor_async_output(self):
# Tests that if the output of some dtensor operations isn't used in any compute,
# the output should be an AsyncCollectiveTensor (representing the fact that
# we haven't synced the collective yet).
from torch.distributed._functional_collectives_impl import _tensor_needs_wait
mesh = DeviceMesh(
self.device_type, torch.arange(self.world_size), _validate_mesh=False
)
def fn(dt):
dt_out_redistribute = dt.redistribute(mesh, [Replicate()])
# Make sure we haven't synced yet
# TODO: figure out why this is returning None
# self.assertTrue(_tensor_needs_wait(dt_out_redistribute))
dt_out_redistribute_view = dt_out_redistribute.view(
dt_out_redistribute.shape
)
local_tensor = dt_out_redistribute_view.to_local()
return local_tensor
x = torch.ones((4, 2), device=self.device_type)
dt = distribute_tensor(x, mesh, [Shard(0)])
out = fn(dt)
# Make sure we haven't synced yet
self.assertEqual(type(out), AsyncCollectiveTensor)
self.assertTrue(_tensor_needs_wait(out.elem))
out_view = out.view(-1)
# Assert that output is a `AsyncCollectiveTensor`
self.assertEqual(type(out_view), AsyncCollectiveTensor)
self.assertTrue(_tensor_needs_wait(out_view.elem))
# Use the daa, requiring a sync
ref = torch.ones((4, 2), device=self.device_type) + 1
ref = ref.view(-1)
out_data = out_view + 1
self.assertEqual(type(out_data), torch.Tensor)
self.assertEqual(out_data, ref)
@with_comms
def test_from_local_then_to_local(self):
# this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works