Revert "Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)"

This reverts commit 1b397420f22b22f90a1093233ecd9167656e50cb.

Reverted https://github.com/pytorch/pytorch/pull/165716 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165716#issuecomment-3418083391))
This commit is contained in:
PyTorch MergeBot
2025-10-18 09:15:49 +00:00
parent 4740ce7787
commit beb6b62e8c
8 changed files with 25 additions and 150 deletions

View File

@ -104,62 +104,6 @@ def _map_to_rank_local_val(val: Any, rank: int) -> Any:
return val
def collect_cuda_rng_states() -> list[torch.Tensor]:
"""
Collects RNG state from all available CUDA devices.
Returns:
List of RNG state tensors, one for each CUDA device.
Returns empty list if CUDA is not available.
"""
if not torch.cuda.is_available():
return []
num_devices = torch.cuda.device_count()
rng_states = []
for device_idx in range(num_devices):
with torch.cuda.device(device_idx):
rng_state = torch.cuda.get_rng_state()
rng_states.append(rng_state)
return rng_states
def set_cuda_rng_states(rng_states: list[torch.Tensor]) -> None:
"""
Sets RNG state for all CUDA devices from a list of states.
Args:
rng_states: List of RNG state tensors to restore.
"""
if not torch.cuda.is_available():
return
num_devices = min(len(rng_states), torch.cuda.device_count())
for device_idx in range(num_devices):
with torch.cuda.device(device_idx):
torch.cuda.set_rng_state(rng_states[device_idx])
def _get_rng_state() -> tuple[torch.Tensor, list[torch.Tensor]]:
"""
Gets CPU and CUDA rng states from all devices.
"""
return (torch.get_rng_state(), collect_cuda_rng_states())
def _set_rng_state(cpu_state: torch.Tensor, cuda_states: list[torch.Tensor]) -> None:
"""
Sets CPU and CUDA rng states for all devices. If the list of cuda states
is shorter than the number of devices only the first len(cuda_states) devices
will get their rng state set.
"""
torch.set_rng_state(cpu_state)
set_cuda_rng_states(cuda_states)
def _for_each_rank_run_func(
func: Callable[..., Any],
ranks: frozenset[int],
@ -173,15 +117,14 @@ def _for_each_rank_run_func(
a.wait() if isinstance(a, AsyncCollectiveTensor) else a for a in flat_args
]
# NB: Before invoking an op we are collecting rng states from CPU and
# CUDA devices such that we can reset to the same before invoking op
# for each rank. This is not very efficient and will likely be revisited
# to support per rank rng state.
rng_state = _get_rng_state()
cpu_state = torch.get_rng_state()
devices, states = get_device_states((args, kwargs))
flat_rank_rets = {}
for r in sorted(ranks):
_set_rng_state(*rng_state)
torch.set_rng_state(cpu_state)
set_device_states(devices, states)
rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args]
rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec)
rank_ret = func(*rank_args, **rank_kwargs)
@ -761,11 +704,6 @@ class _LocalDeviceMesh:
@staticmethod
def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]:
# NB: In order to support submeshes the code below recreates for each
# rank submesh with the same mesh dimensions as current mesh. We are
# doing this because when submesh is created it is created for a particular
# rank (therefore below we are patching get_rank method). We are trying to
# limit the invasiveness of local tensor.
lm = local_tensor_mode()
assert lm is not None, "Unexpectedly not in LocalTensorMode"
@ -778,9 +716,7 @@ class _LocalDeviceMesh:
coords[d][r] = c
out = [torch.SymInt(LocalIntNode(c)) for c in coords]
# The output contains coordinates for each of the ranks with respect to
# their meshes formed from root mesh and selecting the same dimensions
# as the current mesh.
return out # type: ignore[return-value]
@ -858,6 +794,8 @@ def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]:
with lm.disable():
ret = _for_each_rank_run_func(func, lm.ranks, args, kwargs, alias=False)
lm = local_tensor_mode()
assert lm is not None
return ret
return wrapper

View File

@ -6,7 +6,6 @@ from typing import cast, Optional
import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed._local_tensor import maybe_run_for_local_tensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
OpSchema,
@ -84,27 +83,9 @@ class _MaskPartial(Partial):
offset_shape: Optional[torch.Size] = None
offset_dim: int = 0
@staticmethod
@maybe_run_for_local_tensor
def _mask_tensor(
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
# Build the input mask and save it for the current partial placement
# this is so that the output of embedding op can reuse the same partial
# placement saved mask to perform mask + reduction
mask = (tensor < local_offset_on_dim) | (
tensor >= local_offset_on_dim + local_shard_size
)
# mask the input tensor
masked_tensor = tensor.clone() - local_offset_on_dim
masked_tensor[mask] = 0
return mask, masked_tensor
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
my_coordinate = mesh.get_coordinate()
assert my_coordinate is not None, "my_coordinate should not be None"
# override parent logic to perform partial mask for embedding
num_chunks = mesh.size(mesh_dim)
# get local shard size and offset on the embedding_dim
@ -114,11 +95,17 @@ class _MaskPartial(Partial):
local_shard_size, local_offset_on_dim = Shard.local_shard_size_and_offset(
self.offset_shape[self.offset_dim],
num_chunks,
my_coordinate[mesh_dim],
mesh.get_local_rank(mesh_dim),
)
mask, masked_tensor = _MaskPartial._mask_tensor(
tensor, local_offset_on_dim, local_shard_size
# Build the input mask and save it for the current partial placement
# this is so that the output of embedding op can reuse the same partial
# placement saved mask to perform mask + reduction
mask = (tensor < local_offset_on_dim) | (
tensor >= local_offset_on_dim + local_shard_size
)
# mask the input tensor
masked_tensor = tensor.clone() - local_offset_on_dim
masked_tensor[mask] = 0
# materialize the mask buffer to be used for reduction
self.mask_buffer.materialize_mask(mask)
return masked_tensor

View File

@ -48,9 +48,6 @@ class LocalLRUCache(threading.local):
def cache_info(self):
return self.cache.cache_info()
def cache_clear(self):
return self.cache.cache_clear()
class ShardingPropagator:
def __init__(self) -> None:

View File

@ -19,17 +19,6 @@ def _get_sharding_prop_cache_info():
)
def _clear_sharding_prop_cache():
"""
Clears the cache for the sharding propagation cache, used for debugging purpose only.
"""
from torch.distributed.tensor._api import DTensor
return (
DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_clear() # type:ignore[attr-defined]
)
# Set namespace for exposed private names
CommDebugMode.__module__ = "torch.distributed.tensor.debug"
visualize_sharding.__module__ = "torch.distributed.tensor.debug"

View File

@ -359,16 +359,6 @@ class Shard(Placement):
return Shard._select_shard(shards, shard_index)
@staticmethod
@maybe_run_for_local_tensor
def _get_shard_pad_size(
full_size: int, local_tensor: torch.Tensor, dim: int
) -> int:
"""
Get the padding size of the local tensor on the shard dimension.
"""
return full_size - local_tensor.size(dim)
def _to_new_shard_dim(
self,
local_tensor: torch.Tensor,
@ -397,16 +387,14 @@ class Shard(Placement):
old_dim_full_chunk_size = (
old_dim_logical_size + num_chunks - 1
) // num_chunks
old_dim_pad_size = Shard._get_shard_pad_size(
old_dim_full_chunk_size, local_tensor, self.dim
)
old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim)
local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size)
if new_dim_padding:
new_dim_full_chunk_size = (
new_dim_logical_size + num_chunks - 1
) // num_chunks
new_dim_pad_size = Shard._get_shard_pad_size(
new_dim_full_chunk_size * num_chunks, local_tensor, new_shard_dim
new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size(
new_shard_dim
)
local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size)

View File

@ -211,14 +211,6 @@ def at_least_x_gpu(x):
return False
def _maybe_handle_skip_if_lt_x_gpu(args, msg) -> bool:
_handle_test_skip = getattr(args[0], "_handle_test_skip", None)
if len(args) == 0 or _handle_test_skip is None:
return False
_handle_test_skip(msg)
return True
def skip_if_lt_x_gpu(x):
def decorator(func):
@wraps(func)
@ -229,9 +221,7 @@ def skip_if_lt_x_gpu(x):
return func(*args, **kwargs)
if TEST_XPU and torch.xpu.device_count() >= x:
return func(*args, **kwargs)
test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
if _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
sys.exit(test_skip.exit_code)
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
return wrapper
@ -247,9 +237,7 @@ def nccl_skip_if_lt_x_gpu(backend, x):
return func(*args, **kwargs)
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
return func(*args, **kwargs)
test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
if _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
sys.exit(test_skip.exit_code)
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
return wrapper

View File

@ -701,9 +701,6 @@ class DTensorConverter:
class LocalDTensorTestBase(DTensorTestBase):
def _handle_test_skip(self, msg: str) -> None:
self.skipTest(msg)
def _get_local_tensor_mode(self):
return LocalTensorMode(frozenset(range(self.world_size)))