From c4f6619330bdac5bf4addb9070ecb42994202e1f Mon Sep 17 00:00:00 2001 From: Dzmitry Huba Date: Sat, 18 Oct 2025 12:54:20 -0700 Subject: [PATCH] Enable more DTensor tests in local tensor mode and fix more integration issues (#165716) - During op dispatch local tensor is supposed to collect rng state from CPU and CUDA devices so that it can be reset before execution of the op for each such that ops with randomness produces the same result for all ranks (note that we are planning a separate change to add support of per rank rng state). Previously we relied on op input arguments to deduce which devices to get rng state from. Which doesn't work for factory functions such torch.randn. Hence this changes switches to uncondionally collecting rng state from all devices. - Fixing per rank specific computations in _MaskedPartial and Shard placements discovered during test enablement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716 Approved by: https://github.com/ezyang --- test/distributed/tensor/test_tensor_ops.py | 15 +++- test/distributed/test_dist2.py | 18 ++++- torch/distributed/_local_tensor/__init__.py | 78 +++++++++++++++++-- .../distributed/tensor/_ops/_embedding_ops.py | 41 ++++++---- torch/distributed/tensor/_sharding_prop.py | 3 + torch/distributed/tensor/debug/__init__.py | 11 +++ torch/distributed/tensor/placement_types.py | 18 ++++- torch/testing/_internal/common_distributed.py | 16 +++- .../distributed/_tensor/common_dtensor.py | 3 + 9 files changed, 169 insertions(+), 34 deletions(-) diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index eaa1969068c1..8368befabfec 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -17,6 +17,7 @@ from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests, skipIfRocm from torch.testing._internal.distributed._tensor.common_dtensor import ( + create_local_tensor_test_class, DTensorConverter, DTensorTestBase, with_comms, @@ -704,6 +705,12 @@ class DistTensorOpsTest(DTensorTestBase): @with_comms def test_dtensor_dtype_conversion(self): + from torch.distributed.tensor.debug import ( + _clear_sharding_prop_cache, + _get_sharding_prop_cache_info, + ) + + _clear_sharding_prop_cache() device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] # by default we start from bf16 dtype @@ -722,8 +729,6 @@ class DistTensorOpsTest(DTensorTestBase): self.assertEqual(bf16_sharded_dtensor1.dtype, torch.bfloat16) self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16) - from torch.distributed.tensor.debug import _get_sharding_prop_cache_info - # by this point we only have cache misses hits, misses, _, _ = _get_sharding_prop_cache_info() self.assertEqual(hits, 0) @@ -775,7 +780,7 @@ class DistTensorOpsTest(DTensorTestBase): ) def _test_split_on_partial(self, reduce_op: str, split_size: int, split_dim: int): - torch.manual_seed(self.rank) + self.init_manual_seed_for_rank() mesh = self.build_device_mesh() partial_tensor = torch.randn(8, 8, device=self.device_type) @@ -822,5 +827,9 @@ class DistTensorOpsTest(DTensorTestBase): self.assertEqual(x.full_tensor(), y) +DistTensorOpsTestWithLocalTensor = create_local_tensor_test_class( + DistTensorOpsTest, +) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/test_dist2.py b/test/distributed/test_dist2.py index b335eff1c216..2c444fbfe567 100644 --- a/test/distributed/test_dist2.py +++ b/test/distributed/test_dist2.py @@ -53,7 +53,13 @@ class ProcessGroupTest(TestCase): class Dist2MultiProcessTestCase(MultiProcessTestCase): - device: torch.device + @property + def device(self) -> torch.device: + raise NotImplementedError + + # @device.setter + # def device(self, value: torch.device) -> None: + # self._device = value @property def world_size(self) -> int: @@ -257,7 +263,9 @@ class Dist2MultiProcessTestCase(MultiProcessTestCase): class ProcessGroupGlooTest(Dist2MultiProcessTestCase): - device = torch.device("cpu") + @property + def device(self) -> torch.device: + return torch.device("cpu") @requires_gloo() def new_group(self) -> torch.distributed.ProcessGroup: @@ -274,6 +282,10 @@ class ProcessGroupGlooTest(Dist2MultiProcessTestCase): class ProcessGroupNCCLTest(Dist2MultiProcessTestCase): + @property + def device(self) -> torch.device: + return torch.device("cuda", self.rank) + @requires_nccl() @skip_if_lt_x_gpu(2) def new_group(self) -> torch.distributed.ProcessGroup: @@ -282,8 +294,6 @@ class ProcessGroupNCCLTest(Dist2MultiProcessTestCase): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "29501" - self.device = torch.device("cuda", self.rank) - return dist2.new_group( backend="nccl", timeout=timedelta(seconds=60), diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index d9eb7b47e9a3..8121b367790a 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -104,6 +104,62 @@ def _map_to_rank_local_val(val: Any, rank: int) -> Any: return val +def collect_cuda_rng_states() -> list[torch.Tensor]: + """ + Collects RNG state from all available CUDA devices. + + Returns: + List of RNG state tensors, one for each CUDA device. + Returns empty list if CUDA is not available. + """ + if not torch.cuda.is_available(): + return [] + + num_devices = torch.cuda.device_count() + rng_states = [] + + for device_idx in range(num_devices): + with torch.cuda.device(device_idx): + rng_state = torch.cuda.get_rng_state() + rng_states.append(rng_state) + + return rng_states + + +def set_cuda_rng_states(rng_states: list[torch.Tensor]) -> None: + """ + Sets RNG state for all CUDA devices from a list of states. + + Args: + rng_states: List of RNG state tensors to restore. + """ + if not torch.cuda.is_available(): + return + + num_devices = min(len(rng_states), torch.cuda.device_count()) + + for device_idx in range(num_devices): + with torch.cuda.device(device_idx): + torch.cuda.set_rng_state(rng_states[device_idx]) + + +def _get_rng_state() -> tuple[torch.Tensor, list[torch.Tensor]]: + """ + Gets CPU and CUDA rng states from all devices. + """ + return (torch.get_rng_state(), collect_cuda_rng_states()) + + +def _set_rng_state(cpu_state: torch.Tensor, cuda_states: list[torch.Tensor]) -> None: + """ + Sets CPU and CUDA rng states for all devices. If the list of cuda states + is shorter than the number of devices only the first len(cuda_states) devices + will get their rng state set. + """ + torch.set_rng_state(cpu_state) + set_cuda_rng_states(cuda_states) + + def _for_each_rank_run_func( func: Callable[..., Any], ranks: frozenset[int], @@ -117,14 +173,15 @@ def _for_each_rank_run_func( a.wait() if isinstance(a, AsyncCollectiveTensor) else a for a in flat_args ] - cpu_state = torch.get_rng_state() - devices, states = get_device_states((args, kwargs)) - + # NB: Before invoking an op we are collecting rng states from CPU and + # CUDA devices such that we can reset to the same before invoking op + # for each rank. This is not very efficient and will likely be revisited + # to support per rank rng state. + rng_state = _get_rng_state() flat_rank_rets = {} for r in sorted(ranks): - torch.set_rng_state(cpu_state) - set_device_states(devices, states) + _set_rng_state(*rng_state) rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args] rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec) rank_ret = func(*rank_args, **rank_kwargs) @@ -704,6 +761,11 @@ class _LocalDeviceMesh: @staticmethod def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]: + # NB: In order to support submeshes the code below recreates for each + # rank submesh with the same mesh dimensions as current mesh. We are + # doing this because when submesh is created it is created for a particular + # rank (therefore below we are patching get_rank method). We are trying to + # limit the invasiveness of local tensor. lm = local_tensor_mode() assert lm is not None, "Unexpectedly not in LocalTensorMode" @@ -716,7 +778,9 @@ class _LocalDeviceMesh: coords[d][r] = c out = [torch.SymInt(LocalIntNode(c)) for c in coords] - + # The output contains coordinates for each of the ranks with respect to + # their meshes formed from root mesh and selecting the same dimensions + # as the current mesh. return out # type: ignore[return-value] @@ -794,8 +858,6 @@ def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]: with lm.disable(): ret = _for_each_rank_run_func(func, lm.ranks, args, kwargs, alias=False) - lm = local_tensor_mode() - assert lm is not None return ret return wrapper diff --git a/torch/distributed/tensor/_ops/_embedding_ops.py b/torch/distributed/tensor/_ops/_embedding_ops.py index 445b1830defe..283cffb78efd 100644 --- a/torch/distributed/tensor/_ops/_embedding_ops.py +++ b/torch/distributed/tensor/_ops/_embedding_ops.py @@ -6,6 +6,7 @@ from typing import cast, Optional import torch import torch.distributed._functional_collectives as funcol +from torch.distributed._local_tensor import maybe_run_for_local_tensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._op_schema import ( OpSchema, @@ -83,20 +84,11 @@ class _MaskPartial(Partial): offset_shape: Optional[torch.Size] = None offset_dim: int = 0 - def _partition_value( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - # override parent logic to perform partial mask for embedding - num_chunks = mesh.size(mesh_dim) - # get local shard size and offset on the embedding_dim - assert self.offset_shape is not None, ( - "offset_shape needs to be set for _MaskPartial" - ) - local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset( - self.offset_shape[self.offset_dim], - num_chunks, - mesh.get_local_rank(mesh_dim), - ) + @staticmethod + @maybe_run_for_local_tensor + def _mask_tensor( + tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int + ) -> tuple[torch.Tensor, torch.Tensor]: # Build the input mask and save it for the current partial placement # this is so that the output of embedding op can reuse the same partial # placement saved mask to perform mask + reduction @@ -106,6 +98,27 @@ class _MaskPartial(Partial): # mask the input tensor masked_tensor = tensor.clone() - local_offset_on_dim masked_tensor[mask] = 0 + return mask, masked_tensor + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + my_coordinate = mesh.get_coordinate() + assert my_coordinate is not None, "my_coordinate should not be None" + # override parent logic to perform partial mask for embedding + num_chunks = mesh.size(mesh_dim) + # get local shard size and offset on the embedding_dim + assert self.offset_shape is not None, ( + "offset_shape needs to be set for _MaskPartial" + ) + local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset( + self.offset_shape[self.offset_dim], + num_chunks, + my_coordinate[mesh_dim], + ) + mask, masked_tensor = _MaskPartial._mask_tensor( + tensor, local_offset_on_dim, local_shard_size + ) # materialize the mask buffer to be used for reduction self.mask_buffer.materialize_mask(mask) return masked_tensor diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index 4af72b4d3d8f..c1af2c131717 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -48,6 +48,9 @@ class LocalLRUCache(threading.local): def cache_info(self): return self.cache.cache_info() + def cache_clear(self): + return self.cache.cache_clear() + class ShardingPropagator: def __init__(self) -> None: diff --git a/torch/distributed/tensor/debug/__init__.py b/torch/distributed/tensor/debug/__init__.py index e5bf3b833fe4..a74f1449ad12 100644 --- a/torch/distributed/tensor/debug/__init__.py +++ b/torch/distributed/tensor/debug/__init__.py @@ -19,6 +19,17 @@ def _get_sharding_prop_cache_info(): ) +def _clear_sharding_prop_cache(): + """ + Clears the cache for the sharding propagation cache, used for debugging purpose only. + """ + from torch.distributed.tensor._api import DTensor + + return ( + DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_clear() # type:ignore[attr-defined] + ) + + # Set namespace for exposed private names CommDebugMode.__module__ = "torch.distributed.tensor.debug" visualize_sharding.__module__ = "torch.distributed.tensor.debug" diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index 5f68ff03ee22..8930d3b1b29c 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -359,6 +359,16 @@ class Shard(Placement): return Shard._select_shard(shards, shard_index) + @staticmethod + @maybe_run_for_local_tensor + def _get_shard_pad_size( + full_size: int, local_tensor: torch.Tensor, dim: int + ) -> int: + """ + Get the padding size of the local tensor on the shard dimension. + """ + return full_size - local_tensor.size(dim) + def _to_new_shard_dim( self, local_tensor: torch.Tensor, @@ -387,14 +397,16 @@ class Shard(Placement): old_dim_full_chunk_size = ( old_dim_logical_size + num_chunks - 1 ) // num_chunks - old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim) + old_dim_pad_size = Shard._get_shard_pad_size( + old_dim_full_chunk_size, local_tensor, self.dim + ) local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size) if new_dim_padding: new_dim_full_chunk_size = ( new_dim_logical_size + num_chunks - 1 ) // num_chunks - new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size( - new_shard_dim + new_dim_pad_size = Shard._get_shard_pad_size( + new_dim_full_chunk_size * num_chunks, local_tensor, new_shard_dim ) local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 89408b62c9aa..6cd372a8596c 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -211,6 +211,14 @@ def at_least_x_gpu(x): return False +def _maybe_handle_skip_if_lt_x_gpu(args, msg) -> bool: + _handle_test_skip = getattr(args[0], "_handle_test_skip", None) + if len(args) == 0 or _handle_test_skip is None: + return False + _handle_test_skip(msg) + return True + + def skip_if_lt_x_gpu(x): def decorator(func): @wraps(func) @@ -221,7 +229,9 @@ def skip_if_lt_x_gpu(x): return func(*args, **kwargs) if TEST_XPU and torch.xpu.device_count() >= x: return func(*args, **kwargs) - sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + test_skip = TEST_SKIPS[f"multi-gpu-{x}"] + if not _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message): + sys.exit(test_skip.exit_code) return wrapper @@ -237,7 +247,9 @@ def nccl_skip_if_lt_x_gpu(backend, x): return func(*args, **kwargs) if torch.cuda.is_available() and torch.cuda.device_count() >= x: return func(*args, **kwargs) - sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + test_skip = TEST_SKIPS[f"multi-gpu-{x}"] + if not _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message): + sys.exit(test_skip.exit_code) return wrapper diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 1f982aa42074..22d6d8e7dede 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -701,6 +701,9 @@ class DTensorConverter: class LocalDTensorTestBase(DTensorTestBase): + def _handle_test_skip(self, msg: str) -> None: + self.skipTest(msg) + def _get_local_tensor_mode(self): return LocalTensorMode(frozenset(range(self.world_size)))