Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)

- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA
devices so that it can be reset before execution of the op for each such that ops
with randomness produces the same result for all ranks (note that we are planning a
separate change to add support of per rank rng state). Previously we relied on
op input arguments to deduce which devices to get rng state from. Which doesn't work
for factory functions such torch.randn. Hence this changes switches to uncondionally
collecting rng state from all devices.

- Fixing per rank specific computations in _MaskedPartial and Shard placements discovered
during test enablement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716
Approved by: https://github.com/ezyang
This commit is contained in:
Dzmitry Huba
2025-10-17 09:01:44 -07:00
committed by PyTorch MergeBot
parent fe80f03726
commit 1b397420f2
8 changed files with 155 additions and 30 deletions

View File

@ -17,6 +17,7 @@ from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorConverter,
DTensorTestBase,
with_comms,
@ -704,6 +705,12 @@ class DistTensorOpsTest(DTensorTestBase):
@with_comms
def test_dtensor_dtype_conversion(self):
from torch.distributed.tensor.debug import (
_clear_sharding_prop_cache,
_get_sharding_prop_cache_info,
)
_clear_sharding_prop_cache()
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
# by default we start from bf16 dtype
@ -722,8 +729,6 @@ class DistTensorOpsTest(DTensorTestBase):
self.assertEqual(bf16_sharded_dtensor1.dtype, torch.bfloat16)
self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16)
from torch.distributed.tensor.debug import _get_sharding_prop_cache_info
# by this point we only have cache misses
hits, misses, _, _ = _get_sharding_prop_cache_info()
self.assertEqual(hits, 0)
@ -775,7 +780,7 @@ class DistTensorOpsTest(DTensorTestBase):
)
def _test_split_on_partial(self, reduce_op: str, split_size: int, split_dim: int):
torch.manual_seed(self.rank)
self.init_manual_seed_for_rank()
mesh = self.build_device_mesh()
partial_tensor = torch.randn(8, 8, device=self.device_type)
@ -822,5 +827,9 @@ class DistTensorOpsTest(DTensorTestBase):
self.assertEqual(x.full_tensor(), y)
DistTensorOpsTestWithLocalTensor = create_local_tensor_test_class(
DistTensorOpsTest,
)
if __name__ == "__main__":
run_tests()