LocalTensor (#164537)

A LocalTensor is a tensor subclass which simulates a tensor that is
distributed across SPMD ranks.  A LocalTensor might be size N, but in fact
there are world_size shards/replicas of it stored internally.  When you do a
plain PyTorch operation on it, we apply the operation to each shard; when you
do a collective, we do the mathematically equivalent operation on the local
shards.  A LocalTensor is associated with a list of ranks which specify
which ranks it holds local tensors for.

NB, this is NOT a DataParallel like abstraction where you can run operations
on multiple different GPUs. It is intended purely for *debugging* purposes,
the overhead is almost certainly too high to keep eight GPUs (even the C++
autograd needs multithreading to keep up!)  (It might potentially be possible
to trace through this with torch.compile and then compile it with CUDA graphs
but this is currently a non-goal.)

In order to handle MPMD, we provide a helper decorator that allows you to
run a function with no side effects for each LocalTensor shard and combine
results back into LocalTensor or LocalIntNode.

Note: This PR convert all DTensor ops and some DTensor tests to illustrate
intended usage and ensure conrrectness. In subsequent PR more tests will be
converted. DUring test conversion we aim to share as much as possible of
test logic between multi-process / multi-threaded and local tensor tests.
We would like to developers to be able to run both flavors of the tests.

Note: This work is based on the original proposal
by @ezyang (WIP PR https://github.com/pytorch/pytorch/pull/162753).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164537
Approved by: https://github.com/ezyang
This commit is contained in:
Dzmitry Huba
2025-10-12 20:06:41 +00:00
committed by PyTorch MergeBot
parent a2601630cd
commit 5e58420dff
16 changed files with 2212 additions and 70 deletions

View File

@ -3,13 +3,23 @@
import pathlib
import tempfile
import types
import unittest
from functools import wraps
from typing import Optional
from numpy.testing import assert_array_equal
import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as c10d
import torch.nn.functional as F
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._local_tensor import (
LocalIntNode,
LocalTensorMode,
maybe_run_for_local_tensor,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import (
DeviceMesh,
@ -44,6 +54,11 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
c10d_functional = torch.ops.c10d_functional
@maybe_run_for_local_tensor
def map_tensor_for_rank(tensor, rank, func):
return func(tensor, rank)
class DummyMLP(torch.nn.Module):
def __init__(self, device):
super().__init__()
@ -592,7 +607,12 @@ class DTensorTest(DTensorTestBase):
self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
self.assertEqual(sharded_tensor.placements, placements)
local_tensor = sharded_tensor.to_local()
self.assertEqual(local_tensor, full_tensor[range(self.rank, self.rank + 1), :])
self.assertEqual(
local_tensor,
map_tensor_for_rank(
full_tensor, self.rank, lambda ft, r: ft[range(r, r + 1), :]
),
)
# Shard by column
placements = [Shard(1)]
@ -600,7 +620,12 @@ class DTensorTest(DTensorTestBase):
self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
self.assertEqual(sharded_tensor.placements, placements)
local_tensor = sharded_tensor.to_local()
self.assertEqual(local_tensor, full_tensor[:, range(self.rank, self.rank + 1)])
self.assertEqual(
local_tensor,
map_tensor_for_rank(
full_tensor, self.rank, lambda ft, r: ft[:, range(r, r + 1)]
),
)
# assert full tensor is not changed
self.assertEqual(full_tensor, torch.arange(ws * ws).reshape(ws, ws))
@ -620,6 +645,105 @@ class DTensorTest(DTensorTestBase):
self.assertEqual(local_tensor.item(), self.rank)
class LocalDTensorTest(DTensorTest):
def get_local_tensor_mode(self):
return LocalTensorMode(frozenset(range(0, self.world_size)))
@property
def rank(self):
return torch.SymInt(LocalIntNode({r: r for r in range(self.world_size)}))
@rank.setter
def rank(self, rank):
pass
def join_or_run(self, fn):
@wraps(fn)
def wrapper(self):
fn()
return types.MethodType(wrapper, self)
def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
dist.init_process_group("fake", rank=0, world_size=self.world_size)
self._pg = c10d._get_default_group()
def destroy_pg(self, device_id: Optional[int] = None) -> None:
dist.destroy_process_group(self._pg)
self._pg = None
def _spawn_processes(self) -> None:
pass
def test_dtensor_constructor(self):
pass
def test_meta_dtensor(self):
pass
def test_modules_w_meta_dtensor(self):
pass
def test_dtensor_stride(self):
pass
def test_from_local(self):
pass
def test_from_local_uneven_sharding(self):
pass
def test_from_local_uneven_sharding_raise_error(self):
pass
def test_from_local_negative_dim(self):
pass
def test_to_local(self):
pass
def test_to_local_grad_hint(self):
pass
def test_full_tensor_sync(self):
pass
def test_full_tensor_grad_hint(self):
pass
def test_dtensor_new_empty_strided(self):
pass
def test_dtensor_async_output(self):
pass
def test_from_local_then_to_local(self):
pass
def test_dtensor_spec_read_only_after_set(self):
pass
def test_dtensor_spec_hash(self):
pass
def test_dtensor_properties(self):
pass
def test_dtensor_save_load(self):
pass
def test_dtensor_save_load_import(self):
pass
def test_shard_tensor_2d(self):
with self.get_local_tensor_mode():
super().test_shard_tensor_2d()
def test_shard_tensor(self):
with self.get_local_tensor_mode():
super().test_shard_tensor()
class DTensorMeshTest(DTensorTestBase):
@property
def world_size(self):

View File

@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import copy
import re
import unittest
import warnings
@ -8,6 +9,7 @@ import warnings
import torch
import torch.distributed as dist
import torch.testing._internal.common_methods_invocations as common_ops
from torch.distributed._local_tensor import LocalTensorMode, reconcile_args
from torch.distributed.tensor import (
distribute_tensor,
DTensor,
@ -21,7 +23,7 @@ from torch.testing._internal.common_device_type import (
ops,
)
from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db
from torch.testing._internal.common_utils import run_tests, suppress_warnings
from torch.testing._internal.common_utils import run_tests, suppress_warnings, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorConverter,
DTensorOpTestBase,
@ -49,7 +51,7 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
return (op_name, variant_name, device_type, dtypes, False)
def skipOps(test_case_name, base_test_name, to_skip):
def skipOps(op_db, test_case_name, base_test_name, to_skip):
all_opinfos = op_db
for xfail in to_skip:
op_name, variant_name, device_type, dtypes, expected_failure = xfail
@ -88,6 +90,34 @@ def skipOps(test_case_name, base_test_name, to_skip):
return wrapped
def repurpose_ops(op_db, base_test_name, derived_test_name):
"""
Copies op info database and for the decorators that applied to base test class updates
them to apply to derived test class. The class update is required because decorators are applied
only if the class name matches (it doesn't consider base classes).
Specifically we use this function to create two test classes (one for multi-threaded and one for
local tensor flavors) that share common test body but different rules for skip or fail.
Args:
op_db: List of OpInfo objects to be repurposed.
base_test_name: The original test class name to be replaced.
derived_test_name: The new test class name to set in decorators.
Returns:
list: A new list of OpInfo objects with updated target class names for the
decorator.
"""
repurposed_ops = []
for opinfo in op_db:
opinfo_copy = copy.deepcopy(opinfo)
for decorator in list(opinfo_copy.decorators):
if hasattr(decorator, "cls_name") and decorator.cls_name == base_test_name:
decorator.cls_name = derived_test_name
repurposed_ops.append(opinfo_copy)
return repurposed_ops
# Re-generate this failed list, turn on dry_run of the below func
# check_dtensor_func(self, test, op, dry_run=True), then run sth
# like python test/distributed/tensor/test_dtensor_ops.py > failed.expect
@ -162,7 +192,6 @@ dtensor_fails = {
xfail("fmin"),
xfail("frexp"),
xfail("full"),
xfail("full_like"),
xfail("geometric"),
xfail("geqrf"),
xfail("grid_sampler_2d"),
@ -226,7 +255,6 @@ dtensor_fails = {
xfail("masked_select"),
xfail("masked.argmax"),
xfail("masked.argmin"),
xfail("masked.cumprod"),
xfail("masked.logsumexp"),
xfail("masked.median"),
xfail("matrix_exp"),
@ -244,8 +272,6 @@ dtensor_fails = {
xfail("native_batch_norm"),
xfail("narrow_copy"),
xfail("ne"),
xfail("new_empty"),
xfail("new_empty_strided"),
xfail("transpose"),
xfail("nn.functional.adaptive_avg_pool1d"),
xfail("nn.functional.adaptive_avg_pool2d"),
@ -272,8 +298,6 @@ dtensor_fails = {
xfail("nn.functional.cosine_similarity"),
xfail("nn.functional.ctc_loss"),
xfail("nn.functional.dropout"),
xfail("nn.functional.dropout2d"),
xfail("nn.functional.dropout3d"),
xfail("nn.functional.elu"),
xfail("nn.functional.fractional_max_pool2d"),
xfail("nn.functional.fractional_max_pool3d"),
@ -307,7 +331,6 @@ dtensor_fails = {
xfail("nn.functional.multi_margin_loss"),
xfail("nn.functional.multilabel_margin_loss"),
xfail("nn.functional.multilabel_soft_margin_loss"),
xfail("nn.functional.multi_head_attention_forward"),
xfail("nn.functional.pad", "reflect"),
xfail("nn.functional.pad", "replicate"),
xfail("nn.functional.pad", "replicate_negative"),
@ -482,13 +505,21 @@ dtensor_fails = {
skip("_segment_reduce", "offsets"),
# TODO: fix the following ops
skip("squeeze"),
# These must be skipped as their contents are nondeterministic
skip("empty"),
skip("empty_strided"),
skip("empty_like"),
skip("empty_permuted"),
skip("new_empty"),
skip("new_empty_strided"),
}
dtensor_multi_threaded_fails = {
xfail("full_like"),
xfail("nn.functional.dropout2d"),
xfail("nn.functional.dropout3d"),
xfail("masked.cumprod"),
skip("nn.functional.multi_head_attention_forward"),
}
# Add a list of ops that are currently failing BW pass
skip_bw = [
@ -507,7 +538,13 @@ OP_DB_WORLD_SIZE = 4
DEVICE_TYPE = "cpu"
class TestDTensorOps(DTensorOpTestBase):
class TestDTensorOps(TestCase):
__test__ = False
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.__test__ = True
@property
def world_size(self) -> int:
return OP_DB_WORLD_SIZE
@ -535,14 +572,6 @@ class TestDTensorOps(DTensorOpTestBase):
self.check_dtensor_func(test, op)
# only allow float dytpe for now, we can relax this constraint
# when feel necessary later (i.e when adding quantization support).
@suppress_warnings
@ops(op_db, allowed_dtypes=(torch.float,))
@skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails)
def test_dtensor_op_db(self, dtype, op):
self.run_opinfo_test(dtype, op)
def assert_ref_dtensor_equal(self, dtensor_rs, rs):
flat_dtensor_rs = pytree.tree_leaves(dtensor_rs)
flat_rs = pytree.tree_leaves(rs)
@ -567,6 +596,9 @@ class TestDTensorOps(DTensorOpTestBase):
self.assertEqualOnRank(dtensor_r, r)
def assertEqualOnRank(self, x, y, msg=None, *, rank=0) -> None:
raise NotImplementedError
def run_dtensor_crossref(self, func, args, kwargs):
to_dtensor = DTensorConverter(self.mesh, args, kwargs)
@ -580,7 +612,8 @@ class TestDTensorOps(DTensorOpTestBase):
return res
# TODO: also handle cases where func raise an exception
rs = func(*args, **kwargs)
op_args, op_kwargs = reconcile_args(args, kwargs)
rs = func(*op_args, **op_kwargs)
rs = concat_res_if_necessary(func, rs)
def to_replicate(e: object) -> object:
@ -635,12 +668,12 @@ class TestDTensorOps(DTensorOpTestBase):
self.assert_ref_dtensor_equal(dtensor_rs, rs)
else:
raise RuntimeError(
f"failed to convert args to DTensor; "
f"Failed to convert args to DTensor; "
f"originally (*{args}, **{kwargs})"
)
except Exception as e:
raise RuntimeError(
f"{str(e)}\n\nfailed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
f"{str(e)}\n\nFailed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
) from e
return rs
@ -656,7 +689,7 @@ class TestDTensorOps(DTensorOpTestBase):
else:
print(f"xfail('{opinfo.name}'),")
def test_one_hot(self):
def run_one_hot(self):
ops = [op for op in op_db if op.name == "nn.functional.one_hot"]
assert len(ops) == 1
op = ops[0]
@ -668,7 +701,7 @@ class TestDTensorOps(DTensorOpTestBase):
sample_inputs_filter=lambda s: s.kwargs["num_classes"] != -1,
)
def test_mean(self):
def run_mean(self):
self.mesh = init_device_mesh(DEVICE_TYPE, (self.world_size,))
shape = [2 * self.world_size + 1, 2 * self.world_size]
@ -692,6 +725,7 @@ class TestDTensorOps(DTensorOpTestBase):
full_tensor = mean.full_tensor()
self.assertEqual(full_tensor, tensor.mean(dim=reduce_dim))
if is_evenly_shardable:
self.assertTrue("P->R" in debug_mode.debug_string())
else:
@ -720,9 +754,76 @@ class TestDTensorOps(DTensorOpTestBase):
_ = torch.ops.aten.embedding.default(weight_dtensor, input_dtensor)
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
instantiate_device_type_tests(TestDTensorOps, globals(), only_for=(DEVICE_TYPE,))
class TestMultiThreadedDTensorOps(DTensorOpTestBase, TestDTensorOps):
_op_db = repurpose_ops(op_db, "TestDTensorOps", "TestMultiThreadedDTensorOps")
@suppress_warnings
@ops(_op_db, allowed_dtypes=(torch.float,))
@skipOps(
_op_db,
"TestMultiThreadedDTensorOps",
"test_dtensor_op_db",
dtensor_fails | dtensor_multi_threaded_fails,
)
def test_dtensor_op_db(self, dtype, op):
self.run_opinfo_test(dtype, op)
def test_mean(self):
self.run_mean()
def test_one_hot(self):
self.run_one_hot()
class TestLocalDTensorOps(TestDTensorOps):
_op_db = repurpose_ops(op_db, "TestDTensorOps", "TestLocalDTensorOps")
def setUp(self) -> None:
super().setUp()
torch.distributed.init_process_group("fake", rank=0, world_size=self.world_size)
self.fake_pg = torch.distributed.distributed_c10d._get_default_group()
def tearDown(self):
super().tearDown()
try:
dist.destroy_process_group()
except AssertionError:
pass
@suppress_warnings
@ops(_op_db, allowed_dtypes=(torch.float,))
@skipOps(
_op_db,
"TestLocalDTensorOps",
"test_dtensor_op_db",
dtensor_fails,
)
def test_dtensor_op_db(self, dtype, op):
self.run_opinfo_test(dtype, op)
def test_mean(self):
with LocalTensorMode(frozenset(range(0, self.world_size))):
self.run_mean()
def test_one_hot(self):
self.run_one_hot()
def run_opinfo_test(
self, dtype, op, requires_grad=True, sample_inputs_filter=lambda s: True
):
with LocalTensorMode(frozenset(range(0, self.world_size))):
super().run_opinfo_test(dtype, op, requires_grad, sample_inputs_filter)
def assertEqualOnRank(self, x, y, msg=None, *, rank=0):
self.assertEqual(x, y, msg)
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
instantiate_device_type_tests(
TestMultiThreadedDTensorOps, globals(), only_for=(DEVICE_TYPE,)
)
instantiate_device_type_tests(TestLocalDTensorOps, globals(), only_for=(DEVICE_TYPE,))
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,415 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
from contextlib import nullcontext
import torch
import torch.distributed as dist
from torch.distributed._local_tensor import (
local_tensor_mode,
LocalTensor,
LocalTensorMode,
)
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
init_device_mesh,
Partial,
Replicate,
Shard,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class LocalTensorTestBase(TestCase):
def assertEqual(self, lhs, rhs, **kwargs):
mode = local_tensor_mode()
with nullcontext() if mode is None else mode.disable():
if isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor):
assert isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor)
super().assertEqual(lhs._ranks, rhs._ranks)
for r in lhs._ranks:
super().assertEqual(
lhs._local_tensors[r],
rhs._local_tensors[r],
lambda m: f"rank {r}: {m}",
)
elif isinstance(lhs, LocalTensor) or isinstance(rhs, LocalTensor):
lhs, rhs = (lhs, rhs) if isinstance(lhs, LocalTensor) else (rhs, lhs)
for r in lhs._ranks:
super().assertEqual(
lhs._local_tensors[r], rhs, lambda m: f"rank {r}: {m}"
)
else:
return super().assertEqual(lhs, rhs, **kwargs)
@property
def world_size(self):
raise NotImplementedError("override world-size in your subclass")
def build_device_mesh(self) -> DeviceMesh:
return init_device_mesh("cpu", (self.world_size,))
def setUp(self):
super().setUp()
torch.distributed.init_process_group(
# TODO: test other ranks too
"fake",
rank=0,
world_size=self.world_size,
)
def tearDown(self):
super().tearDown()
try:
dist.destroy_process_group()
except AssertionError:
pass
class TestLocalTensorWorld2(LocalTensorTestBase):
world_size = 2
def test_local_tensor_dtype_consistency(self):
"""Test that LocalTensor enforces dtype consistency."""
device = torch.device("cpu")
shape = (2, 3)
inconsistent_tensors = {
0: torch.randn(shape, dtype=torch.float32, device=device),
1: torch.randn(
shape, dtype=torch.float64, device=device
), # Different dtype
}
with self.assertRaises(AssertionError):
LocalTensor(inconsistent_tensors)
def test_local_tensor_creation_fails_with_grad_tensors(self):
"""Test that LocalTensor creation fails when local tensors have requires_grad=True."""
device = torch.device("cpu")
shape = (2, 3)
dtype = torch.float32
# Create sample local tensors for different ranks
local_tensors = {
0: torch.randn(shape, dtype=dtype, device=device, requires_grad=True),
1: torch.randn(shape, dtype=dtype, device=device, requires_grad=True),
}
with self.assertRaises(AssertionError):
LocalTensor(local_tensors)
# TODO: test flatten/unflatten
def test_basic_arithmetic_operations(self):
"""Test basic arithmetic operations on LocalTensors."""
device = torch.device("cpu")
shape = (2, 3)
dtype = torch.float32
# Create identical local tensors for consistency tests
base_tensor = torch.randn(shape, dtype=dtype, device=device)
identical_local_tensors = {
0: base_tensor.clone(),
1: base_tensor.clone(),
}
lt1 = LocalTensor(identical_local_tensors)
lt2 = LocalTensor(identical_local_tensors)
# Test addition
result_add = lt1 + lt2
self.assertIsInstance(result_add, LocalTensor)
self.assertEqual(len(result_add._local_tensors), 2)
# Verify the operation was applied to each local tensor
for rank in identical_local_tensors.keys():
expected = identical_local_tensors[rank] + identical_local_tensors[rank]
self.assertEqual(result_add._local_tensors[rank], expected)
# Test multiplication
result_mul = lt1 * 2.0
self.assertIsInstance(result_mul, LocalTensor)
for rank in identical_local_tensors.keys():
expected = identical_local_tensors[rank] * 2.0
self.assertEqual(result_mul._local_tensors[rank], expected)
# TODO: consider an op-info test; we don't actually need to cover all ops
# but it will help make sure views and more exotic things are done
# correctly (in standard subclass style)
def test_mixed_operations_with_regular_tensors(self):
"""Test operations between LocalTensors and regular tensors."""
device = torch.device("cpu")
shape = (2, 3)
dtype = torch.float32
# Create identical local tensors for consistency tests
base_tensor = torch.randn(shape, dtype=dtype, device=device)
identical_local_tensors = {
0: base_tensor.clone(),
1: base_tensor.clone(),
}
lt = LocalTensor(identical_local_tensors)
regular_tensor = torch.ones_like(identical_local_tensors[0])
# Test LocalTensor + regular tensor
result = lt + regular_tensor
self.assertIsInstance(result, LocalTensor)
for rank in identical_local_tensors.keys():
expected = identical_local_tensors[rank] + regular_tensor
self.assertEqual(result._local_tensors[rank], expected)
def test_local_tensor_mode(self):
"""Test LocalTensorMode functionality."""
device = torch.device("cpu")
shape = (2, 3)
dtype = torch.float32
# Create identical local tensors for consistency tests
base_tensor = torch.randn(shape, dtype=dtype, device=device)
identical_local_tensors = {
0: base_tensor.clone(),
1: base_tensor.clone(),
}
lt = LocalTensor(identical_local_tensors)
with LocalTensorMode(lt._ranks):
result = lt + 1.0
self.assertIsInstance(result, LocalTensor)
regular = torch.ones(2, 2)
regular_result = regular + 1.0
self.assertIsInstance(regular, LocalTensor)
self.assertIsInstance(regular_result, LocalTensor)
def test_empty_local_tensors(self):
"""Test behavior with empty local tensors dict."""
# TODO: raise a better error here
with self.assertRaises(StopIteration): # next() on empty iterator
LocalTensor({})
def test_collectives_within_local_tensor_mode(self):
"""Test that collective operations work within LocalTensorMode context."""
test_tensors = {
0: torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
1: torch.tensor([[5.0, 6.0], [7.0, 8.0]]),
}
lt = LocalTensor(test_tensors)
fake_pg = torch.distributed.distributed_c10d._get_default_group()
with LocalTensorMode(lt._ranks):
# Test all_reduce within mode
lt_sum = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
dist.all_reduce(lt_sum, group=fake_pg)
expected_sum = torch.tensor([[6.0, 8.0], [10.0, 12.0]])
for rank in test_tensors.keys():
self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
# Test broadcast within mode
lt_broadcast = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
dist.broadcast(lt_broadcast, src=0, group=fake_pg)
for rank in test_tensors.keys():
self.assertEqual(lt_broadcast._local_tensors[rank], test_tensors[0])
# Test that regular operations still work
result = lt + 1.0
self.assertIsInstance(result, LocalTensor)
def test_scalar_mul_reduction_bug(self):
with LocalTensorMode(self.world_size):
mesh = self.build_device_mesh()
tensor = torch.tensor([10, 10]).float()
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
y = dt.sum() * 1 # noqa: F841
tensor = torch.arange(10).reshape(10, 1).float().requires_grad_()
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
print(dt.sum() * 1, dt.sum() * 2, dt.sum() * 3)
def test_uneven_sharding_mean_bug(self):
with LocalTensorMode(self.world_size):
mesh = self.build_device_mesh()
tensor = torch.arange(12).reshape(-1, 4).float()
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
mean = dt.mean()
self.assertEqual(mean.placements, [Replicate()])
full = mean.full_tensor()
self.assertEqual(tensor.mean(), full)
def test_uneven_sharding_prod(self):
with LocalTensorMode(self.world_size):
mesh = self.build_device_mesh()
tensor = (torch.arange(12) + 1).reshape(-1, 4).float()
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
x = dt.prod()
full = x.full_tensor()
self.assertEqual(tensor.prod(), full)
def test_even_sharding_mean_is_partial(self):
with LocalTensorMode(self.world_size):
mesh = self.build_device_mesh()
tensor = torch.arange(16).reshape(4, 4).float()
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
mean = dt.mean()
full = mean.full_tensor()
self.assertEqual(tensor.mean(), full)
self.assertEqual(mean.placements, [Partial("avg")])
class TestLocalTensorWorld3(LocalTensorTestBase):
world_size = 3
def test_collective_reduction_operations(self):
"""Test different reduction operations for all_reduce."""
# Create different tensors for each rank with simple values for testing
test_tensors = {
0: torch.tensor([[1.0, 4.0], [2.0, 5.0]]),
1: torch.tensor([[2.0, 1.0], [3.0, 6.0]]),
2: torch.tensor([[3.0, 2.0], [1.0, 4.0]]),
}
fake_pg = torch.distributed.distributed_c10d._get_default_group()
# Test SUM reduction
lt_sum = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
dist.all_reduce(lt_sum, op=dist.ReduceOp.SUM, group=fake_pg)
expected_sum = torch.tensor([[6.0, 7.0], [6.0, 15.0]]) # Sum of all tensors
for rank in test_tensors.keys():
self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
# Test MAX reduction
lt_max = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
dist.all_reduce(lt_max, op=dist.ReduceOp.MAX, group=fake_pg)
expected_max = torch.tensor([[3.0, 4.0], [3.0, 6.0]]) # Max across all tensors
for rank in test_tensors.keys():
self.assertEqual(lt_max._local_tensors[rank], expected_max)
# Test MIN reduction
lt_min = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
dist.all_reduce(lt_min, op=dist.ReduceOp.MIN, group=fake_pg)
expected_min = torch.tensor([[1.0, 1.0], [1.0, 4.0]]) # Min across all tensors
for rank in test_tensors.keys():
self.assertEqual(lt_min._local_tensors[rank], expected_min)
def test_all_reduce_collective(self):
"""Test that all_reduce collective operation works correctly with LocalTensor."""
# Create different tensors for each rank
different_tensors = {
0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]),
2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]),
}
fake_pg = torch.distributed.distributed_c10d._get_default_group()
# Test all_reduce with SUM (default)
lt_sum = LocalTensor({k: v.clone() for k, v in different_tensors.items()})
lt_sum = lt_sum + 1
dist.all_reduce(lt_sum, group=fake_pg)
# Verify all ranks have the sum of all tensors (after adding 1 to each)
expected_sum = torch.tensor([[114.0, 225.0, 336.0], [447.0, 558.0, 669.0]])
for rank in different_tensors.keys():
self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
def test_broadcast_collective(self):
"""Test that broadcast collective operation works correctly with LocalTensor."""
# Create different tensors for each rank
different_tensors = {
0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]),
2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]),
}
fake_pg = torch.distributed.distributed_c10d._get_default_group()
# Test broadcast from rank 1
lt_broadcast = LocalTensor({k: v.clone() for k, v in different_tensors.items()})
dist.broadcast(lt_broadcast, src=1, group=fake_pg)
# Verify all ranks have rank 1's original tensor
expected_broadcast = different_tensors[1]
for rank in different_tensors.keys():
self.assertEqual(lt_broadcast._local_tensors[rank], expected_broadcast)
def test_all_gather_collective(self):
"""Test that all_gather collective operation works correctly with LocalTensor."""
# Create different tensors for each rank
different_tensors = {
0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]),
2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]),
}
fake_pg = torch.distributed.distributed_c10d._get_default_group()
# Test all_gather
lt_gather = LocalTensor(different_tensors)
tensor_list = [torch.zeros_like(lt_gather) for _ in range(3)]
dist.all_gather(tensor_list, lt_gather, group=fake_pg)
# Verify each position in tensor_list contains the corresponding rank's tensor
self.assertEqual(tensor_list[0], different_tensors[0])
self.assertEqual(tensor_list[1], different_tensors[1])
self.assertEqual(tensor_list[2], different_tensors[2])
class TestLocalTensorWorld4(LocalTensorTestBase):
world_size = 4
def test_dtensor_cat(self):
with LocalTensorMode(self.world_size):
device_mesh = self.build_device_mesh()
t1 = torch.arange(16).view(4, 4).float()
d1 = distribute_tensor(t1, device_mesh, [Replicate()])
t2 = (torch.arange(16) + 16).view(4, 4).float()
d2 = distribute_tensor(t2, device_mesh, [Shard(0)])
local_res = torch.cat([t1, t2], dim=-1)
dist_res = torch.cat([d1, d2], dim=-1)
full_tensor = dist_res.full_tensor()
self.assertEqual(full_tensor, local_res)
class TestLocalTensorWorld8(LocalTensorTestBase):
world_size = 8
def test_dtensor_addmm(self):
with LocalTensorMode(self.world_size):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
replica_spec = [Replicate()]
tensor_to_shard = torch.randn(12, 8)
mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
tensor_to_replicate = torch.randn(8, 4)
mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec)
input_tensor = torch.randn(4)
input = distribute_tensor(input_tensor, device_mesh, replica_spec)
dist_res = torch.addmm(input, mat1, mat2)
local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate)
full_tensor = dist_res.full_tensor()
self.assertEqual(full_tensor, local_res)
if __name__ == "__main__":
run_tests()

View File

@ -7915,9 +7915,13 @@ torch.cuda.synchronize()
nt = torch.nested.nested_tensor(
[
torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype)
if dtype is torch.bool
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
(
torch.randint(
2, (n, *post_seq_len_shape), device=device, dtype=dtype
)
if dtype is torch.bool
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
)
for n in range(2, 9)
],
layout=torch.jagged,
@ -7966,9 +7970,13 @@ torch.cuda.synchronize()
nt = torch.nested.nested_tensor(
[
torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype)
if dtype is torch.bool
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
(
torch.randint(
2, (n, *post_seq_len_shape), device=device, dtype=dtype
)
if dtype is torch.bool
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
)
for n in range(2, 9)
],
layout=torch.jagged,
@ -8713,7 +8721,7 @@ COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
# min() / max(): weird bug
XFailRule(
error_type=AttributeError,
error_msg="'ConstantIntNode' object has no attribute 'add'",
error_msg="'NestedIntNode' object has no attribute 'add'",
op_match_fn=lambda device, op: (
op.full_name in {"max.reduction_with_dim", "min.reduction_with_dim"}
),
@ -8730,7 +8738,7 @@ COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
# copysign(): formula is broken for (T, NT) broadcasting
XFailRule(
error_type=AttributeError,
error_msg="'ConstantIntNode' object has no attribute 'add'",
error_msg="'NestedIntNode' object has no attribute 'add'",
op_match_fn=lambda device, op: (op.full_name == "copysign"),
sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name),
name="broken_copysign_compile_backward",

View File

@ -15,7 +15,11 @@ TORCH_LIBRARY(c10d, m) {
m.class_<Work>("Work")
.def(torch::init<>())
.def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
m.class_<ReduceOp>("ReduceOp")
.def(torch::init<>())
.def("op", [](const c10::intrusive_ptr<ReduceOp>& self) -> int64_t {
return self->op_;
});
m.def(
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(

View File

@ -0,0 +1,747 @@
from ast import Call
"""
A LocalTensor is a tensor subclass which simulates a tensor that is
distributed across SPMD ranks. A LocalTensor might be size N, but in fact
there are world_size shards/replicas of it stored internally. When you do a
plain PyTorch operation on it, we apply the operation to each shard; when you
do a collective, we do the mathematically equivalent operation on the local
shards. A LocalTensor is associated with a list of ranks which specify
which ranks it holds local tensors for.
NB, this is NOT a DataParallel like abstraction where you can run operations
on multiple different GPUs. It is intended purely for *debugging* purposes,
the overhead is almost certainly too high to keep eight GPUs (even the C++
autograd needs multithreading to keep up!) (It might potentially be possible
to trace through this with torch.compile and then compile it with CUDA graphs
but this is currently a non-goal.)
We do not directly handling MPMD. However in practice even in SPMD you may
encounter divergence in behavior per rank (for example, uneven sharding
across ranks). To support scenarios like this, we provide a helper decorator
that allows you to run a function with no side effects for each LocalTensor
shard and combine results back into LocalTensor or LocalIntNode.
NB: This is a torch dispatch Tensor subclass, as we want to assume that autograd
is SPMD, so we run it once, and dispatch the inner autograd calls to the individual
local shards.
NOTE ABOUT MESH: This subclass requires collectives that are issued to it to
respect a DeviceMesh like abstraction. The reason for this is that when
DTensor issues us a collective for a particular rank, you will be asked to do
this on a specific process group which involves some ranks. However, this
will only be for the LOCAL PG that this particular rank is participating in;
there will be a bunch of other PGs for other nodes that you don't get to see.
We need to be able to reverse engineer all of the collectives that don't
involve the current local rank here to actually issue them. This can be done
two ways: (1) looking at the participating local ranks in the PG and computing
the complement which specifies all the other collectives you have to run, or
(2) retrieving the device mesh axis corresponding to the PG for this rank, and
then running all the fibers for this.
"""
import contextlib
import functools
import operator
import os
import sys
from collections import defaultdict
from collections.abc import Sequence
from types import TracebackType
from typing import Any, Callable, Generator, Optional, Union
import torch
from torch import Size, SymBool, SymInt, Tensor
from torch._C import DispatchKey, DispatchKeySet
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
from torch.distributed import DeviceMesh
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.fx.experimental._constant_symnode import ConstantIntNode
from torch.nested._internal.nested_int import NestedIntNode
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import return_and_correct_aliasing, TorchDispatchMode
from torch.utils.checkpoint import get_device_states, set_device_states
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
from . import _c10d
def _int_on_rank(i: "LocalIntNode | ConstantIntNode", r: int) -> int:
if isinstance(i, LocalIntNode):
return i._local_ints[r]
elif isinstance(i, ConstantIntNode):
return i.val
else:
raise AssertionError(type(i))
def _check_for_subclass(flat_args: Sequence[object]) -> bool:
return any(_check_for_subclass_arg(x) for x in flat_args)
def _check_for_subclass_arg(x: object) -> bool:
return (
not isinstance(x, LocalTensor)
and isinstance(x, Tensor)
and type(x) not in (Tensor, torch.nn.Parameter, torch.nn.Buffer)
)
def _map_to_rank_local_val(val: Any, rank: int) -> Any:
if isinstance(val, LocalTensor):
return val._local_tensors[rank]
if isinstance(val, SymInt) and isinstance(val.node, LocalIntNode):
return val.node._local_ints[rank]
return val
def _for_each_rank_run_func(
func: Callable[..., Any],
ranks: frozenset[int],
args: Sequence[Any],
kwargs: dict[str, Any],
*,
alias: bool = True,
) -> Any:
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
cpu_state = torch.get_rng_state()
devices, states = get_device_states((args, kwargs))
flat_rank_rets = {}
for r in sorted(ranks):
torch.set_rng_state(cpu_state)
set_device_states(devices, states)
rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args]
rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec)
rank_ret = func(*rank_args, **rank_kwargs)
flat_rank_rets[r] = rank_ret
rr_key = next(iter(flat_rank_rets.keys()))
rr_val = flat_rank_rets[rr_key]
if isinstance(rr_val, Tensor):
ret = LocalTensor({r: flat_rank_rets[r] for r in sorted(ranks)})
elif isinstance(rr_val, (list, tuple)):
ret_list = []
for i in range(len(rr_val)):
rets = {r: flat_rank_rets[r][i] for r in sorted(ranks)}
v_it = iter(rets.values())
v = next(v_it)
if isinstance(v, Tensor):
ret_list.append(LocalTensor(rets))
elif isinstance(v, int) and not all(v == v2 for v2 in v_it):
ret_list.append(torch.SymInt(LocalIntNode(rets)))
else:
assert all(v == v2 for v2 in v_it)
ret_list.append(v)
ret = type(rr_val)(ret_list)
else:
v_it = iter(flat_rank_rets.values())
v = next(v_it)
if all(v == v2 for v2 in v_it):
return v
if isinstance(v, int):
return torch.SymInt(LocalIntNode(flat_rank_rets))
raise AssertionError(f"Unexpected return type {type(v)}")
if alias:
return return_and_correct_aliasing(func, args, kwargs, ret)
else:
return ret
def _get_extra_dispatch_keys(t: torch.Tensor) -> DispatchKeySet:
extra_dispatch_keys = torch._C.DispatchKeySet.from_raw_repr(0)
if torch._C._dispatch_keys(t).has(torch._C.DispatchKey.Conjugate):
extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Conjugate)
if torch._C._dispatch_keys(t).has(torch._C.DispatchKey.Negative):
extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Negative)
return extra_dispatch_keys
class LocalIntNode:
"""
Like a LocalTensor, but for an int. We can't use a 0D tensor to represent this
because often only a SymInt is accepted where we wish to use this.
"""
def __new__(cls, local_ints: dict[int, int]) -> "ConstantIntNode | LocalIntNode": # type: ignore[misc]
if len(set(local_ints.values())) == 1:
return ConstantIntNode(next(iter(local_ints.values())))
return super().__new__(cls)
def __init__(self, local_ints: dict[int, int]):
self._local_ints = local_ints
def maybe_as_int(self) -> Optional[int]:
return None
def is_int(self) -> bool:
return True
def is_float(self) -> bool:
return False
def is_bool(self) -> bool:
return False
def is_nested_int(self) -> bool:
return False
def clone(self) -> "LocalIntNode":
return self
def _str(self) -> str:
return f"LocalIntNode({self._local_ints})"
def __str__(self) -> str:
return self._str()
def __repr__(self) -> str:
return self._str()
def _graph_repr(self) -> str:
return self._str()
def is_symbolic(self) -> bool:
return False
def is_constant(self) -> bool:
return False
def sym_max(
self, other: "LocalIntNode | ConstantIntNode"
) -> "LocalIntNode | ConstantIntNode":
return LocalIntNode(
{
r: max(self._local_ints[r], _int_on_rank(other, r))
for r in self._local_ints
}
)
def add(
self, other: "LocalIntNode | ConstantIntNode"
) -> "LocalIntNode | ConstantIntNode":
return LocalIntNode(
{r: self._local_ints[r] + _int_on_rank(other, r) for r in self._local_ints}
)
def sub(
self, other: "LocalIntNode | ConstantIntNode"
) -> "LocalIntNode | ConstantIntNode":
return LocalIntNode(
{r: self._local_ints[r] - _int_on_rank(other, r) for r in self._local_ints}
)
def mul(
self, other: "LocalIntNode | ConstantIntNode"
) -> "LocalIntNode | ConstantIntNode":
return LocalIntNode(
{r: self._local_ints[r] * _int_on_rank(other, r) for r in self._local_ints}
)
def eq(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
r = {self._local_ints[r] == _int_on_rank(other, r) for r in self._local_ints}
return torch._C._get_constant_bool_symnode(len(r) == 1 and next(iter(r)))
def gt(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
r = {self._local_ints[r] > _int_on_rank(other, r) for r in self._local_ints}
assert len(r) == 1, (self, other)
return torch._C._get_constant_bool_symnode(next(iter(r)))
def lt(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
r = {self._local_ints[r] < _int_on_rank(other, r) for r in self._local_ints}
assert len(r) == 1, (self, other)
return torch._C._get_constant_bool_symnode(next(iter(r)))
def wrap_int(self, num: int) -> "LocalIntNode | ConstantIntNode":
return ConstantIntNode(num)
class LocalTensor(torch.Tensor):
"""
LocalTensor is a Tensor subclass that simulates a tensor distributed across multiple SPMD
(Single Program, Multiple Data) ranks. Each LocalTensor instance internally holds a mapping from
global rank ids to their corresponding local Tensor shards.Operations performed on a LocalTensor
are applied independently to each local shard, mimicking distributed computation. Collectives
and other distributed operations are handled by mapping them to the local shards as appropriate.
Note:
This class is primarily intended for debugging and simulating distributed tensor computations
on a single process.
"""
# Map from global rank to the local tensor.
_local_tensors: dict[int, torch.Tensor]
# Precomputed for speed set of keys from the local tensor map.
_ranks: frozenset[int]
__slots__ = ["_local_tensors", "_ranks"]
@staticmethod
@torch._disable_dynamo
def __new__(
cls,
local_tensors: dict[int, torch.Tensor],
) -> "LocalTensor":
if any(t.requires_grad for t in local_tensors.values()):
raise AssertionError(
"Internal local_tensors require grad, but we will ignore those autograd graph. "
"Make a custom autograd function and make sure you detach the inner tensors."
)
it = iter(local_tensors.values())
first_local_tensor = next(it)
first_shape = first_local_tensor.shape
first_stride = first_local_tensor.stride()
dtype = first_local_tensor.dtype
device = first_local_tensor.device
layout = first_local_tensor.layout
extra_dispatch_keys = _get_extra_dispatch_keys(first_local_tensor)
# Assert that all tensors have the same dtype, layout and dispatch keys. Due
# to uneven sharding, it is possible that tensors will have different shapes.
for local_tensor in it:
assert dtype == local_tensor.dtype, (
"Tensors representing LocalTensor shards must have the same dtype"
)
assert layout == local_tensor.layout, (
"Tensors representing LocalTensor shards must have the same layout"
)
assert extra_dispatch_keys == _get_extra_dispatch_keys(local_tensor), (
"Tensors representing LocalTensor shards must have the same set of extra dispatch keys"
)
# Compute shape/stride. We allow for non-SPMD'ness here
local_shapes: dict[int, dict[int, int]] = defaultdict(
dict
) # dim => rank => size
local_strides: dict[int, dict[int, int]] = defaultdict(
dict
) # dim => rank => size
for r, local_tensor in local_tensors.items():
for d, size in enumerate(local_tensor.shape):
local_shapes[d][r] = size
local_strides[d][r] = local_tensor.stride(d)
shape = [
(
first_shape[d]
if len(set(local_shapes[d])) == 1
else torch.SymInt(LocalIntNode(local_shapes[d]))
)
for d in range(len(first_shape))
]
strides = [
(
first_stride[d]
if len(set(local_strides[d])) == 1
else torch.SymInt(LocalIntNode(local_strides[d]))
)
for d in range(len(first_shape))
]
r = torch.Tensor._make_wrapper_subclass(
cls,
shape,
strides=strides,
dtype=dtype,
device=device,
layout=layout,
requires_grad=False,
_extra_dispatch_keys=extra_dispatch_keys,
)
local_tensors = {
r: v if not isinstance(v, AsyncCollectiveTensor) else v.wait()
for r, v in local_tensors.items()
}
r._local_tensors = local_tensors
r._ranks = frozenset(local_tensors.keys())
return r
@torch._disable_dynamo
@mark_subclass_constructor_exportable_experimental # type: ignore[misc]
def __init__(self, *args: Any, **kwargs: Any):
super().__init__()
def __repr__(self) -> str: # type: ignore[override]
parts = []
for k, v in self._local_tensors.items():
parts.append(f" {k}: {v}")
tensors_str = ",\n".join(parts)
return f"LocalTensor(\n{tensors_str}\n)"
def __tensor_flatten__(self) -> tuple[list[str], tuple[Any, ...]]:
"""
protocol to inform how to flatten a DTensor to local tensor
for PT2 tracing
"""
return ["_local_tensors"], ()
@staticmethod
def __tensor_unflatten__(
inner_tensors: dict[str, Any],
flatten_spec: tuple[Any, ...],
outer_size: torch.Size,
outer_stride: tuple[int, ...],
) -> "LocalTensor":
assert flatten_spec is not None, (
"Expecting spec to be not None from `__tensor_flatten__` return value!"
)
local_tensors = inner_tensors["_local_tensors"]
return LocalTensor(local_tensors)
@classmethod
@torch._disable_dynamo
def __torch_dispatch__( # type: ignore[override]
cls,
func: Any,
types: tuple[Any, ...],
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
) -> Any:
if kwargs is None:
kwargs = {}
# This is horribly inefficient
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
local_tensor = None
for arg in flat_args:
if isinstance(arg, LocalTensor):
local_tensor = arg
break
assert local_tensor is not None, (
"At least one of the arguments must be a LocalTensor"
)
# Check for unrecognized tensor subclasses (but allow regular tensors and scalars)
has_unrecognized_types = _check_for_subclass(flat_args)
if has_unrecognized_types:
unrecognized_types = [
type(x) for x in flat_args if _check_for_subclass_arg(x)
]
not_implemented_log.debug(
"LocalTensor unrecognized subclass(es): %s", unrecognized_types
)
return NotImplemented
with LocalTensorMode(local_tensor._ranks):
return func(*args, **kwargs)
def tolist(self) -> list[Any]:
"""
Reconcile and convert result to list.
"""
return self.reconcile().tolist()
def reconcile(self) -> torch.Tensor:
"""
Reconciles the LocalTensor into a single torch.Tensor by ensuring all local
shards are identical and returning a detached clone of one of them.
Note:
This method is useful for extracting a representative tensor from a LocalTensor
when all shards are expected to be the same, such as after a collective operation
that synchronizes all ranks.
"""
# Force all local tensor shards across ranks to be the same
it = iter(self._local_tensors.values())
t1 = next(it)
for t2 in it:
assert torch.equal(t1, t2), (
"LocalTensor shards must be the same to reconcile"
)
cl = t1.clone().detach()
cl.requires_grad_(self.requires_grad)
return cl
_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = []
class LocalTensorMode(TorchDispatchMode):
"""
A TorchDispatchMode that simulates SPMD (Single Program, Multiple Data) execution
for LocalTensor objects across a set of ranks.
LocalTensorMode enables PyTorch operations to be transparently applied to each
local shard of a LocalTensor, as if they were distributed across multiple ranks.
When active, this mode intercepts tensor operations and dispatches them to each
rank's local tensor, collecting and wrapping the results as LocalTensors. It also
handles collective operations by mapping them to local implementations.
This mode is primarily intended for debugging and simulating distributed tensor
computations on a single process, rather than for high-performance distributed
training. It maintains a stack of active modes, patches DeviceMesh coordinate
resolution, and provides utilities for temporarily disabling the mode or mapping
functions over ranks.
"""
# What ranks this local tensor mode is operating over
def __init__(self, ranks: Union[int, frozenset[int]]):
if isinstance(ranks, int):
# assume is world size
self.ranks = frozenset(range(ranks))
else:
assert isinstance(ranks, frozenset)
self.ranks = ranks
self._disable = False
self._old_get_coordinate = None
def __enter__(self) -> "LocalTensorMode":
self._disable = False
self._patch_device_mesh()
_LOCAL_TENSOR_MODE.append(self)
return super().__enter__()
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self._disable = True
self._unpatch_device_mesh()
_LOCAL_TENSOR_MODE.pop()
super().__exit__(exc_type, exc_val, exc_tb)
def __torch_dispatch__(
self,
func: Any,
types: tuple[Any, ...],
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
) -> Any:
if kwargs is None:
kwargs = {}
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
# Find all LocalTensor arguments to determine ranks
local_tensors = [a for a in flat_args if isinstance(a, LocalTensor)]
# Check for unrecognized tensor subclasses (but allow regular tensors and scalars)
has_unrecognized_types = _check_for_subclass(flat_args)
if has_unrecognized_types:
unrecognized_types = [
type(x) for x in flat_args if _check_for_subclass_arg(x)
]
not_implemented_log.debug(
"LocalTensorMode unrecognized subclass(es): %s", unrecognized_types
)
return NotImplemented
# Factory functions convert into LocalTensor, so we don't have to
# transmute a Tensor into a LocalTensor if mutation happens...
# But if you do an operation on a Tensor, do NOT wrap it into a
# LocalTensor. This helps prevent accidents when you're doing Tensor
# operations on the inner non-wrapped tensors.
if not local_tensors:
if self._disable or any(isinstance(a, Tensor) for a in flat_args):
return func(*args, **kwargs)
# For LocalTensors, verify they have compatible ranks
for a in flat_args:
if isinstance(a, LocalTensor):
assert a._ranks == self.ranks, (
f"Input LocalTensor {a} and LocalTensorMode must be configured for the same ranks"
)
if func.namespace == "c10d":
if func is torch.ops.c10d.allreduce_.default:
return _c10d._local_all_reduce_(*args, **kwargs)
elif func is torch.ops.c10d.allreduce_coalesced_.default:
return _c10d._local_allreduce_coalesced_(*args, **kwargs)
elif func is torch.ops.c10d.reduce_scatter_tensor_coalesced_.default:
return _c10d._local_reduce_scatter_tensor_coalesced_(*args, **kwargs)
elif func is torch.ops.c10d.scatter_.default:
return _c10d._local_scatter_(*args, **kwargs)
elif func is torch.ops.c10d.broadcast_.default:
return _c10d._local_broadcast_(*args, **kwargs)
elif func is torch.ops.c10d.allgather_.default:
return _c10d._local_all_gather_(*args, **kwargs)
elif func is torch.ops.c10d.allgather_into_tensor_coalesced_.default:
return _c10d._local_allgather_into_tensor_coalesced_(*args, **kwargs)
elif func is torch.ops.c10d.gather_.default:
return _c10d._local_gather_(*args, **kwargs)
elif func is torch.ops.c10d.alltoall_.default:
return _c10d._local_alltoall_(*args, **kwargs)
elif func is torch.ops.c10d.alltoall_base_.default:
return _c10d._local_alltoall_base_(*args, **kwargs)
elif func is torch.ops.c10d.barrier.default:
return _c10d._local_barrier(*args, **kwargs)
elif func is torch.ops.c10d.monitored_barrier_.default:
return _c10d._local_monitored_barrier_(*args, **kwargs)
elif func is torch.ops.c10d.send.default:
return _c10d._local_send(*args, **kwargs)
elif func is torch.ops.c10d.recv_.default:
return _c10d._local_recv_(*args, **kwargs)
elif func is torch.ops.c10d.recv_any_source_.default:
return _c10d._local_recv_any_source_(*args, **kwargs)
raise NotImplementedError(f"{func} not implemented")
if func.namespace == "_c10d_functional" or func.namespace == "_dtensor":
with LocalTensorMode(self.ranks):
return func._op_dk(
DispatchKey.CompositeExplicitAutograd, *args, **kwargs
)
if func.namespace == "_c10d_functional_autograd":
raise NotImplementedError(f"{func} not implemented")
if func.namespace == "symm_mem":
raise NotImplementedError(f"{func} not implemented")
return _for_each_rank_run_func(func, self.ranks, args, kwargs, alias=True)
@contextlib.contextmanager
def disable(self) -> Generator[None, None, None]:
"""
Disables LocalTensorMode temporarily. Primarily is intended to be used to perform
rank specific computations and merge results back before enabling LocalTensorMode back.
"""
old = self._disable
self._disable = True
self._unpatch_device_mesh()
try:
yield
finally:
self._disable = old
self._patch_device_mesh()
def rank_map(self, cb: Callable[[int], Tensor]) -> LocalTensor:
"""
Creates a LocalTensor instance by mapping rank id to ids local shard.
"""
with self.disable():
return LocalTensor({r: cb(r) for r in self.ranks})
def _patch_device_mesh(self) -> None:
assert self._old_get_coordinate is None
self._old_get_coordinate = DeviceMesh.get_coordinate # type: ignore[assignment]
DeviceMesh.get_coordinate = _LocalDeviceMesh.get_coordinate # type: ignore[method-assign]
def _unpatch_device_mesh(self) -> None:
assert self._old_get_coordinate is not None
DeviceMesh.get_coordinate = self._old_get_coordinate
self._old_get_coordinate = None
class _LocalDeviceMesh:
"""
Holds implementations of DeviceMesh functionality that must be patched while running
under LocalTensorMode.
"""
@staticmethod
def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]:
lm = local_tensor_mode()
assert lm is not None, "Unexpectedly not in LocalTensorMode"
rank_coords = (self.mesh == lm.rank_map(lambda r: torch.tensor(r))).nonzero()
# NB: unlike the regular mechanism, we don't allow for MPMD
assert rank_coords.size(0) == 1
assert isinstance(rank_coords[0], LocalTensor)
coords: list[dict[int, int]] = [{} for _ in range(rank_coords.size(1))]
for r, v in rank_coords[0]._local_tensors.items():
for i, x in enumerate(v.tolist()):
coords[i][r] = x
out = [torch.SymInt(LocalIntNode(c)) for c in coords]
return out # type: ignore[return-value]
def reconcile_args(args: Any, kwargs: dict[str, Any] | None = None) -> Any:
"""
Reconciles arguments by converting any LocalTensor instances in the input
arguments to their underlying torch.Tensor representation.
This function is typically used to prepare arguments for functions that
expect standard torch.Tensor objects, by flattening the input arguments,
replacing LocalTensor instances with their reconciled (standard tensor)
versions, and then reconstructing the original argument structure.
Args:
args: Positional arguments, possibly containing LocalTensor instances.
kwargs: Keyword arguments, possibly containing LocalTensor instances.
Returns:
Any: The arguments with all LocalTensor instances replaced by their reconciled torch.Tensor equivalents,
preserving the original structure.
"""
if kwargs is None:
kwargs = {}
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
reconciled_args = [
a.reconcile() if isinstance(a, LocalTensor) else a for a in flat_args
]
return pytree.tree_unflatten(reconciled_args, args_spec)
def local_tensor_mode() -> Optional[LocalTensorMode]:
"""
Returns the current active LocalTensorMode if one exists.
This function checks the global stack of LocalTensorMode instance. If there
is at least one LocalTensorMode active, it returns the most recently entered
(top of the stack) LocalTensorMode. If no LocalTensorMode is active, it returns None.
Returns:
Optional[LocalTensorMode]: The current LocalTensorMode if active, else None.
"""
if len(_LOCAL_TENSOR_MODE) > 0:
return _LOCAL_TENSOR_MODE[-1]
return None
def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorator that ensures a function is executed for each local tensor shard
when running under LocalTensorMode. If not in LocalTensorMode, the function
is executed normally. When in LocalTensorMode, the function is run for each
rank, and the results are collected appropriately.
This decorator is useful for functions that exhibit non-SPMD behavior, such
as those requiring rank specific actions. For example, a function that computes
offset into input tensor based on rank.
Note that the function being decorated must not have any side effects and
contain operations for a single rank only. For example, wrapping a function
that performs a collective operation will not work.
Args:
func (Callable[..., Any]): The function to be decorated.
Returns:
Callable[..., Any]: The wrapped function that handles LocalTensorMode logic.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
lm = local_tensor_mode()
if lm is None:
return func(*args, **kwargs)
ret = None
with lm.disable():
ret = _for_each_rank_run_func(func, lm.ranks, args, kwargs, alias=False)
lm = local_tensor_mode()
assert lm is not None
return ret
return wrapper

View File

@ -0,0 +1,669 @@
import functools
import math
import operator
from typing import Sequence
import torch
from torch._C import ScriptObject
from torch._C._distributed_c10d import FakeWork
from torch.distributed._mesh_layout import _MeshLayout
from torch.distributed.distributed_c10d import (
_get_default_group,
ProcessGroup,
ReduceOp,
Work,
)
# NOTE: Most of the c10d collectives often take a Tensor[] (or Tensor[][])
# when you would expect Tensor (or Tensor[]). In fact, there will only ever
# be one Tensor in this case; the old signature was to support dispatching a
# collective on multiple devices (ala DataParallel) but we don't support that
# API anymore. Note that we are not 100% consistent about this; some more
# modern collectives like _allgather_base_ got rid of the unnecessary list.
# When in doubt, consult the code that dispatches to the collective on the PG
# in distributed_c10d.py e.g., work = group.allgather([tensor_list], [tensor],
# opts) indicates its always a list.
def _gcd_list(numbers: Sequence[int]) -> int:
return 0 if not numbers else functools.reduce(math.gcd, numbers)
def _indices_to_layout(indices: list[int]) -> tuple[tuple[int, ...], tuple[int, ...]]:
# Base case: A single index represents a point, not a dimension.
if len(indices) <= 1:
return (), ()
# The smallest stride is likely the GCD of the differences between consecutive indices.
# For a sorted, unique list, all differences will be positive.
diffs = [indices[i] - indices[i - 1] for i in range(1, len(indices))]
last_stride = _gcd_list(diffs)
assert last_stride != 0, (
# This case should not be reached if indices are unique and sorted.
"Cannot determine stride; indices may not be unique."
)
# Identify the starting index of each "row" in the last dimension.
# An index starts a new row if the preceding index (index - stride) is not present.
indices_set = set(indices)
higher_dim_indices = [indices[0]]
for index in indices[1:]:
if (index - last_stride) not in indices_set:
higher_dim_indices.append(index)
# From the number of rows, we can deduce the shape of the last dimension.
assert len(indices) % len(higher_dim_indices) == 0, (
"Indices do not form a regular grid. "
f"Found {len(higher_dim_indices)} subgroups for {len(indices)} total elements."
)
last_shape = len(indices) // len(higher_dim_indices)
# Recurse on the higher-dimensional indices (the start of each row).
higher_shapes, higher_strides = _indices_to_layout(higher_dim_indices)
# Combine the results from the recursion with the current dimension's results.
final_shapes = higher_shapes + (last_shape,)
final_strides = higher_strides + (last_stride,)
return final_shapes, final_strides
def _prepare_collective_groups(
process_group_so: ScriptObject,
) -> tuple[list[int], list[int], int]:
process_group = ProcessGroup.unbox(process_group_so)
ranks = torch.distributed.get_process_group_ranks(process_group)
assert ranks
# TODO: We can handle permutations but the layout inference algorithm will
# lose the permutation so we will have to reapply it
assert ranks == sorted(ranks), ranks
offset = ranks[0]
ranks = [r - offset for r in ranks]
shape, strides = _indices_to_layout(ranks)
layout = _MeshLayout(shape, strides)
global_pg = _get_default_group()
group_offsets = layout.complement(global_pg.size()).all_ranks_from_zero()
return ranks, group_offsets, offset
def _local_broadcast_(
tensors: list[torch.Tensor],
process_group_so: ScriptObject,
root_rank: int,
root_tensor: int,
async_op: bool = True,
timeout: int = -1,
) -> tuple[list[torch.Tensor], ScriptObject]:
# "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
# "int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"
from . import LocalTensor
assert len(tensors) == 1
assert root_tensor == 0
tensor = tensors[0]
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)
# We're going to assume SPMD where for every rank group the root_rank is
# the same relative to others
relative_root_rank = root_rank - offset
assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor"
for group_offset in group_offsets:
# For the tensors in this group [group_offset + r for r in ranks]
# perform the broadcast on them
group_ranks = [group_offset + r for r in ranks]
source_rank = group_offset + relative_root_rank
source_tensor = tensor._local_tensors[source_rank]
# Broadcast the source tensor to all ranks in this group
for rank in group_ranks:
if source_rank != rank:
tensor._local_tensors[rank].copy_(source_tensor)
work = FakeWork()
work_so = Work.boxed(work)
return (tensors, work_so)
def _local_reduce(
reduce_op: ReduceOp,
tensors: list[torch.Tensor],
) -> torch.Tensor:
if reduce_op == ReduceOp.SUM:
op = operator.add
elif reduce_op == ReduceOp.AVG:
op = None
elif reduce_op == ReduceOp.PRODUCT:
op = operator.mul
elif reduce_op == ReduceOp.MIN:
op = torch.minimum
elif reduce_op == ReduceOp.MAX:
op = torch.maximum
elif reduce_op == ReduceOp.BAND:
op = torch.bitwise_and
elif reduce_op == ReduceOp.BOR:
op = torch.bitwise_or
elif reduce_op == ReduceOp.BXOR:
op = torch.bitwise_xor
elif reduce_op == ReduceOp.PREMUL_SUM:
raise NotImplementedError("PREMUL_SUM: need to add binding for scaling factor")
else:
raise NotImplementedError(f"ReduceOp {reduce_op} not implemented")
if reduce_op == ReduceOp.AVG:
return functools.reduce(operator.add, tensors) / len(tensors)
else:
assert op is not None
return functools.reduce(op, tensors)
def _local_all_reduce_(
tensors: list[torch.Tensor],
process_group_so: ScriptObject,
reduce_op_so: ScriptObject,
sparse_indices: torch.Tensor | None = None,
async_op: bool = True,
timeout: int = -1,
) -> tuple[list[torch.Tensor], ScriptObject]:
# "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
# "__torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, "
# "int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
from . import LocalTensor
assert len(tensors) == 1
tensor = tensors[0]
reduce_op = reduce_op_so.op() # type: ignore[attr-defined]
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor"
for group_offset in group_offsets:
# For the tensors in this group [group_offset + r for r in ranks]
# perform the allreduce on them
group_ranks = [group_offset + r for r in ranks]
# Collect tensors from the specified ranks in this group
group_tensors = []
for rank in group_ranks:
group_tensors.append(tensor._local_tensors[rank])
# Perform the reduction operation
reduced_tensor = _local_reduce(reduce_op, group_tensors)
# Update all tensors in the group with the reduced result
for rank in group_ranks:
tensor._local_tensors[rank].copy_(reduced_tensor)
work = FakeWork()
work_so = Work.boxed(work)
return (tensors, work_so)
def _local_allreduce_coalesced_(
tensors: list[torch.Tensor],
process_group_so: ScriptObject,
reduce_op_so: ScriptObject,
async_op: bool = True,
timeout: int = -1,
) -> ScriptObject:
# "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
# "__torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"
from . import LocalTensor
reduce_op = reduce_op_so.op() # type: ignore[attr-defined]
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
for group_offset in group_offsets:
# For the tensors in this group [group_offset + r for r in ranks]
# perform the allreduce on all tensors together
group_ranks = [group_offset + r for r in ranks]
# For each tensor, perform the reduction operation
for tensor in tensors:
assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor"
# Collect tensors from the specified ranks in this group
group_tensors = []
for rank in group_ranks:
group_tensors.append(tensor._local_tensors[rank])
# Perform the reduction operation
reduced_tensor = _local_reduce(reduce_op, group_tensors)
# Update all tensors in the group with the reduced result
for rank in group_ranks:
tensor._local_tensors[rank].copy_(reduced_tensor)
work = FakeWork()
work_so = Work.boxed(work)
return work_so
def _local_reduce_scatter_tensor_coalesced_(
output_tensors: list[torch.Tensor],
input_tensors: list[torch.Tensor],
process_group_so: ScriptObject,
reduce_op_so: ScriptObject,
async_op: bool = True,
timeout: int = -1,
) -> ScriptObject:
# "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, "
# "__torch__.torch.classes.c10d.ProcessGroup process_group, "
# "__torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, "
# "int timeout=-1) -> __torch__.torch.classes.c10d.Work"
from . import LocalTensor
reduce_op = reduce_op_so.op() # type: ignore[attr-defined]
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
for group_offset in group_offsets:
# For the tensors in this group [group_offset + r for r in ranks]
# perform the allreduce on all tensors together
group_ranks = [group_offset + r for r in ranks]
# For each tensor, perform the reduction operation
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
assert isinstance(input_tensor, LocalTensor), (
"Input tensor must be a LocalTensor"
)
assert isinstance(output_tensor, LocalTensor), (
"Output tensor must be a LocalTensor"
)
# Collect tensors from the specified ranks in this group
group_inputs = []
for rank in group_ranks:
group_inputs.append(input_tensor._local_tensors[rank])
# Perform the reduction operation
reduced_input = _local_reduce(reduce_op, group_inputs)
reduced_inpit_splits = torch.split(
reduced_input, reduced_input.size(0) // len(group_ranks), dim=0
)
# Update all tensors in the group with the reduced result
for rank in group_ranks:
output_tensor._local_tensors[rank].copy_(reduced_inpit_splits[rank])
work = FakeWork()
work_so = Work.boxed(work)
return work_so
def _local_all_gather_(
output_tensors: list[list[torch.Tensor]],
input_tensors: list[torch.Tensor],
process_group_so: ScriptObject,
async_op: bool = True,
timeout: int = -1,
) -> tuple[list[list[torch.Tensor]], ScriptObject]:
# "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, "
# "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, "
# "int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
from . import LocalTensor
assert len(output_tensors) == 1
assert len(input_tensors) == 1
input_tensor = input_tensors[0]
output_tensors = output_tensors[0]
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor"
for i in range(len(output_tensors)):
assert isinstance(output_tensors[i], LocalTensor), (
"Output tensor must be a LocalTensor"
)
for group_offset in group_offsets:
# For the tensors in this group [group_offset + r for r in ranks]
# perform the all_gather on them
group_ranks = [group_offset + r for r in ranks]
# For each rank in the group, gather from their input tensor
for i, rank_i in enumerate(group_ranks):
output_tensors[i].copy_(input_tensor._local_tensors[rank_i])
work = FakeWork()
work_so = Work.boxed(work)
return ([output_tensors], work_so)
def _local_allgather_into_tensor_coalesced_(
output_tensors: list[torch.Tensor],
input_tensors: list[torch.Tensor],
process_group_so: ScriptObject,
async_op: bool = True,
) -> ScriptObject:
# "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, "
# "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) "
# "-> __torch__.torch.classes.c10d.Work"
from . import LocalTensor
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
# Each output tensor should be sized to hold all gathered inputs
# outputs[i] will contain all inputs[i] from all ranks
assert len(output_tensors) == len(input_tensors), (
f"Number of outputs ({len(output_tensors)}) must match number of inputs ({len(input_tensors)})"
)
for group_offset in group_offsets:
# For the tensors in this group [group_offset + r for r in ranks]
# perform the allgather_into_tensor on them
group_ranks = [group_offset + r for r in ranks]
# For each input/output pair
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
assert isinstance(input_tensor, LocalTensor), (
"Input tensor must be a LocalTensor"
)
assert isinstance(output_tensor, LocalTensor), (
"Output tensor must be a LocalTensor"
)
# Gather input_tensor from all ranks into output_tensor
# The output should be a concatenation of all inputs along the first dimension
gathered_tensors = []
for rank in group_ranks:
gathered_tensors.append(input_tensor._local_tensors[rank])
# Concatenate along first dimension and copy to output
if gathered_tensors:
concatenated = torch.cat(gathered_tensors, dim=0)
for rank in group_ranks:
output_tensor._local_tensors[rank].copy_(concatenated)
work = FakeWork()
work_so = Work.boxed(work)
return work_so
def _local_gather_(
output_tensors: list[list[torch.Tensor]],
input_tensors: list[torch.Tensor],
process_group_so: ScriptObject,
root_rank: int,
async_op: bool = True,
timeout: int = -1,
) -> ScriptObject:
# "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, "
# "__torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, "
# "bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"
raise NotImplementedError(
"LocalTensor does not support MPMD operations like gather "
"(only root rank receives data). Use SPMD collective operations like allgather instead."
)
def _local_scatter_(
output_tensors: list[torch.Tensor],
input_tensors: list[list[torch.Tensor]],
process_group_so: ScriptObject,
root_rank: int,
async_op: bool = True,
timeout: int = -1,
) -> tuple[list[torch.Tensor], ScriptObject]:
# "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, "
# "__torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, "
# "bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
from . import LocalTensor
assert len(output_tensors) == 1
assert len(input_tensors) == 1
output_tensor = output_tensors[0]
input_tensors = input_tensors[0]
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)
# We're going to assume SPMD where for every rank group the root_rank is
# the same relative to others
relative_root_rank = root_rank - offset
assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor"
assert len(ranks) == len(input_tensors), (ranks, input_tensors)
for group_offset in group_offsets:
# For the tensors in this group [group_offset + r for r in ranks]
# perform the scatter on them
group_ranks = [group_offset + r for r in ranks]
# Root rank scatters its input tensors to all ranks in this group
for i, rank in enumerate(group_ranks):
input_tensor = input_tensors[i]
assert isinstance(input_tensor, LocalTensor)
# Each rank i gets the i-th input tensor from the root
source_tensor = input_tensor._local_tensors[
group_offset + relative_root_rank
]
output_tensor._local_tensors[rank].copy_(source_tensor)
work = FakeWork()
work_so = Work.boxed(work)
return (output_tensors, work_so)
def _local_alltoall_(
output_tensors: list[torch.Tensor],
input_tensors: list[torch.Tensor],
process_group_so: ScriptObject,
async_op: bool = True,
timeout: int = -1,
) -> tuple[list[torch.Tensor], ScriptObject]:
# "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, "
# "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, "
# "int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)";
from . import LocalTensor
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
assert len(input_tensors) == len(output_tensors) == len(ranks), (
f"Number of input tensors ({len(input_tensors)}), "
f"output tensors ({len(output_tensors)}), and ranks ({len(ranks)}) must match"
)
for group_offset in group_offsets:
# For the tensors in this group [group_offset + r for r in ranks]
# perform the alltoall on them
group_ranks = [group_offset + r for r in ranks]
# In alltoall, rank i sends input_tensors[j] to rank j and receives into output_tensors[i] from rank j
for i, rank_i in enumerate(group_ranks):
output_tensor = output_tensors[i]
assert isinstance(output_tensor, LocalTensor), (
"Output tensor must be a LocalTensor"
)
for j, rank_j in enumerate(group_ranks):
input_tensor = input_tensors[j]
assert isinstance(input_tensor, LocalTensor), (
"Input tensor must be a LocalTensor"
)
# Rank i's j-th input tensor goes to rank j's i-th output tensor
source_tensor = input_tensor._local_tensors[rank_i]
output_tensor._local_tensors[rank_j].copy_(source_tensor)
work = FakeWork()
work_so = Work.boxed(work)
return (output_tensors, work_so)
def _local_alltoall_base_(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
process_group_so: ScriptObject,
output_split_sizes: list[int],
input_split_sizes: list[int],
async_op: bool = True,
timeout: int = -1,
) -> ScriptObject:
# "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, "
# "int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work";
from . import LocalTensor
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor"
assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor"
# Convert split sizes to lists if they aren't already
if output_split_sizes is not None:
output_split_sizes = list(output_split_sizes)
if input_split_sizes is not None:
input_split_sizes = list(input_split_sizes)
for group_offset in group_offsets:
# For the tensors in this group [group_offset + r for r in ranks]
# perform the alltoall_base on them
group_ranks = [group_offset + r for r in ranks]
for i, rank_i in enumerate(group_ranks):
# Split input tensor from rank_i according to input_split_sizes
rank_tensor = input_tensor._local_tensors[rank_i]
if input_split_sizes is not None and len(input_split_sizes) > 0:
# Split the input tensor
input_splits = torch.split(rank_tensor, input_split_sizes, dim=0)
else:
# No split sizes specified, split evenly
split_size = rank_tensor.size(0) // len(group_ranks)
input_splits = torch.split(rank_tensor, split_size, dim=0)
# Send each split to the corresponding rank
for j, rank_j in enumerate(group_ranks):
if j < len(input_splits):
split_tensor = input_splits[j]
# Determine where to place this split in the output tensor
if output_split_sizes is not None and len(output_split_sizes) > 0:
# Calculate offset based on output split sizes
output_offset = sum(output_split_sizes[:i]) if i > 0 else 0
end_offset = (
output_offset + output_split_sizes[i]
if i < len(output_split_sizes)
else output_tensor._local_tensors[rank_j].size(0)
)
else:
# No output split sizes, use even splits
split_size = output_tensor._local_tensors[rank_j].size(
0
) // len(group_ranks)
output_offset = i * split_size
end_offset = min(
(i + 1) * split_size,
output_tensor._local_tensors[rank_j].size(0),
)
# Copy the split to the appropriate section of the output tensor
output_section = output_tensor._local_tensors[rank_j][
output_offset:end_offset
]
if output_section.numel() > 0:
# Reshape split_tensor to match output_section if necessary
if split_tensor.size() != output_section.size():
split_tensor = split_tensor.view(output_section.size())
output_section.copy_(split_tensor)
work = FakeWork()
work_so = Work.boxed(work)
return work_so
def _local_barrier(
tensor: torch.Tensor,
process_group_so: ScriptObject,
device_ids: list[int],
async_op: bool = True,
timeout: int = -1,
) -> ScriptObject:
# "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, "
# "int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work";
from . import LocalTensor
# Barrier is a synchronization primitive - in local simulation,
# we don't need to do any actual work since all "ranks" are in the same process
# Just validate that the tensor is a LocalTensor
assert isinstance(tensor, LocalTensor)
# In a real distributed setting, barrier would synchronize all processes
# In local simulation, this is essentially a no-op since all ranks are local
work = FakeWork()
work_so = Work.boxed(work)
return work_so
def _local_monitored_barrier_(
tensor: torch.Tensor,
process_group_so: ScriptObject,
device_ids: list[int],
timeout: int,
wait_all_ranks: bool,
) -> None:
# "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, "
# "int[] device_ids, int timeout, bool wait_all_ranks) -> ()";
from . import LocalTensor
# Monitored barrier is a synchronization primitive with monitoring - in local simulation,
# we don't need to do any actual work since all "ranks" are in the same process
# Just validate that the tensor is a LocalTensor
assert isinstance(tensor, LocalTensor)
# In a real distributed setting, monitored barrier would synchronize all processes
# and provide monitoring capabilities. In local simulation, this is essentially a no-op
# since all ranks are local and no actual synchronization is needed
return
def _local_send(
tensors: list[torch.Tensor],
process_group_so: ScriptObject,
dst: int,
tag: int,
) -> ScriptObject:
# "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
# "int dst, int tag) -> __torch__.torch.classes.c10d.Work";
raise NotImplementedError(
"LocalTensor does not support MPMD operations like send. "
"Use SPMD collective operations instead."
)
def _local_recv_(
tensors: list[torch.Tensor],
process_group_so: ScriptObject,
src: int,
tag: int,
) -> ScriptObject:
# "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
# "int src, int tag) -> __torch__.torch.classes.c10d.Work";
raise NotImplementedError(
"LocalTensor does not support MPMD operations like recv. "
"Use SPMD collective operations instead."
)
def _local_recv_any_source_(
tensors: list[torch.Tensor], process_group_so: ScriptObject, tag: int
) -> ScriptObject:
# "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
# "int tag) -> __torch__.torch.classes.c10d.Work";
raise NotImplementedError(
"LocalTensor does not support MPMD operations like recv_any_source. "
"Use SPMD collective operations instead."
)

View File

@ -10,6 +10,7 @@ import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._dtensor_spec as dtensor_spec
from torch._C._distributed_c10d import _resolve_process_group
from torch._logging import warning_once
from torch.distributed._local_tensor import local_tensor_mode
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.distributed_c10d import (
_get_group_size_by_name,
@ -40,7 +41,7 @@ def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
if mesh.device_type == "cpu":
if mesh.device_type == "cpu" and local_tensor_mode() is None:
# Gloo does not support alltoall, so falling back to allgather + chunk
warning_once(
logger,

View File

@ -165,7 +165,7 @@ class OpDispatcher:
raise
except Exception as e:
raise RuntimeError(
f"Sharding propagation failed for {op_info.schema}"
f"{e}\n\nSharding propagation failed for {op_info.schema}"
) from e
output_sharding = op_info.output_sharding

View File

@ -319,6 +319,10 @@ LINEAR_REDUCTION_OP_MAP = {
aten.all.dim: "sum",
aten.sum.default: "sum",
aten.sum.dim_IntList: "sum",
aten.any.default: "sum",
aten.any.dim: "sum",
aten.any.out: "sum",
# These are only valid when there is no padding
aten.prod.default: "product",
aten.prod.dim_int: "product",
aten.prod.int_out: "product",
@ -332,9 +336,6 @@ LINEAR_REDUCTION_OP_MAP = {
aten.min.default: "min",
aten.min.dim: "min",
aten.min.out: "min",
aten.any.default: "sum",
aten.any.dim: "sum",
aten.any.out: "sum",
aten.amax.default: "max",
aten.amax.out: "max",
aten.amin.default: "min",

View File

@ -383,6 +383,7 @@ def redistribute_local_tensor(
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
elif target.is_shard():
# Case 2: target is Shard
target_placement = cast(Shard, target)

View File

@ -149,7 +149,7 @@ def _compute_local_shape_and_global_offset(
ordered_placements = _explicit_order_placements(mesh_shape, placements)
local_shape = list(global_shape)
# We'll compute the data for where the shard beings on a per-dim basis.
# We'll compute the data for where the shard begins on a per-dim basis.
# However, a single dim can be sharded multiple times, so we will end up
# doing a Sum(size*stride) like computation to determine the location of our
# shard for each of the shardings on that dim.
@ -170,6 +170,14 @@ def _compute_local_shape_and_global_offset(
local_shape[shard_dim] = shard_size
shard_global_offset = global_offset[shard_dim] + not_none(shard_offset)
zero_global_offset = global_shape[shard_dim]
if isinstance(shard_global_offset, torch.SymInt) and not isinstance(
zero_global_offset, torch.SymInt
):
zero_global_offset = torch.SymInt(zero_global_offset)
global_offset[shard_dim] = torch.sym_ite(
shard_size == 0,
# Special case to fill in a standardized non-garbage value for
@ -179,11 +187,11 @@ def _compute_local_shape_and_global_offset(
# Note that you can end up with zero-size shards that are
# still otherwise in bounds for the tensor (TODO: give an
# example).
global_shape[shard_dim],
zero_global_offset,
# As we successively shard the same dimension, we keep
# advancing our pointer beyond our original offset until we
# get to the final chunk start.
global_offset[shard_dim] + not_none(shard_offset),
shard_global_offset,
)
# NOTE: the offset compute relies on the local shard index and it has no

View File

@ -6,6 +6,7 @@ from typing import cast, Optional
import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed._local_tensor import maybe_run_for_local_tensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._collective_utils import (
fill_empty_tensor_to_shards,
@ -128,6 +129,7 @@ class Shard(Placement):
)
@staticmethod
@maybe_run_for_local_tensor
def local_shard_size_and_offset(
curr_local_size: int,
num_chunks: int,
@ -170,6 +172,20 @@ class Shard(Placement):
) -> tuple[int, Optional[int]]:
return Shard.local_shard_size_and_offset(curr_local_size, num_chunks, rank)
@staticmethod
@maybe_run_for_local_tensor
def _maybe_unpad_tensor_with_sizes(
dim, local_tensor, pad_sizes, mesh_dim_local_rank, make_contiguous
) -> torch.Tensor:
# Only unpad if the local_tensor was padded on the dimension.
if pad_sizes[mesh_dim_local_rank] > 0:
local_tensor = unpad_tensor(
local_tensor, dim, pad_sizes[mesh_dim_local_rank]
)
if make_contiguous:
local_tensor = local_tensor.contiguous()
return local_tensor
@staticmethod
def _make_shard_tensor(
dim: int,
@ -198,24 +214,28 @@ class Shard(Placement):
dim, tensor, num_chunks, with_padding=False, contiguous=True
)
return scatter_list[mesh_dim_local_rank]
return Shard._select_shard(scatter_list, mesh_dim_local_rank)
scatter_list, pad_sizes = Shard._make_split_tensor(
dim, tensor, num_chunks, with_padding=True, contiguous=True
)
output = torch.empty_like(scatter_list[mesh_dim_local_rank])
it = iter(scatter_list)
first = next(it)
# Tensors in the scatter list are expected to have the same shape because
# split is requested with padding.
assert all(first.shape == v.shape for v in it)
output = torch.empty_like(first)
# perform scatter from the src_data_rank as data source when it is not None
mesh_scatter(
output, scatter_list, mesh, mesh_dim=mesh_dim, group_src=src_data_rank
)
# Only unpad if the local_tensor was padded on the dimension.
if pad_sizes[mesh_dim_local_rank] > 0:
output = unpad_tensor(output, dim, pad_sizes[mesh_dim_local_rank])
# Unpad might return a view, hence we need to remake it contiguous
output = output.contiguous()
return output
return Shard._maybe_unpad_tensor_with_sizes(
dim, output, pad_sizes, mesh_dim_local_rank, True
)
def _shard_tensor(
self,
@ -245,6 +265,7 @@ class Shard(Placement):
return tensor
is_padded = tensor.size(self.dim) % num_chunks != 0
pad_sizes = None
if is_padded:
scattered_list, pad_sizes = Shard._make_split_tensor(
self.dim, tensor, num_chunks, with_padding=True, contiguous=True
@ -258,9 +279,47 @@ class Shard(Placement):
)
if is_padded:
output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined]
assert pad_sizes is not None
output = Shard._maybe_unpad_tensor_with_sizes(
self.dim, output, pad_sizes, my_coordinate[mesh_dim], False
)
return output
@maybe_run_for_local_tensor
def _maybe_pad_tensor(
self,
local_tensor: torch.Tensor,
logical_dim_size: int,
num_chunks: int,
) -> torch.Tensor:
is_padded = logical_dim_size % num_chunks != 0
if is_padded:
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
pad_size = full_chunk_size - local_tensor.size(self.dim)
local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
if not local_tensor.is_contiguous():
local_tensor = local_tensor.contiguous()
return local_tensor
@maybe_run_for_local_tensor
def _maybe_unpad_tensor(
self,
local_tensor: torch.Tensor,
logical_dim_size: int,
num_chunks: int,
) -> torch.Tensor:
is_padded = logical_dim_size % num_chunks != 0
if is_padded:
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined]
local_tensor = unpad_tensor(local_tensor, self.dim, unpad_size)
return local_tensor
def _to_replicate_tensor(
self,
local_tensor: torch.Tensor,
@ -273,28 +332,27 @@ class Shard(Placement):
is replicated on the previously sharded mesh dimension
"""
num_chunks = mesh.size(mesh_dim=mesh_dim)
logical_dim_size = current_logical_shape[self.dim]
is_padded = logical_dim_size % num_chunks != 0
if is_padded:
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
pad_size = full_chunk_size - local_tensor.size(self.dim)
local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
if not local_tensor.is_contiguous():
local_tensor = local_tensor.contiguous()
local_tensor = self._maybe_pad_tensor(
local_tensor, logical_dim_size, num_chunks
)
result = funcol.all_gather_tensor(
local_tensor,
gather_dim=self.dim,
group=(mesh, mesh_dim),
)
if is_padded:
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined]
result = unpad_tensor(result, self.dim, unpad_size)
result = self._maybe_unpad_tensor(result, logical_dim_size, num_chunks)
return result
@staticmethod
@maybe_run_for_local_tensor
def _select_shard(shards: list[torch.Tensor], shard_index) -> torch.Tensor:
return shards[shard_index].clone()
def _replicate_to_shard(
self,
local_tensor: torch.Tensor,
@ -313,7 +371,8 @@ class Shard(Placement):
with_padding=False,
contiguous=False,
)
return shards[shard_index].clone()
return Shard._select_shard(shards, shard_index)
def _to_new_shard_dim(
self,

View File

@ -41,6 +41,9 @@ class ConstantIntNode:
def _graph_repr(self) -> str:
return self._str()
def add(self, other: Any) -> Any:
return other.add(self)
def mul(self, other: Any) -> Any:
return other.mul(self)

View File

@ -13,6 +13,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._local_tensor import LocalTensor
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
@ -660,7 +661,7 @@ class DTensorConverter:
def to_dist_tensor(
self, t: torch.Tensor, mesh: DeviceMesh, placements: list[Placement]
) -> torch.Tensor:
if type(t) is torch.Tensor or type(t) is nn.Parameter:
if type(t) is torch.Tensor or type(t) is nn.Parameter or type(t) is LocalTensor:
if self.is_supported_tensor(t):
self.hit += 1
if t.ndim == 0:
@ -669,7 +670,7 @@ class DTensorConverter:
else:
# distribute non-scalar tensors
r = distribute_tensor(t, mesh, placements)
if type(t) is nn.Parameter:
if isinstance(t, nn.Parameter):
r = nn.Parameter( # type: ignore[assignment]
r, requires_grad=r.requires_grad
)

View File

@ -1292,7 +1292,7 @@ SAC_IGNORED_OPS = {
# With subclasses involved, these metadata ops become dispatchable, this
# can result in incorrectness if these ops are selected cached.
torch.ops.prim.device.default,
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) # type: ignore[has-type]
class _CachingTorchDispatchMode(TorchDispatchMode):