mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)
- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA devices so that it can be reset before execution of the op for each such that ops with randomness produces the same result for all ranks (note that we are planning a separate change to add support of per rank rng state). Previously we relied on op input arguments to deduce which devices to get rng state from. Which doesn't work for factory functions such torch.randn. Hence this changes switches to uncondionally collecting rng state from all devices. - Fixing per rank specific computations in _MaskedPartial and Shard placements discovered during test enablement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
f18041cca8
commit
c4f6619330
@ -104,6 +104,62 @@ def _map_to_rank_local_val(val: Any, rank: int) -> Any:
|
||||
return val
|
||||
|
||||
|
||||
def collect_cuda_rng_states() -> list[torch.Tensor]:
|
||||
"""
|
||||
Collects RNG state from all available CUDA devices.
|
||||
|
||||
Returns:
|
||||
List of RNG state tensors, one for each CUDA device.
|
||||
Returns empty list if CUDA is not available.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return []
|
||||
|
||||
num_devices = torch.cuda.device_count()
|
||||
rng_states = []
|
||||
|
||||
for device_idx in range(num_devices):
|
||||
with torch.cuda.device(device_idx):
|
||||
rng_state = torch.cuda.get_rng_state()
|
||||
rng_states.append(rng_state)
|
||||
|
||||
return rng_states
|
||||
|
||||
|
||||
def set_cuda_rng_states(rng_states: list[torch.Tensor]) -> None:
|
||||
"""
|
||||
Sets RNG state for all CUDA devices from a list of states.
|
||||
|
||||
Args:
|
||||
rng_states: List of RNG state tensors to restore.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
num_devices = min(len(rng_states), torch.cuda.device_count())
|
||||
|
||||
for device_idx in range(num_devices):
|
||||
with torch.cuda.device(device_idx):
|
||||
torch.cuda.set_rng_state(rng_states[device_idx])
|
||||
|
||||
|
||||
def _get_rng_state() -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""
|
||||
Gets CPU and CUDA rng states from all devices.
|
||||
"""
|
||||
return (torch.get_rng_state(), collect_cuda_rng_states())
|
||||
|
||||
|
||||
def _set_rng_state(cpu_state: torch.Tensor, cuda_states: list[torch.Tensor]) -> None:
|
||||
"""
|
||||
Sets CPU and CUDA rng states for all devices. If the list of cuda states
|
||||
is shorter than the number of devices only the first len(cuda_states) devices
|
||||
will get their rng state set.
|
||||
"""
|
||||
torch.set_rng_state(cpu_state)
|
||||
set_cuda_rng_states(cuda_states)
|
||||
|
||||
|
||||
def _for_each_rank_run_func(
|
||||
func: Callable[..., Any],
|
||||
ranks: frozenset[int],
|
||||
@ -117,14 +173,15 @@ def _for_each_rank_run_func(
|
||||
a.wait() if isinstance(a, AsyncCollectiveTensor) else a for a in flat_args
|
||||
]
|
||||
|
||||
cpu_state = torch.get_rng_state()
|
||||
devices, states = get_device_states((args, kwargs))
|
||||
|
||||
# NB: Before invoking an op we are collecting rng states from CPU and
|
||||
# CUDA devices such that we can reset to the same before invoking op
|
||||
# for each rank. This is not very efficient and will likely be revisited
|
||||
# to support per rank rng state.
|
||||
rng_state = _get_rng_state()
|
||||
flat_rank_rets = {}
|
||||
|
||||
for r in sorted(ranks):
|
||||
torch.set_rng_state(cpu_state)
|
||||
set_device_states(devices, states)
|
||||
_set_rng_state(*rng_state)
|
||||
rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args]
|
||||
rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec)
|
||||
rank_ret = func(*rank_args, **rank_kwargs)
|
||||
@ -704,6 +761,11 @@ class _LocalDeviceMesh:
|
||||
|
||||
@staticmethod
|
||||
def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]:
|
||||
# NB: In order to support submeshes the code below recreates for each
|
||||
# rank submesh with the same mesh dimensions as current mesh. We are
|
||||
# doing this because when submesh is created it is created for a particular
|
||||
# rank (therefore below we are patching get_rank method). We are trying to
|
||||
# limit the invasiveness of local tensor.
|
||||
lm = local_tensor_mode()
|
||||
assert lm is not None, "Unexpectedly not in LocalTensorMode"
|
||||
|
||||
@ -716,7 +778,9 @@ class _LocalDeviceMesh:
|
||||
coords[d][r] = c
|
||||
|
||||
out = [torch.SymInt(LocalIntNode(c)) for c in coords]
|
||||
|
||||
# The output contains coordinates for each of the ranks with respect to
|
||||
# their meshes formed from root mesh and selecting the same dimensions
|
||||
# as the current mesh.
|
||||
return out # type: ignore[return-value]
|
||||
|
||||
|
||||
@ -794,8 +858,6 @@ def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
with lm.disable():
|
||||
ret = _for_each_rank_run_func(func, lm.ranks, args, kwargs, alias=False)
|
||||
|
||||
lm = local_tensor_mode()
|
||||
assert lm is not None
|
||||
return ret
|
||||
|
||||
return wrapper
|
||||
|
@ -6,6 +6,7 @@ from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch.distributed._local_tensor import maybe_run_for_local_tensor
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._op_schema import (
|
||||
OpSchema,
|
||||
@ -83,20 +84,11 @@ class _MaskPartial(Partial):
|
||||
offset_shape: Optional[torch.Size] = None
|
||||
offset_dim: int = 0
|
||||
|
||||
def _partition_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
# override parent logic to perform partial mask for embedding
|
||||
num_chunks = mesh.size(mesh_dim)
|
||||
# get local shard size and offset on the embedding_dim
|
||||
assert self.offset_shape is not None, (
|
||||
"offset_shape needs to be set for _MaskPartial"
|
||||
)
|
||||
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
|
||||
self.offset_shape[self.offset_dim],
|
||||
num_chunks,
|
||||
mesh.get_local_rank(mesh_dim),
|
||||
)
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _mask_tensor(
|
||||
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Build the input mask and save it for the current partial placement
|
||||
# this is so that the output of embedding op can reuse the same partial
|
||||
# placement saved mask to perform mask + reduction
|
||||
@ -106,6 +98,27 @@ class _MaskPartial(Partial):
|
||||
# mask the input tensor
|
||||
masked_tensor = tensor.clone() - local_offset_on_dim
|
||||
masked_tensor[mask] = 0
|
||||
return mask, masked_tensor
|
||||
|
||||
def _partition_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
my_coordinate = mesh.get_coordinate()
|
||||
assert my_coordinate is not None, "my_coordinate should not be None"
|
||||
# override parent logic to perform partial mask for embedding
|
||||
num_chunks = mesh.size(mesh_dim)
|
||||
# get local shard size and offset on the embedding_dim
|
||||
assert self.offset_shape is not None, (
|
||||
"offset_shape needs to be set for _MaskPartial"
|
||||
)
|
||||
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
|
||||
self.offset_shape[self.offset_dim],
|
||||
num_chunks,
|
||||
my_coordinate[mesh_dim],
|
||||
)
|
||||
mask, masked_tensor = _MaskPartial._mask_tensor(
|
||||
tensor, local_offset_on_dim, local_shard_size
|
||||
)
|
||||
# materialize the mask buffer to be used for reduction
|
||||
self.mask_buffer.materialize_mask(mask)
|
||||
return masked_tensor
|
||||
|
@ -48,6 +48,9 @@ class LocalLRUCache(threading.local):
|
||||
def cache_info(self):
|
||||
return self.cache.cache_info()
|
||||
|
||||
def cache_clear(self):
|
||||
return self.cache.cache_clear()
|
||||
|
||||
|
||||
class ShardingPropagator:
|
||||
def __init__(self) -> None:
|
||||
|
@ -19,6 +19,17 @@ def _get_sharding_prop_cache_info():
|
||||
)
|
||||
|
||||
|
||||
def _clear_sharding_prop_cache():
|
||||
"""
|
||||
Clears the cache for the sharding propagation cache, used for debugging purpose only.
|
||||
"""
|
||||
from torch.distributed.tensor._api import DTensor
|
||||
|
||||
return (
|
||||
DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_clear() # type:ignore[attr-defined]
|
||||
)
|
||||
|
||||
|
||||
# Set namespace for exposed private names
|
||||
CommDebugMode.__module__ = "torch.distributed.tensor.debug"
|
||||
visualize_sharding.__module__ = "torch.distributed.tensor.debug"
|
||||
|
@ -359,6 +359,16 @@ class Shard(Placement):
|
||||
|
||||
return Shard._select_shard(shards, shard_index)
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _get_shard_pad_size(
|
||||
full_size: int, local_tensor: torch.Tensor, dim: int
|
||||
) -> int:
|
||||
"""
|
||||
Get the padding size of the local tensor on the shard dimension.
|
||||
"""
|
||||
return full_size - local_tensor.size(dim)
|
||||
|
||||
def _to_new_shard_dim(
|
||||
self,
|
||||
local_tensor: torch.Tensor,
|
||||
@ -387,14 +397,16 @@ class Shard(Placement):
|
||||
old_dim_full_chunk_size = (
|
||||
old_dim_logical_size + num_chunks - 1
|
||||
) // num_chunks
|
||||
old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim)
|
||||
old_dim_pad_size = Shard._get_shard_pad_size(
|
||||
old_dim_full_chunk_size, local_tensor, self.dim
|
||||
)
|
||||
local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size)
|
||||
if new_dim_padding:
|
||||
new_dim_full_chunk_size = (
|
||||
new_dim_logical_size + num_chunks - 1
|
||||
) // num_chunks
|
||||
new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size(
|
||||
new_shard_dim
|
||||
new_dim_pad_size = Shard._get_shard_pad_size(
|
||||
new_dim_full_chunk_size * num_chunks, local_tensor, new_shard_dim
|
||||
)
|
||||
local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size)
|
||||
|
||||
|
@ -211,6 +211,14 @@ def at_least_x_gpu(x):
|
||||
return False
|
||||
|
||||
|
||||
def _maybe_handle_skip_if_lt_x_gpu(args, msg) -> bool:
|
||||
_handle_test_skip = getattr(args[0], "_handle_test_skip", None)
|
||||
if len(args) == 0 or _handle_test_skip is None:
|
||||
return False
|
||||
_handle_test_skip(msg)
|
||||
return True
|
||||
|
||||
|
||||
def skip_if_lt_x_gpu(x):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
@ -221,7 +229,9 @@ def skip_if_lt_x_gpu(x):
|
||||
return func(*args, **kwargs)
|
||||
if TEST_XPU and torch.xpu.device_count() >= x:
|
||||
return func(*args, **kwargs)
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||
test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
|
||||
if not _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
|
||||
sys.exit(test_skip.exit_code)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -237,7 +247,9 @@ def nccl_skip_if_lt_x_gpu(backend, x):
|
||||
return func(*args, **kwargs)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
|
||||
return func(*args, **kwargs)
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||
test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
|
||||
if not _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
|
||||
sys.exit(test_skip.exit_code)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
@ -701,6 +701,9 @@ class DTensorConverter:
|
||||
|
||||
|
||||
class LocalDTensorTestBase(DTensorTestBase):
|
||||
def _handle_test_skip(self, msg: str) -> None:
|
||||
self.skipTest(msg)
|
||||
|
||||
def _get_local_tensor_mode(self):
|
||||
return LocalTensorMode(frozenset(range(self.world_size)))
|
||||
|
||||
|
Reference in New Issue
Block a user