Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)

- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA
devices so that it can be reset before execution of the op for each such that ops
with randomness produces the same result for all ranks (note that we are planning a
separate change to add support of per rank rng state). Previously we relied on
op input arguments to deduce which devices to get rng state from. Which doesn't work
for factory functions such torch.randn. Hence this changes switches to uncondionally
collecting rng state from all devices.

- Fixing per rank specific computations in _MaskedPartial and Shard placements discovered
during test enablement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716
Approved by: https://github.com/ezyang
This commit is contained in:
Dzmitry Huba
2025-10-18 12:54:20 -07:00
committed by PyTorch MergeBot
parent f18041cca8
commit c4f6619330
9 changed files with 169 additions and 34 deletions

View File

@ -17,6 +17,7 @@ from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorConverter,
DTensorTestBase,
with_comms,
@ -704,6 +705,12 @@ class DistTensorOpsTest(DTensorTestBase):
@with_comms
def test_dtensor_dtype_conversion(self):
from torch.distributed.tensor.debug import (
_clear_sharding_prop_cache,
_get_sharding_prop_cache_info,
)
_clear_sharding_prop_cache()
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
# by default we start from bf16 dtype
@ -722,8 +729,6 @@ class DistTensorOpsTest(DTensorTestBase):
self.assertEqual(bf16_sharded_dtensor1.dtype, torch.bfloat16)
self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16)
from torch.distributed.tensor.debug import _get_sharding_prop_cache_info
# by this point we only have cache misses
hits, misses, _, _ = _get_sharding_prop_cache_info()
self.assertEqual(hits, 0)
@ -775,7 +780,7 @@ class DistTensorOpsTest(DTensorTestBase):
)
def _test_split_on_partial(self, reduce_op: str, split_size: int, split_dim: int):
torch.manual_seed(self.rank)
self.init_manual_seed_for_rank()
mesh = self.build_device_mesh()
partial_tensor = torch.randn(8, 8, device=self.device_type)
@ -822,5 +827,9 @@ class DistTensorOpsTest(DTensorTestBase):
self.assertEqual(x.full_tensor(), y)
DistTensorOpsTestWithLocalTensor = create_local_tensor_test_class(
DistTensorOpsTest,
)
if __name__ == "__main__":
run_tests()

View File

@ -53,7 +53,13 @@ class ProcessGroupTest(TestCase):
class Dist2MultiProcessTestCase(MultiProcessTestCase):
device: torch.device
@property
def device(self) -> torch.device:
raise NotImplementedError
# @device.setter
# def device(self, value: torch.device) -> None:
# self._device = value
@property
def world_size(self) -> int:
@ -257,7 +263,9 @@ class Dist2MultiProcessTestCase(MultiProcessTestCase):
class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
device = torch.device("cpu")
@property
def device(self) -> torch.device:
return torch.device("cpu")
@requires_gloo()
def new_group(self) -> torch.distributed.ProcessGroup:
@ -274,6 +282,10 @@ class ProcessGroupGlooTest(Dist2MultiProcessTestCase):
class ProcessGroupNCCLTest(Dist2MultiProcessTestCase):
@property
def device(self) -> torch.device:
return torch.device("cuda", self.rank)
@requires_nccl()
@skip_if_lt_x_gpu(2)
def new_group(self) -> torch.distributed.ProcessGroup:
@ -282,8 +294,6 @@ class ProcessGroupNCCLTest(Dist2MultiProcessTestCase):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29501"
self.device = torch.device("cuda", self.rank)
return dist2.new_group(
backend="nccl",
timeout=timedelta(seconds=60),