Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)

- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA
devices so that it can be reset before execution of the op for each such that ops
with randomness produces the same result for all ranks (note that we are planning a
separate change to add support of per rank rng state). Previously we relied on
op input arguments to deduce which devices to get rng state from. Which doesn't work
for factory functions such torch.randn. Hence this changes switches to uncondionally
collecting rng state from all devices.

- Fixing per rank specific computations in _MaskedPartial and Shard placements discovered
during test enablement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716
Approved by: https://github.com/ezyang
This commit is contained in:
Dzmitry Huba
2025-10-17 09:01:44 -07:00
committed by PyTorch MergeBot
parent fe80f03726
commit 1b397420f2
8 changed files with 155 additions and 30 deletions

View File

@ -104,6 +104,62 @@ def _map_to_rank_local_val(val: Any, rank: int) -> Any:
return val
def collect_cuda_rng_states() -> list[torch.Tensor]:
"""
Collects RNG state from all available CUDA devices.
Returns:
List of RNG state tensors, one for each CUDA device.
Returns empty list if CUDA is not available.
"""
if not torch.cuda.is_available():
return []
num_devices = torch.cuda.device_count()
rng_states = []
for device_idx in range(num_devices):
with torch.cuda.device(device_idx):
rng_state = torch.cuda.get_rng_state()
rng_states.append(rng_state)
return rng_states
def set_cuda_rng_states(rng_states: list[torch.Tensor]) -> None:
"""
Sets RNG state for all CUDA devices from a list of states.
Args:
rng_states: List of RNG state tensors to restore.
"""
if not torch.cuda.is_available():
return
num_devices = min(len(rng_states), torch.cuda.device_count())
for device_idx in range(num_devices):
with torch.cuda.device(device_idx):
torch.cuda.set_rng_state(rng_states[device_idx])
def _get_rng_state() -> tuple[torch.Tensor, list[torch.Tensor]]:
"""
Gets CPU and CUDA rng states from all devices.
"""
return (torch.get_rng_state(), collect_cuda_rng_states())
def _set_rng_state(cpu_state: torch.Tensor, cuda_states: list[torch.Tensor]) -> None:
"""
Sets CPU and CUDA rng states for all devices. If the list of cuda states
is shorter than the number of devices only the first len(cuda_states) devices
will get their rng state set.
"""
torch.set_rng_state(cpu_state)
set_cuda_rng_states(cuda_states)
def _for_each_rank_run_func(
func: Callable[..., Any],
ranks: frozenset[int],
@ -117,14 +173,15 @@ def _for_each_rank_run_func(
a.wait() if isinstance(a, AsyncCollectiveTensor) else a for a in flat_args
]
cpu_state = torch.get_rng_state()
devices, states = get_device_states((args, kwargs))
# NB: Before invoking an op we are collecting rng states from CPU and
# CUDA devices such that we can reset to the same before invoking op
# for each rank. This is not very efficient and will likely be revisited
# to support per rank rng state.
rng_state = _get_rng_state()
flat_rank_rets = {}
for r in sorted(ranks):
torch.set_rng_state(cpu_state)
set_device_states(devices, states)
_set_rng_state(*rng_state)
rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args]
rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec)
rank_ret = func(*rank_args, **rank_kwargs)
@ -704,6 +761,11 @@ class _LocalDeviceMesh:
@staticmethod
def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]:
# NB: In order to support submeshes the code below recreates for each
# rank submesh with the same mesh dimensions as current mesh. We are
# doing this because when submesh is created it is created for a particular
# rank (therefore below we are patching get_rank method). We are trying to
# limit the invasiveness of local tensor.
lm = local_tensor_mode()
assert lm is not None, "Unexpectedly not in LocalTensorMode"
@ -716,7 +778,9 @@ class _LocalDeviceMesh:
coords[d][r] = c
out = [torch.SymInt(LocalIntNode(c)) for c in coords]
# The output contains coordinates for each of the ranks with respect to
# their meshes formed from root mesh and selecting the same dimensions
# as the current mesh.
return out # type: ignore[return-value]
@ -794,8 +858,6 @@ def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]:
with lm.disable():
ret = _for_each_rank_run_func(func, lm.ranks, args, kwargs, alias=False)
lm = local_tensor_mode()
assert lm is not None
return ret
return wrapper

View File

@ -6,6 +6,7 @@ from typing import cast, Optional
import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed._local_tensor import maybe_run_for_local_tensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
OpSchema,
@ -83,20 +84,11 @@ class _MaskPartial(Partial):
offset_shape: Optional[torch.Size] = None
offset_dim: int = 0
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
# override parent logic to perform partial mask for embedding
num_chunks = mesh.size(mesh_dim)
# get local shard size and offset on the embedding_dim
assert self.offset_shape is not None, (
"offset_shape needs to be set for _MaskPartial"
)
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
self.offset_shape[self.offset_dim],
num_chunks,
mesh.get_local_rank(mesh_dim),
)
@staticmethod
@maybe_run_for_local_tensor
def _mask_tensor(
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
# Build the input mask and save it for the current partial placement
# this is so that the output of embedding op can reuse the same partial
# placement saved mask to perform mask + reduction
@ -106,6 +98,27 @@ class _MaskPartial(Partial):
# mask the input tensor
masked_tensor = tensor.clone() - local_offset_on_dim
masked_tensor[mask] = 0
return mask, masked_tensor
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
my_coordinate = mesh.get_coordinate()
assert my_coordinate is not None, "my_coordinate should not be None"
# override parent logic to perform partial mask for embedding
num_chunks = mesh.size(mesh_dim)
# get local shard size and offset on the embedding_dim
assert self.offset_shape is not None, (
"offset_shape needs to be set for _MaskPartial"
)
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
self.offset_shape[self.offset_dim],
num_chunks,
my_coordinate[mesh_dim],
)
mask, masked_tensor = _MaskPartial._mask_tensor(
tensor, local_offset_on_dim, local_shard_size
)
# materialize the mask buffer to be used for reduction
self.mask_buffer.materialize_mask(mask)
return masked_tensor

View File

@ -48,6 +48,9 @@ class LocalLRUCache(threading.local):
def cache_info(self):
return self.cache.cache_info()
def cache_clear(self):
return self.cache.cache_clear()
class ShardingPropagator:
def __init__(self) -> None:

View File

@ -19,6 +19,17 @@ def _get_sharding_prop_cache_info():
)
def _clear_sharding_prop_cache():
"""
Clears the cache for the sharding propagation cache, used for debugging purpose only.
"""
from torch.distributed.tensor._api import DTensor
return (
DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_clear() # type:ignore[attr-defined]
)
# Set namespace for exposed private names
CommDebugMode.__module__ = "torch.distributed.tensor.debug"
visualize_sharding.__module__ = "torch.distributed.tensor.debug"

View File

@ -359,6 +359,16 @@ class Shard(Placement):
return Shard._select_shard(shards, shard_index)
@staticmethod
@maybe_run_for_local_tensor
def _get_shard_pad_size(
full_size: int, local_tensor: torch.Tensor, dim: int
) -> int:
"""
Get the padding size of the local tensor on the shard dimension.
"""
return full_size - local_tensor.size(dim)
def _to_new_shard_dim(
self,
local_tensor: torch.Tensor,
@ -387,14 +397,16 @@ class Shard(Placement):
old_dim_full_chunk_size = (
old_dim_logical_size + num_chunks - 1
) // num_chunks
old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim)
old_dim_pad_size = Shard._get_shard_pad_size(
old_dim_full_chunk_size, local_tensor, self.dim
)
local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size)
if new_dim_padding:
new_dim_full_chunk_size = (
new_dim_logical_size + num_chunks - 1
) // num_chunks
new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size(
new_shard_dim
new_dim_pad_size = Shard._get_shard_pad_size(
new_dim_full_chunk_size * num_chunks, local_tensor, new_shard_dim
)
local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size)

View File

@ -211,6 +211,14 @@ def at_least_x_gpu(x):
return False
def _maybe_handle_skip_if_lt_x_gpu(args, msg) -> bool:
_handle_test_skip = getattr(args[0], "_handle_test_skip", None)
if len(args) == 0 or _handle_test_skip is None:
return False
_handle_test_skip(msg)
return True
def skip_if_lt_x_gpu(x):
def decorator(func):
@wraps(func)
@ -221,7 +229,9 @@ def skip_if_lt_x_gpu(x):
return func(*args, **kwargs)
if TEST_XPU and torch.xpu.device_count() >= x:
return func(*args, **kwargs)
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
if _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
sys.exit(test_skip.exit_code)
return wrapper
@ -237,7 +247,9 @@ def nccl_skip_if_lt_x_gpu(backend, x):
return func(*args, **kwargs)
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
return func(*args, **kwargs)
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
if _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
sys.exit(test_skip.exit_code)
return wrapper

View File

@ -701,6 +701,9 @@ class DTensorConverter:
class LocalDTensorTestBase(DTensorTestBase):
def _handle_test_skip(self, msg: str) -> None:
self.skipTest(msg)
def _get_local_tensor_mode(self):
return LocalTensorMode(frozenset(range(0, self.world_size)))