mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)"
This reverts commit 1b397420f22b22f90a1093233ecd9167656e50cb. Reverted https://github.com/pytorch/pytorch/pull/165716 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165716#issuecomment-3418083391))
This commit is contained in:
@ -104,62 +104,6 @@ def _map_to_rank_local_val(val: Any, rank: int) -> Any:
|
||||
return val
|
||||
|
||||
|
||||
def collect_cuda_rng_states() -> list[torch.Tensor]:
|
||||
"""
|
||||
Collects RNG state from all available CUDA devices.
|
||||
|
||||
Returns:
|
||||
List of RNG state tensors, one for each CUDA device.
|
||||
Returns empty list if CUDA is not available.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return []
|
||||
|
||||
num_devices = torch.cuda.device_count()
|
||||
rng_states = []
|
||||
|
||||
for device_idx in range(num_devices):
|
||||
with torch.cuda.device(device_idx):
|
||||
rng_state = torch.cuda.get_rng_state()
|
||||
rng_states.append(rng_state)
|
||||
|
||||
return rng_states
|
||||
|
||||
|
||||
def set_cuda_rng_states(rng_states: list[torch.Tensor]) -> None:
|
||||
"""
|
||||
Sets RNG state for all CUDA devices from a list of states.
|
||||
|
||||
Args:
|
||||
rng_states: List of RNG state tensors to restore.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
num_devices = min(len(rng_states), torch.cuda.device_count())
|
||||
|
||||
for device_idx in range(num_devices):
|
||||
with torch.cuda.device(device_idx):
|
||||
torch.cuda.set_rng_state(rng_states[device_idx])
|
||||
|
||||
|
||||
def _get_rng_state() -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""
|
||||
Gets CPU and CUDA rng states from all devices.
|
||||
"""
|
||||
return (torch.get_rng_state(), collect_cuda_rng_states())
|
||||
|
||||
|
||||
def _set_rng_state(cpu_state: torch.Tensor, cuda_states: list[torch.Tensor]) -> None:
|
||||
"""
|
||||
Sets CPU and CUDA rng states for all devices. If the list of cuda states
|
||||
is shorter than the number of devices only the first len(cuda_states) devices
|
||||
will get their rng state set.
|
||||
"""
|
||||
torch.set_rng_state(cpu_state)
|
||||
set_cuda_rng_states(cuda_states)
|
||||
|
||||
|
||||
def _for_each_rank_run_func(
|
||||
func: Callable[..., Any],
|
||||
ranks: frozenset[int],
|
||||
@ -173,15 +117,14 @@ def _for_each_rank_run_func(
|
||||
a.wait() if isinstance(a, AsyncCollectiveTensor) else a for a in flat_args
|
||||
]
|
||||
|
||||
# NB: Before invoking an op we are collecting rng states from CPU and
|
||||
# CUDA devices such that we can reset to the same before invoking op
|
||||
# for each rank. This is not very efficient and will likely be revisited
|
||||
# to support per rank rng state.
|
||||
rng_state = _get_rng_state()
|
||||
cpu_state = torch.get_rng_state()
|
||||
devices, states = get_device_states((args, kwargs))
|
||||
|
||||
flat_rank_rets = {}
|
||||
|
||||
for r in sorted(ranks):
|
||||
_set_rng_state(*rng_state)
|
||||
torch.set_rng_state(cpu_state)
|
||||
set_device_states(devices, states)
|
||||
rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args]
|
||||
rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec)
|
||||
rank_ret = func(*rank_args, **rank_kwargs)
|
||||
@ -761,11 +704,6 @@ class _LocalDeviceMesh:
|
||||
|
||||
@staticmethod
|
||||
def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]:
|
||||
# NB: In order to support submeshes the code below recreates for each
|
||||
# rank submesh with the same mesh dimensions as current mesh. We are
|
||||
# doing this because when submesh is created it is created for a particular
|
||||
# rank (therefore below we are patching get_rank method). We are trying to
|
||||
# limit the invasiveness of local tensor.
|
||||
lm = local_tensor_mode()
|
||||
assert lm is not None, "Unexpectedly not in LocalTensorMode"
|
||||
|
||||
@ -778,9 +716,7 @@ class _LocalDeviceMesh:
|
||||
coords[d][r] = c
|
||||
|
||||
out = [torch.SymInt(LocalIntNode(c)) for c in coords]
|
||||
# The output contains coordinates for each of the ranks with respect to
|
||||
# their meshes formed from root mesh and selecting the same dimensions
|
||||
# as the current mesh.
|
||||
|
||||
return out # type: ignore[return-value]
|
||||
|
||||
|
||||
@ -858,6 +794,8 @@ def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
with lm.disable():
|
||||
ret = _for_each_rank_run_func(func, lm.ranks, args, kwargs, alias=False)
|
||||
|
||||
lm = local_tensor_mode()
|
||||
assert lm is not None
|
||||
return ret
|
||||
|
||||
return wrapper
|
||||
|
@ -6,7 +6,6 @@ from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch.distributed._local_tensor import maybe_run_for_local_tensor
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._op_schema import (
|
||||
OpSchema,
|
||||
@ -84,27 +83,9 @@ class _MaskPartial(Partial):
|
||||
offset_shape: Optional[torch.Size] = None
|
||||
offset_dim: int = 0
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _mask_tensor(
|
||||
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Build the input mask and save it for the current partial placement
|
||||
# this is so that the output of embedding op can reuse the same partial
|
||||
# placement saved mask to perform mask + reduction
|
||||
mask = (tensor < local_offset_on_dim) | (
|
||||
tensor >= local_offset_on_dim + local_shard_size
|
||||
)
|
||||
# mask the input tensor
|
||||
masked_tensor = tensor.clone() - local_offset_on_dim
|
||||
masked_tensor[mask] = 0
|
||||
return mask, masked_tensor
|
||||
|
||||
def _partition_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
my_coordinate = mesh.get_coordinate()
|
||||
assert my_coordinate is not None, "my_coordinate should not be None"
|
||||
# override parent logic to perform partial mask for embedding
|
||||
num_chunks = mesh.size(mesh_dim)
|
||||
# get local shard size and offset on the embedding_dim
|
||||
@ -114,11 +95,17 @@ class _MaskPartial(Partial):
|
||||
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
|
||||
self.offset_shape[self.offset_dim],
|
||||
num_chunks,
|
||||
my_coordinate[mesh_dim],
|
||||
mesh.get_local_rank(mesh_dim),
|
||||
)
|
||||
mask, masked_tensor = _MaskPartial._mask_tensor(
|
||||
tensor, local_offset_on_dim, local_shard_size
|
||||
# Build the input mask and save it for the current partial placement
|
||||
# this is so that the output of embedding op can reuse the same partial
|
||||
# placement saved mask to perform mask + reduction
|
||||
mask = (tensor < local_offset_on_dim) | (
|
||||
tensor >= local_offset_on_dim + local_shard_size
|
||||
)
|
||||
# mask the input tensor
|
||||
masked_tensor = tensor.clone() - local_offset_on_dim
|
||||
masked_tensor[mask] = 0
|
||||
# materialize the mask buffer to be used for reduction
|
||||
self.mask_buffer.materialize_mask(mask)
|
||||
return masked_tensor
|
||||
|
@ -48,9 +48,6 @@ class LocalLRUCache(threading.local):
|
||||
def cache_info(self):
|
||||
return self.cache.cache_info()
|
||||
|
||||
def cache_clear(self):
|
||||
return self.cache.cache_clear()
|
||||
|
||||
|
||||
class ShardingPropagator:
|
||||
def __init__(self) -> None:
|
||||
|
@ -19,17 +19,6 @@ def _get_sharding_prop_cache_info():
|
||||
)
|
||||
|
||||
|
||||
def _clear_sharding_prop_cache():
|
||||
"""
|
||||
Clears the cache for the sharding propagation cache, used for debugging purpose only.
|
||||
"""
|
||||
from torch.distributed.tensor._api import DTensor
|
||||
|
||||
return (
|
||||
DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_clear() # type:ignore[attr-defined]
|
||||
)
|
||||
|
||||
|
||||
# Set namespace for exposed private names
|
||||
CommDebugMode.__module__ = "torch.distributed.tensor.debug"
|
||||
visualize_sharding.__module__ = "torch.distributed.tensor.debug"
|
||||
|
@ -359,16 +359,6 @@ class Shard(Placement):
|
||||
|
||||
return Shard._select_shard(shards, shard_index)
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _get_shard_pad_size(
|
||||
full_size: int, local_tensor: torch.Tensor, dim: int
|
||||
) -> int:
|
||||
"""
|
||||
Get the padding size of the local tensor on the shard dimension.
|
||||
"""
|
||||
return full_size - local_tensor.size(dim)
|
||||
|
||||
def _to_new_shard_dim(
|
||||
self,
|
||||
local_tensor: torch.Tensor,
|
||||
@ -397,16 +387,14 @@ class Shard(Placement):
|
||||
old_dim_full_chunk_size = (
|
||||
old_dim_logical_size + num_chunks - 1
|
||||
) // num_chunks
|
||||
old_dim_pad_size = Shard._get_shard_pad_size(
|
||||
old_dim_full_chunk_size, local_tensor, self.dim
|
||||
)
|
||||
old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim)
|
||||
local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size)
|
||||
if new_dim_padding:
|
||||
new_dim_full_chunk_size = (
|
||||
new_dim_logical_size + num_chunks - 1
|
||||
) // num_chunks
|
||||
new_dim_pad_size = Shard._get_shard_pad_size(
|
||||
new_dim_full_chunk_size * num_chunks, local_tensor, new_shard_dim
|
||||
new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size(
|
||||
new_shard_dim
|
||||
)
|
||||
local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size)
|
||||
|
||||
|
@ -211,14 +211,6 @@ def at_least_x_gpu(x):
|
||||
return False
|
||||
|
||||
|
||||
def _maybe_handle_skip_if_lt_x_gpu(args, msg) -> bool:
|
||||
_handle_test_skip = getattr(args[0], "_handle_test_skip", None)
|
||||
if len(args) == 0 or _handle_test_skip is None:
|
||||
return False
|
||||
_handle_test_skip(msg)
|
||||
return True
|
||||
|
||||
|
||||
def skip_if_lt_x_gpu(x):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
@ -229,9 +221,7 @@ def skip_if_lt_x_gpu(x):
|
||||
return func(*args, **kwargs)
|
||||
if TEST_XPU and torch.xpu.device_count() >= x:
|
||||
return func(*args, **kwargs)
|
||||
test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
|
||||
if _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
|
||||
sys.exit(test_skip.exit_code)
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -247,9 +237,7 @@ def nccl_skip_if_lt_x_gpu(backend, x):
|
||||
return func(*args, **kwargs)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
|
||||
return func(*args, **kwargs)
|
||||
test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
|
||||
if _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
|
||||
sys.exit(test_skip.exit_code)
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
@ -701,9 +701,6 @@ class DTensorConverter:
|
||||
|
||||
|
||||
class LocalDTensorTestBase(DTensorTestBase):
|
||||
def _handle_test_skip(self, msg: str) -> None:
|
||||
self.skipTest(msg)
|
||||
|
||||
def _get_local_tensor_mode(self):
|
||||
return LocalTensorMode(frozenset(range(self.world_size)))
|
||||
|
||||
|
Reference in New Issue
Block a user