mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			ciflow/b20
			...
			huba/local
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| b1648d4bc9 | 
@ -3,13 +3,23 @@
 | 
			
		||||
 | 
			
		||||
import pathlib
 | 
			
		||||
import tempfile
 | 
			
		||||
import types
 | 
			
		||||
import unittest
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from numpy.testing import assert_array_equal
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
import torch.distributed.distributed_c10d as c10d
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
 | 
			
		||||
from torch.distributed._local_tensor import (
 | 
			
		||||
    LocalIntNode,
 | 
			
		||||
    LocalTensorMode,
 | 
			
		||||
    maybe_run_for_local_tensor,
 | 
			
		||||
)
 | 
			
		||||
from torch.distributed.device_mesh import init_device_mesh
 | 
			
		||||
from torch.distributed.tensor import (
 | 
			
		||||
    DeviceMesh,
 | 
			
		||||
@ -44,6 +54,11 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
 | 
			
		||||
c10d_functional = torch.ops.c10d_functional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@maybe_run_for_local_tensor
 | 
			
		||||
def map_tensor_for_rank(tensor, rank, func):
 | 
			
		||||
    return func(tensor, rank)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DummyMLP(torch.nn.Module):
 | 
			
		||||
    def __init__(self, device):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
@ -592,7 +607,12 @@ class DTensorTest(DTensorTestBase):
 | 
			
		||||
        self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
 | 
			
		||||
        self.assertEqual(sharded_tensor.placements, placements)
 | 
			
		||||
        local_tensor = sharded_tensor.to_local()
 | 
			
		||||
        self.assertEqual(local_tensor, full_tensor[range(self.rank, self.rank + 1), :])
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            local_tensor,
 | 
			
		||||
            map_tensor_for_rank(
 | 
			
		||||
                full_tensor, self.rank, lambda ft, r: ft[range(r, r + 1), :]
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Shard by column
 | 
			
		||||
        placements = [Shard(1)]
 | 
			
		||||
@ -600,7 +620,12 @@ class DTensorTest(DTensorTestBase):
 | 
			
		||||
        self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
 | 
			
		||||
        self.assertEqual(sharded_tensor.placements, placements)
 | 
			
		||||
        local_tensor = sharded_tensor.to_local()
 | 
			
		||||
        self.assertEqual(local_tensor, full_tensor[:, range(self.rank, self.rank + 1)])
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            local_tensor,
 | 
			
		||||
            map_tensor_for_rank(
 | 
			
		||||
                full_tensor, self.rank, lambda ft, r: ft[:, range(r, r + 1)]
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # assert full tensor is not changed
 | 
			
		||||
        self.assertEqual(full_tensor, torch.arange(ws * ws).reshape(ws, ws))
 | 
			
		||||
@ -620,6 +645,105 @@ class DTensorTest(DTensorTestBase):
 | 
			
		||||
        self.assertEqual(local_tensor.item(), self.rank)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LocalDTensorTest(DTensorTest):
 | 
			
		||||
    def get_local_tensor_mode(self):
 | 
			
		||||
        return LocalTensorMode(frozenset(range(0, self.world_size)))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def rank(self):
 | 
			
		||||
        return torch.SymInt(LocalIntNode({r: r for r in range(self.world_size)}))
 | 
			
		||||
 | 
			
		||||
    @rank.setter
 | 
			
		||||
    def rank(self, rank):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def join_or_run(self, fn):
 | 
			
		||||
        @wraps(fn)
 | 
			
		||||
        def wrapper(self):
 | 
			
		||||
            fn()
 | 
			
		||||
 | 
			
		||||
        return types.MethodType(wrapper, self)
 | 
			
		||||
 | 
			
		||||
    def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
 | 
			
		||||
        dist.init_process_group("fake", rank=0, world_size=self.world_size)
 | 
			
		||||
        self._pg = c10d._get_default_group()
 | 
			
		||||
 | 
			
		||||
    def destroy_pg(self, device_id: Optional[int] = None) -> None:
 | 
			
		||||
        dist.destroy_process_group(self._pg)
 | 
			
		||||
        self._pg = None
 | 
			
		||||
 | 
			
		||||
    def _spawn_processes(self) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_constructor(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_meta_dtensor(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_modules_w_meta_dtensor(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_stride(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_from_local(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_from_local_uneven_sharding(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_from_local_uneven_sharding_raise_error(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_from_local_negative_dim(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_to_local(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_to_local_grad_hint(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_full_tensor_sync(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_full_tensor_grad_hint(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_new_empty_strided(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_async_output(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_from_local_then_to_local(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_spec_read_only_after_set(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_spec_hash(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_properties(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_save_load(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_save_load_import(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_shard_tensor_2d(self):
 | 
			
		||||
        with self.get_local_tensor_mode():
 | 
			
		||||
            super().test_shard_tensor_2d()
 | 
			
		||||
 | 
			
		||||
    def test_shard_tensor(self):
 | 
			
		||||
        with self.get_local_tensor_mode():
 | 
			
		||||
            super().test_shard_tensor()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DTensorMeshTest(DTensorTestBase):
 | 
			
		||||
    @property
 | 
			
		||||
    def world_size(self):
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates
 | 
			
		||||
# Owner(s): ["oncall: distributed"]
 | 
			
		||||
 | 
			
		||||
import copy
 | 
			
		||||
import re
 | 
			
		||||
import unittest
 | 
			
		||||
import warnings
 | 
			
		||||
@ -8,6 +9,7 @@ import warnings
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
import torch.testing._internal.common_methods_invocations as common_ops
 | 
			
		||||
from torch.distributed._local_tensor import LocalTensorMode, reconcile_args
 | 
			
		||||
from torch.distributed.tensor import (
 | 
			
		||||
    distribute_tensor,
 | 
			
		||||
    DTensor,
 | 
			
		||||
@ -21,7 +23,7 @@ from torch.testing._internal.common_device_type import (
 | 
			
		||||
    ops,
 | 
			
		||||
)
 | 
			
		||||
from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db
 | 
			
		||||
from torch.testing._internal.common_utils import run_tests, suppress_warnings
 | 
			
		||||
from torch.testing._internal.common_utils import run_tests, suppress_warnings, TestCase
 | 
			
		||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
 | 
			
		||||
    DTensorConverter,
 | 
			
		||||
    DTensorOpTestBase,
 | 
			
		||||
@ -49,7 +51,7 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
 | 
			
		||||
    return (op_name, variant_name, device_type, dtypes, False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def skipOps(test_case_name, base_test_name, to_skip):
 | 
			
		||||
def skipOps(op_db, test_case_name, base_test_name, to_skip):
 | 
			
		||||
    all_opinfos = op_db
 | 
			
		||||
    for xfail in to_skip:
 | 
			
		||||
        op_name, variant_name, device_type, dtypes, expected_failure = xfail
 | 
			
		||||
@ -88,6 +90,34 @@ def skipOps(test_case_name, base_test_name, to_skip):
 | 
			
		||||
    return wrapped
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repurpose_ops(op_db, base_test_name, derived_test_name):
 | 
			
		||||
    """
 | 
			
		||||
    Copies op info database and for the decorators that applied to base test class updates
 | 
			
		||||
    them to apply to derived test class. The class update is required because decorators are applied
 | 
			
		||||
    only if the class name matches (it doesn't consider base classes).
 | 
			
		||||
 | 
			
		||||
    Specifically we use this function to create two test classes (one for multi-threaded and one for
 | 
			
		||||
    local tensor flavors) that share common test body but different rules for skip or fail.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        op_db: List of OpInfo objects to be repurposed.
 | 
			
		||||
        base_test_name: The original test class name to be replaced.
 | 
			
		||||
        derived_test_name: The new test class name to set in decorators.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        list: A new list of OpInfo objects with updated target class names for the
 | 
			
		||||
        decorator.
 | 
			
		||||
    """
 | 
			
		||||
    repurposed_ops = []
 | 
			
		||||
    for opinfo in op_db:
 | 
			
		||||
        opinfo_copy = copy.deepcopy(opinfo)
 | 
			
		||||
        for decorator in list(opinfo_copy.decorators):
 | 
			
		||||
            if hasattr(decorator, "cls_name") and decorator.cls_name == base_test_name:
 | 
			
		||||
                decorator.cls_name = derived_test_name
 | 
			
		||||
        repurposed_ops.append(opinfo_copy)
 | 
			
		||||
    return repurposed_ops
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Re-generate this failed list, turn on dry_run of the below func
 | 
			
		||||
# check_dtensor_func(self, test, op, dry_run=True), then run sth
 | 
			
		||||
# like python test/distributed/tensor/test_dtensor_ops.py > failed.expect
 | 
			
		||||
@ -162,7 +192,6 @@ dtensor_fails = {
 | 
			
		||||
    xfail("fmin"),
 | 
			
		||||
    xfail("frexp"),
 | 
			
		||||
    xfail("full"),
 | 
			
		||||
    xfail("full_like"),
 | 
			
		||||
    xfail("geometric"),
 | 
			
		||||
    xfail("geqrf"),
 | 
			
		||||
    xfail("grid_sampler_2d"),
 | 
			
		||||
@ -226,7 +255,6 @@ dtensor_fails = {
 | 
			
		||||
    xfail("masked_select"),
 | 
			
		||||
    xfail("masked.argmax"),
 | 
			
		||||
    xfail("masked.argmin"),
 | 
			
		||||
    xfail("masked.cumprod"),
 | 
			
		||||
    xfail("masked.logsumexp"),
 | 
			
		||||
    xfail("masked.median"),
 | 
			
		||||
    xfail("matrix_exp"),
 | 
			
		||||
@ -244,8 +272,6 @@ dtensor_fails = {
 | 
			
		||||
    xfail("native_batch_norm"),
 | 
			
		||||
    xfail("narrow_copy"),
 | 
			
		||||
    xfail("ne"),
 | 
			
		||||
    xfail("new_empty"),
 | 
			
		||||
    xfail("new_empty_strided"),
 | 
			
		||||
    xfail("transpose"),
 | 
			
		||||
    xfail("nn.functional.adaptive_avg_pool1d"),
 | 
			
		||||
    xfail("nn.functional.adaptive_avg_pool2d"),
 | 
			
		||||
@ -272,8 +298,6 @@ dtensor_fails = {
 | 
			
		||||
    xfail("nn.functional.cosine_similarity"),
 | 
			
		||||
    xfail("nn.functional.ctc_loss"),
 | 
			
		||||
    xfail("nn.functional.dropout"),
 | 
			
		||||
    xfail("nn.functional.dropout2d"),
 | 
			
		||||
    xfail("nn.functional.dropout3d"),
 | 
			
		||||
    xfail("nn.functional.elu"),
 | 
			
		||||
    xfail("nn.functional.fractional_max_pool2d"),
 | 
			
		||||
    xfail("nn.functional.fractional_max_pool3d"),
 | 
			
		||||
@ -307,7 +331,6 @@ dtensor_fails = {
 | 
			
		||||
    xfail("nn.functional.multi_margin_loss"),
 | 
			
		||||
    xfail("nn.functional.multilabel_margin_loss"),
 | 
			
		||||
    xfail("nn.functional.multilabel_soft_margin_loss"),
 | 
			
		||||
    xfail("nn.functional.multi_head_attention_forward"),
 | 
			
		||||
    xfail("nn.functional.pad", "reflect"),
 | 
			
		||||
    xfail("nn.functional.pad", "replicate"),
 | 
			
		||||
    xfail("nn.functional.pad", "replicate_negative"),
 | 
			
		||||
@ -482,13 +505,21 @@ dtensor_fails = {
 | 
			
		||||
    skip("_segment_reduce", "offsets"),
 | 
			
		||||
    # TODO: fix the following ops
 | 
			
		||||
    skip("squeeze"),
 | 
			
		||||
    # These must be skipped as their contents are nondeterministic
 | 
			
		||||
    skip("empty"),
 | 
			
		||||
    skip("empty_strided"),
 | 
			
		||||
    skip("empty_like"),
 | 
			
		||||
    skip("empty_permuted"),
 | 
			
		||||
    skip("new_empty"),
 | 
			
		||||
    skip("new_empty_strided"),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
dtensor_multi_threaded_fails = {
 | 
			
		||||
    xfail("full_like"),
 | 
			
		||||
    xfail("nn.functional.dropout2d"),
 | 
			
		||||
    xfail("nn.functional.dropout3d"),
 | 
			
		||||
    xfail("masked.cumprod"),
 | 
			
		||||
    skip("nn.functional.multi_head_attention_forward"),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Add a list of ops that are currently failing BW pass
 | 
			
		||||
skip_bw = [
 | 
			
		||||
@ -507,7 +538,13 @@ OP_DB_WORLD_SIZE = 4
 | 
			
		||||
DEVICE_TYPE = "cpu"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestDTensorOps(DTensorOpTestBase):
 | 
			
		||||
class TestDTensorOps(TestCase):
 | 
			
		||||
    __test__ = False
 | 
			
		||||
 | 
			
		||||
    def __init_subclass__(cls, **kwargs):
 | 
			
		||||
        super().__init_subclass__(**kwargs)
 | 
			
		||||
        cls.__test__ = True
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def world_size(self) -> int:
 | 
			
		||||
        return OP_DB_WORLD_SIZE
 | 
			
		||||
@ -535,14 +572,6 @@ class TestDTensorOps(DTensorOpTestBase):
 | 
			
		||||
 | 
			
		||||
        self.check_dtensor_func(test, op)
 | 
			
		||||
 | 
			
		||||
    # only allow float dytpe for now, we can relax this constraint
 | 
			
		||||
    # when feel necessary later (i.e when adding quantization support).
 | 
			
		||||
    @suppress_warnings
 | 
			
		||||
    @ops(op_db, allowed_dtypes=(torch.float,))
 | 
			
		||||
    @skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails)
 | 
			
		||||
    def test_dtensor_op_db(self, dtype, op):
 | 
			
		||||
        self.run_opinfo_test(dtype, op)
 | 
			
		||||
 | 
			
		||||
    def assert_ref_dtensor_equal(self, dtensor_rs, rs):
 | 
			
		||||
        flat_dtensor_rs = pytree.tree_leaves(dtensor_rs)
 | 
			
		||||
        flat_rs = pytree.tree_leaves(rs)
 | 
			
		||||
@ -567,6 +596,9 @@ class TestDTensorOps(DTensorOpTestBase):
 | 
			
		||||
 | 
			
		||||
            self.assertEqualOnRank(dtensor_r, r)
 | 
			
		||||
 | 
			
		||||
    def assertEqualOnRank(self, x, y, msg=None, *, rank=0) -> None:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def run_dtensor_crossref(self, func, args, kwargs):
 | 
			
		||||
        to_dtensor = DTensorConverter(self.mesh, args, kwargs)
 | 
			
		||||
 | 
			
		||||
@ -580,7 +612,8 @@ class TestDTensorOps(DTensorOpTestBase):
 | 
			
		||||
                return res
 | 
			
		||||
 | 
			
		||||
        # TODO: also handle cases where func raise an exception
 | 
			
		||||
        rs = func(*args, **kwargs)
 | 
			
		||||
        op_args, op_kwargs = reconcile_args(args, kwargs)
 | 
			
		||||
        rs = func(*op_args, **op_kwargs)
 | 
			
		||||
        rs = concat_res_if_necessary(func, rs)
 | 
			
		||||
 | 
			
		||||
        def to_replicate(e: object) -> object:
 | 
			
		||||
@ -635,12 +668,12 @@ class TestDTensorOps(DTensorOpTestBase):
 | 
			
		||||
                        self.assert_ref_dtensor_equal(dtensor_rs, rs)
 | 
			
		||||
                    else:
 | 
			
		||||
                        raise RuntimeError(
 | 
			
		||||
                            f"failed to convert args to DTensor; "
 | 
			
		||||
                            f"Failed to convert args to DTensor; "
 | 
			
		||||
                            f"originally (*{args}, **{kwargs})"
 | 
			
		||||
                        )
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    raise RuntimeError(
 | 
			
		||||
                        f"{str(e)}\n\nfailed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
 | 
			
		||||
                        f"{str(e)}\n\nFailed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
 | 
			
		||||
                    ) from e
 | 
			
		||||
        return rs
 | 
			
		||||
 | 
			
		||||
@ -656,7 +689,7 @@ class TestDTensorOps(DTensorOpTestBase):
 | 
			
		||||
                else:
 | 
			
		||||
                    print(f"xfail('{opinfo.name}'),")
 | 
			
		||||
 | 
			
		||||
    def test_one_hot(self):
 | 
			
		||||
    def run_one_hot(self):
 | 
			
		||||
        ops = [op for op in op_db if op.name == "nn.functional.one_hot"]
 | 
			
		||||
        assert len(ops) == 1
 | 
			
		||||
        op = ops[0]
 | 
			
		||||
@ -668,7 +701,7 @@ class TestDTensorOps(DTensorOpTestBase):
 | 
			
		||||
            sample_inputs_filter=lambda s: s.kwargs["num_classes"] != -1,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_mean(self):
 | 
			
		||||
    def run_mean(self):
 | 
			
		||||
        self.mesh = init_device_mesh(DEVICE_TYPE, (self.world_size,))
 | 
			
		||||
 | 
			
		||||
        shape = [2 * self.world_size + 1, 2 * self.world_size]
 | 
			
		||||
@ -692,6 +725,7 @@ class TestDTensorOps(DTensorOpTestBase):
 | 
			
		||||
                full_tensor = mean.full_tensor()
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(full_tensor, tensor.mean(dim=reduce_dim))
 | 
			
		||||
 | 
			
		||||
            if is_evenly_shardable:
 | 
			
		||||
                self.assertTrue("P->R" in debug_mode.debug_string())
 | 
			
		||||
            else:
 | 
			
		||||
@ -720,9 +754,76 @@ class TestDTensorOps(DTensorOpTestBase):
 | 
			
		||||
            _ = torch.ops.aten.embedding.default(weight_dtensor, input_dtensor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
 | 
			
		||||
instantiate_device_type_tests(TestDTensorOps, globals(), only_for=(DEVICE_TYPE,))
 | 
			
		||||
class TestMultiThreadedDTensorOps(DTensorOpTestBase, TestDTensorOps):
 | 
			
		||||
    _op_db = repurpose_ops(op_db, "TestDTensorOps", "TestMultiThreadedDTensorOps")
 | 
			
		||||
 | 
			
		||||
    @suppress_warnings
 | 
			
		||||
    @ops(_op_db, allowed_dtypes=(torch.float,))
 | 
			
		||||
    @skipOps(
 | 
			
		||||
        _op_db,
 | 
			
		||||
        "TestMultiThreadedDTensorOps",
 | 
			
		||||
        "test_dtensor_op_db",
 | 
			
		||||
        dtensor_fails | dtensor_multi_threaded_fails,
 | 
			
		||||
    )
 | 
			
		||||
    def test_dtensor_op_db(self, dtype, op):
 | 
			
		||||
        self.run_opinfo_test(dtype, op)
 | 
			
		||||
 | 
			
		||||
    def test_mean(self):
 | 
			
		||||
        self.run_mean()
 | 
			
		||||
 | 
			
		||||
    def test_one_hot(self):
 | 
			
		||||
        self.run_one_hot()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestLocalDTensorOps(TestDTensorOps):
 | 
			
		||||
    _op_db = repurpose_ops(op_db, "TestDTensorOps", "TestLocalDTensorOps")
 | 
			
		||||
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        torch.distributed.init_process_group("fake", rank=0, world_size=self.world_size)
 | 
			
		||||
        self.fake_pg = torch.distributed.distributed_c10d._get_default_group()
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        super().tearDown()
 | 
			
		||||
        try:
 | 
			
		||||
            dist.destroy_process_group()
 | 
			
		||||
        except AssertionError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    @suppress_warnings
 | 
			
		||||
    @ops(_op_db, allowed_dtypes=(torch.float,))
 | 
			
		||||
    @skipOps(
 | 
			
		||||
        _op_db,
 | 
			
		||||
        "TestLocalDTensorOps",
 | 
			
		||||
        "test_dtensor_op_db",
 | 
			
		||||
        dtensor_fails,
 | 
			
		||||
    )
 | 
			
		||||
    def test_dtensor_op_db(self, dtype, op):
 | 
			
		||||
        self.run_opinfo_test(dtype, op)
 | 
			
		||||
 | 
			
		||||
    def test_mean(self):
 | 
			
		||||
        with LocalTensorMode(frozenset(range(0, self.world_size))):
 | 
			
		||||
            self.run_mean()
 | 
			
		||||
 | 
			
		||||
    def test_one_hot(self):
 | 
			
		||||
        self.run_one_hot()
 | 
			
		||||
 | 
			
		||||
    def run_opinfo_test(
 | 
			
		||||
        self, dtype, op, requires_grad=True, sample_inputs_filter=lambda s: True
 | 
			
		||||
    ):
 | 
			
		||||
        with LocalTensorMode(frozenset(range(0, self.world_size))):
 | 
			
		||||
            super().run_opinfo_test(dtype, op, requires_grad, sample_inputs_filter)
 | 
			
		||||
 | 
			
		||||
    def assertEqualOnRank(self, x, y, msg=None, *, rank=0):
 | 
			
		||||
        self.assertEqual(x, y, msg)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
 | 
			
		||||
instantiate_device_type_tests(
 | 
			
		||||
    TestMultiThreadedDTensorOps, globals(), only_for=(DEVICE_TYPE,)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
instantiate_device_type_tests(TestLocalDTensorOps, globals(), only_for=(DEVICE_TYPE,))
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    run_tests()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										415
									
								
								test/distributed/test_local_tensor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										415
									
								
								test/distributed/test_local_tensor.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,415 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates
 | 
			
		||||
# Owner(s): ["oncall: distributed"]
 | 
			
		||||
 | 
			
		||||
from contextlib import nullcontext
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
from torch.distributed._local_tensor import (
 | 
			
		||||
    local_tensor_mode,
 | 
			
		||||
    LocalTensor,
 | 
			
		||||
    LocalTensorMode,
 | 
			
		||||
)
 | 
			
		||||
from torch.distributed.tensor import (
 | 
			
		||||
    DeviceMesh,
 | 
			
		||||
    distribute_tensor,
 | 
			
		||||
    init_device_mesh,
 | 
			
		||||
    Partial,
 | 
			
		||||
    Replicate,
 | 
			
		||||
    Shard,
 | 
			
		||||
)
 | 
			
		||||
from torch.testing._internal.common_utils import run_tests, TestCase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LocalTensorTestBase(TestCase):
 | 
			
		||||
    def assertEqual(self, lhs, rhs, **kwargs):
 | 
			
		||||
        mode = local_tensor_mode()
 | 
			
		||||
        with nullcontext() if mode is None else mode.disable():
 | 
			
		||||
            if isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor):
 | 
			
		||||
                assert isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor)
 | 
			
		||||
                super().assertEqual(lhs._ranks, rhs._ranks)
 | 
			
		||||
                for r in lhs._ranks:
 | 
			
		||||
                    super().assertEqual(
 | 
			
		||||
                        lhs._local_tensors[r],
 | 
			
		||||
                        rhs._local_tensors[r],
 | 
			
		||||
                        lambda m: f"rank {r}: {m}",
 | 
			
		||||
                    )
 | 
			
		||||
            elif isinstance(lhs, LocalTensor) or isinstance(rhs, LocalTensor):
 | 
			
		||||
                lhs, rhs = (lhs, rhs) if isinstance(lhs, LocalTensor) else (rhs, lhs)
 | 
			
		||||
                for r in lhs._ranks:
 | 
			
		||||
                    super().assertEqual(
 | 
			
		||||
                        lhs._local_tensors[r], rhs, lambda m: f"rank {r}: {m}"
 | 
			
		||||
                    )
 | 
			
		||||
            else:
 | 
			
		||||
                return super().assertEqual(lhs, rhs, **kwargs)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def world_size(self):
 | 
			
		||||
        raise NotImplementedError("override world-size in your subclass")
 | 
			
		||||
 | 
			
		||||
    def build_device_mesh(self) -> DeviceMesh:
 | 
			
		||||
        return init_device_mesh("cpu", (self.world_size,))
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        torch.distributed.init_process_group(
 | 
			
		||||
            # TODO: test other ranks too
 | 
			
		||||
            "fake",
 | 
			
		||||
            rank=0,
 | 
			
		||||
            world_size=self.world_size,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        super().tearDown()
 | 
			
		||||
        try:
 | 
			
		||||
            dist.destroy_process_group()
 | 
			
		||||
        except AssertionError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestLocalTensorWorld2(LocalTensorTestBase):
 | 
			
		||||
    world_size = 2
 | 
			
		||||
 | 
			
		||||
    def test_local_tensor_dtype_consistency(self):
 | 
			
		||||
        """Test that LocalTensor enforces dtype consistency."""
 | 
			
		||||
        device = torch.device("cpu")
 | 
			
		||||
        shape = (2, 3)
 | 
			
		||||
 | 
			
		||||
        inconsistent_tensors = {
 | 
			
		||||
            0: torch.randn(shape, dtype=torch.float32, device=device),
 | 
			
		||||
            1: torch.randn(
 | 
			
		||||
                shape, dtype=torch.float64, device=device
 | 
			
		||||
            ),  # Different dtype
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(AssertionError):
 | 
			
		||||
            LocalTensor(inconsistent_tensors)
 | 
			
		||||
 | 
			
		||||
    def test_local_tensor_creation_fails_with_grad_tensors(self):
 | 
			
		||||
        """Test that LocalTensor creation fails when local tensors have requires_grad=True."""
 | 
			
		||||
        device = torch.device("cpu")
 | 
			
		||||
        shape = (2, 3)
 | 
			
		||||
        dtype = torch.float32
 | 
			
		||||
 | 
			
		||||
        # Create sample local tensors for different ranks
 | 
			
		||||
        local_tensors = {
 | 
			
		||||
            0: torch.randn(shape, dtype=dtype, device=device, requires_grad=True),
 | 
			
		||||
            1: torch.randn(shape, dtype=dtype, device=device, requires_grad=True),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(AssertionError):
 | 
			
		||||
            LocalTensor(local_tensors)
 | 
			
		||||
 | 
			
		||||
        # TODO: test flatten/unflatten
 | 
			
		||||
 | 
			
		||||
    def test_basic_arithmetic_operations(self):
 | 
			
		||||
        """Test basic arithmetic operations on LocalTensors."""
 | 
			
		||||
        device = torch.device("cpu")
 | 
			
		||||
        shape = (2, 3)
 | 
			
		||||
        dtype = torch.float32
 | 
			
		||||
 | 
			
		||||
        # Create identical local tensors for consistency tests
 | 
			
		||||
        base_tensor = torch.randn(shape, dtype=dtype, device=device)
 | 
			
		||||
        identical_local_tensors = {
 | 
			
		||||
            0: base_tensor.clone(),
 | 
			
		||||
            1: base_tensor.clone(),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        lt1 = LocalTensor(identical_local_tensors)
 | 
			
		||||
        lt2 = LocalTensor(identical_local_tensors)
 | 
			
		||||
 | 
			
		||||
        # Test addition
 | 
			
		||||
        result_add = lt1 + lt2
 | 
			
		||||
        self.assertIsInstance(result_add, LocalTensor)
 | 
			
		||||
        self.assertEqual(len(result_add._local_tensors), 2)
 | 
			
		||||
 | 
			
		||||
        # Verify the operation was applied to each local tensor
 | 
			
		||||
        for rank in identical_local_tensors.keys():
 | 
			
		||||
            expected = identical_local_tensors[rank] + identical_local_tensors[rank]
 | 
			
		||||
            self.assertEqual(result_add._local_tensors[rank], expected)
 | 
			
		||||
 | 
			
		||||
        # Test multiplication
 | 
			
		||||
        result_mul = lt1 * 2.0
 | 
			
		||||
        self.assertIsInstance(result_mul, LocalTensor)
 | 
			
		||||
        for rank in identical_local_tensors.keys():
 | 
			
		||||
            expected = identical_local_tensors[rank] * 2.0
 | 
			
		||||
            self.assertEqual(result_mul._local_tensors[rank], expected)
 | 
			
		||||
 | 
			
		||||
    # TODO: consider an op-info test; we don't actually need to cover all ops
 | 
			
		||||
    # but it will help make sure views and more exotic things are done
 | 
			
		||||
    # correctly (in standard subclass style)
 | 
			
		||||
 | 
			
		||||
    def test_mixed_operations_with_regular_tensors(self):
 | 
			
		||||
        """Test operations between LocalTensors and regular tensors."""
 | 
			
		||||
        device = torch.device("cpu")
 | 
			
		||||
        shape = (2, 3)
 | 
			
		||||
        dtype = torch.float32
 | 
			
		||||
 | 
			
		||||
        # Create identical local tensors for consistency tests
 | 
			
		||||
        base_tensor = torch.randn(shape, dtype=dtype, device=device)
 | 
			
		||||
        identical_local_tensors = {
 | 
			
		||||
            0: base_tensor.clone(),
 | 
			
		||||
            1: base_tensor.clone(),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        lt = LocalTensor(identical_local_tensors)
 | 
			
		||||
        regular_tensor = torch.ones_like(identical_local_tensors[0])
 | 
			
		||||
 | 
			
		||||
        # Test LocalTensor + regular tensor
 | 
			
		||||
        result = lt + regular_tensor
 | 
			
		||||
        self.assertIsInstance(result, LocalTensor)
 | 
			
		||||
 | 
			
		||||
        for rank in identical_local_tensors.keys():
 | 
			
		||||
            expected = identical_local_tensors[rank] + regular_tensor
 | 
			
		||||
            self.assertEqual(result._local_tensors[rank], expected)
 | 
			
		||||
 | 
			
		||||
    def test_local_tensor_mode(self):
 | 
			
		||||
        """Test LocalTensorMode functionality."""
 | 
			
		||||
        device = torch.device("cpu")
 | 
			
		||||
        shape = (2, 3)
 | 
			
		||||
        dtype = torch.float32
 | 
			
		||||
 | 
			
		||||
        # Create identical local tensors for consistency tests
 | 
			
		||||
        base_tensor = torch.randn(shape, dtype=dtype, device=device)
 | 
			
		||||
        identical_local_tensors = {
 | 
			
		||||
            0: base_tensor.clone(),
 | 
			
		||||
            1: base_tensor.clone(),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        lt = LocalTensor(identical_local_tensors)
 | 
			
		||||
 | 
			
		||||
        with LocalTensorMode(lt._ranks):
 | 
			
		||||
            result = lt + 1.0
 | 
			
		||||
            self.assertIsInstance(result, LocalTensor)
 | 
			
		||||
 | 
			
		||||
            regular = torch.ones(2, 2)
 | 
			
		||||
            regular_result = regular + 1.0
 | 
			
		||||
            self.assertIsInstance(regular, LocalTensor)
 | 
			
		||||
            self.assertIsInstance(regular_result, LocalTensor)
 | 
			
		||||
 | 
			
		||||
    def test_empty_local_tensors(self):
 | 
			
		||||
        """Test behavior with empty local tensors dict."""
 | 
			
		||||
        # TODO: raise a better error here
 | 
			
		||||
        with self.assertRaises(StopIteration):  # next() on empty iterator
 | 
			
		||||
            LocalTensor({})
 | 
			
		||||
 | 
			
		||||
    def test_collectives_within_local_tensor_mode(self):
 | 
			
		||||
        """Test that collective operations work within LocalTensorMode context."""
 | 
			
		||||
        test_tensors = {
 | 
			
		||||
            0: torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
 | 
			
		||||
            1: torch.tensor([[5.0, 6.0], [7.0, 8.0]]),
 | 
			
		||||
        }
 | 
			
		||||
        lt = LocalTensor(test_tensors)
 | 
			
		||||
        fake_pg = torch.distributed.distributed_c10d._get_default_group()
 | 
			
		||||
 | 
			
		||||
        with LocalTensorMode(lt._ranks):
 | 
			
		||||
            # Test all_reduce within mode
 | 
			
		||||
            lt_sum = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
 | 
			
		||||
            dist.all_reduce(lt_sum, group=fake_pg)
 | 
			
		||||
 | 
			
		||||
            expected_sum = torch.tensor([[6.0, 8.0], [10.0, 12.0]])
 | 
			
		||||
            for rank in test_tensors.keys():
 | 
			
		||||
                self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
 | 
			
		||||
 | 
			
		||||
            # Test broadcast within mode
 | 
			
		||||
            lt_broadcast = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
 | 
			
		||||
            dist.broadcast(lt_broadcast, src=0, group=fake_pg)
 | 
			
		||||
 | 
			
		||||
            for rank in test_tensors.keys():
 | 
			
		||||
                self.assertEqual(lt_broadcast._local_tensors[rank], test_tensors[0])
 | 
			
		||||
 | 
			
		||||
            # Test that regular operations still work
 | 
			
		||||
            result = lt + 1.0
 | 
			
		||||
            self.assertIsInstance(result, LocalTensor)
 | 
			
		||||
 | 
			
		||||
    def test_scalar_mul_reduction_bug(self):
 | 
			
		||||
        with LocalTensorMode(self.world_size):
 | 
			
		||||
            mesh = self.build_device_mesh()
 | 
			
		||||
 | 
			
		||||
            tensor = torch.tensor([10, 10]).float()
 | 
			
		||||
            dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
 | 
			
		||||
            y = dt.sum() * 1  # noqa: F841
 | 
			
		||||
 | 
			
		||||
            tensor = torch.arange(10).reshape(10, 1).float().requires_grad_()
 | 
			
		||||
            dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
 | 
			
		||||
 | 
			
		||||
            print(dt.sum() * 1, dt.sum() * 2, dt.sum() * 3)
 | 
			
		||||
 | 
			
		||||
    def test_uneven_sharding_mean_bug(self):
 | 
			
		||||
        with LocalTensorMode(self.world_size):
 | 
			
		||||
            mesh = self.build_device_mesh()
 | 
			
		||||
            tensor = torch.arange(12).reshape(-1, 4).float()
 | 
			
		||||
 | 
			
		||||
            dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
 | 
			
		||||
 | 
			
		||||
            mean = dt.mean()
 | 
			
		||||
            self.assertEqual(mean.placements, [Replicate()])
 | 
			
		||||
            full = mean.full_tensor()
 | 
			
		||||
            self.assertEqual(tensor.mean(), full)
 | 
			
		||||
 | 
			
		||||
    def test_uneven_sharding_prod(self):
 | 
			
		||||
        with LocalTensorMode(self.world_size):
 | 
			
		||||
            mesh = self.build_device_mesh()
 | 
			
		||||
            tensor = (torch.arange(12) + 1).reshape(-1, 4).float()
 | 
			
		||||
 | 
			
		||||
            dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
 | 
			
		||||
 | 
			
		||||
            x = dt.prod()
 | 
			
		||||
            full = x.full_tensor()
 | 
			
		||||
            self.assertEqual(tensor.prod(), full)
 | 
			
		||||
 | 
			
		||||
    def test_even_sharding_mean_is_partial(self):
 | 
			
		||||
        with LocalTensorMode(self.world_size):
 | 
			
		||||
            mesh = self.build_device_mesh()
 | 
			
		||||
            tensor = torch.arange(16).reshape(4, 4).float()
 | 
			
		||||
 | 
			
		||||
            dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
 | 
			
		||||
 | 
			
		||||
            mean = dt.mean()
 | 
			
		||||
            full = mean.full_tensor()
 | 
			
		||||
            self.assertEqual(tensor.mean(), full)
 | 
			
		||||
            self.assertEqual(mean.placements, [Partial("avg")])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestLocalTensorWorld3(LocalTensorTestBase):
 | 
			
		||||
    world_size = 3
 | 
			
		||||
 | 
			
		||||
    def test_collective_reduction_operations(self):
 | 
			
		||||
        """Test different reduction operations for all_reduce."""
 | 
			
		||||
        # Create different tensors for each rank with simple values for testing
 | 
			
		||||
        test_tensors = {
 | 
			
		||||
            0: torch.tensor([[1.0, 4.0], [2.0, 5.0]]),
 | 
			
		||||
            1: torch.tensor([[2.0, 1.0], [3.0, 6.0]]),
 | 
			
		||||
            2: torch.tensor([[3.0, 2.0], [1.0, 4.0]]),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        fake_pg = torch.distributed.distributed_c10d._get_default_group()
 | 
			
		||||
 | 
			
		||||
        # Test SUM reduction
 | 
			
		||||
        lt_sum = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
 | 
			
		||||
        dist.all_reduce(lt_sum, op=dist.ReduceOp.SUM, group=fake_pg)
 | 
			
		||||
        expected_sum = torch.tensor([[6.0, 7.0], [6.0, 15.0]])  # Sum of all tensors
 | 
			
		||||
        for rank in test_tensors.keys():
 | 
			
		||||
            self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
 | 
			
		||||
 | 
			
		||||
        # Test MAX reduction
 | 
			
		||||
        lt_max = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
 | 
			
		||||
        dist.all_reduce(lt_max, op=dist.ReduceOp.MAX, group=fake_pg)
 | 
			
		||||
        expected_max = torch.tensor([[3.0, 4.0], [3.0, 6.0]])  # Max across all tensors
 | 
			
		||||
        for rank in test_tensors.keys():
 | 
			
		||||
            self.assertEqual(lt_max._local_tensors[rank], expected_max)
 | 
			
		||||
 | 
			
		||||
        # Test MIN reduction
 | 
			
		||||
        lt_min = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
 | 
			
		||||
        dist.all_reduce(lt_min, op=dist.ReduceOp.MIN, group=fake_pg)
 | 
			
		||||
        expected_min = torch.tensor([[1.0, 1.0], [1.0, 4.0]])  # Min across all tensors
 | 
			
		||||
        for rank in test_tensors.keys():
 | 
			
		||||
            self.assertEqual(lt_min._local_tensors[rank], expected_min)
 | 
			
		||||
 | 
			
		||||
    def test_all_reduce_collective(self):
 | 
			
		||||
        """Test that all_reduce collective operation works correctly with LocalTensor."""
 | 
			
		||||
        # Create different tensors for each rank
 | 
			
		||||
        different_tensors = {
 | 
			
		||||
            0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
 | 
			
		||||
            1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]),
 | 
			
		||||
            2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        fake_pg = torch.distributed.distributed_c10d._get_default_group()
 | 
			
		||||
 | 
			
		||||
        # Test all_reduce with SUM (default)
 | 
			
		||||
        lt_sum = LocalTensor({k: v.clone() for k, v in different_tensors.items()})
 | 
			
		||||
        lt_sum = lt_sum + 1
 | 
			
		||||
        dist.all_reduce(lt_sum, group=fake_pg)
 | 
			
		||||
 | 
			
		||||
        # Verify all ranks have the sum of all tensors (after adding 1 to each)
 | 
			
		||||
        expected_sum = torch.tensor([[114.0, 225.0, 336.0], [447.0, 558.0, 669.0]])
 | 
			
		||||
        for rank in different_tensors.keys():
 | 
			
		||||
            self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
 | 
			
		||||
 | 
			
		||||
    def test_broadcast_collective(self):
 | 
			
		||||
        """Test that broadcast collective operation works correctly with LocalTensor."""
 | 
			
		||||
        # Create different tensors for each rank
 | 
			
		||||
        different_tensors = {
 | 
			
		||||
            0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
 | 
			
		||||
            1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]),
 | 
			
		||||
            2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        fake_pg = torch.distributed.distributed_c10d._get_default_group()
 | 
			
		||||
 | 
			
		||||
        # Test broadcast from rank 1
 | 
			
		||||
        lt_broadcast = LocalTensor({k: v.clone() for k, v in different_tensors.items()})
 | 
			
		||||
        dist.broadcast(lt_broadcast, src=1, group=fake_pg)
 | 
			
		||||
 | 
			
		||||
        # Verify all ranks have rank 1's original tensor
 | 
			
		||||
        expected_broadcast = different_tensors[1]
 | 
			
		||||
        for rank in different_tensors.keys():
 | 
			
		||||
            self.assertEqual(lt_broadcast._local_tensors[rank], expected_broadcast)
 | 
			
		||||
 | 
			
		||||
    def test_all_gather_collective(self):
 | 
			
		||||
        """Test that all_gather collective operation works correctly with LocalTensor."""
 | 
			
		||||
        # Create different tensors for each rank
 | 
			
		||||
        different_tensors = {
 | 
			
		||||
            0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
 | 
			
		||||
            1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]),
 | 
			
		||||
            2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        fake_pg = torch.distributed.distributed_c10d._get_default_group()
 | 
			
		||||
 | 
			
		||||
        # Test all_gather
 | 
			
		||||
        lt_gather = LocalTensor(different_tensors)
 | 
			
		||||
        tensor_list = [torch.zeros_like(lt_gather) for _ in range(3)]
 | 
			
		||||
 | 
			
		||||
        dist.all_gather(tensor_list, lt_gather, group=fake_pg)
 | 
			
		||||
 | 
			
		||||
        # Verify each position in tensor_list contains the corresponding rank's tensor
 | 
			
		||||
        self.assertEqual(tensor_list[0], different_tensors[0])
 | 
			
		||||
        self.assertEqual(tensor_list[1], different_tensors[1])
 | 
			
		||||
        self.assertEqual(tensor_list[2], different_tensors[2])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestLocalTensorWorld4(LocalTensorTestBase):
 | 
			
		||||
    world_size = 4
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_cat(self):
 | 
			
		||||
        with LocalTensorMode(self.world_size):
 | 
			
		||||
            device_mesh = self.build_device_mesh()
 | 
			
		||||
 | 
			
		||||
            t1 = torch.arange(16).view(4, 4).float()
 | 
			
		||||
            d1 = distribute_tensor(t1, device_mesh, [Replicate()])
 | 
			
		||||
            t2 = (torch.arange(16) + 16).view(4, 4).float()
 | 
			
		||||
            d2 = distribute_tensor(t2, device_mesh, [Shard(0)])
 | 
			
		||||
 | 
			
		||||
            local_res = torch.cat([t1, t2], dim=-1)
 | 
			
		||||
            dist_res = torch.cat([d1, d2], dim=-1)
 | 
			
		||||
            full_tensor = dist_res.full_tensor()
 | 
			
		||||
            self.assertEqual(full_tensor, local_res)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestLocalTensorWorld8(LocalTensorTestBase):
 | 
			
		||||
    world_size = 8
 | 
			
		||||
 | 
			
		||||
    def test_dtensor_addmm(self):
 | 
			
		||||
        with LocalTensorMode(self.world_size):
 | 
			
		||||
            device_mesh = self.build_device_mesh()
 | 
			
		||||
 | 
			
		||||
            shard_spec = [Shard(0)]
 | 
			
		||||
            replica_spec = [Replicate()]
 | 
			
		||||
 | 
			
		||||
            tensor_to_shard = torch.randn(12, 8)
 | 
			
		||||
            mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
 | 
			
		||||
            tensor_to_replicate = torch.randn(8, 4)
 | 
			
		||||
            mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec)
 | 
			
		||||
            input_tensor = torch.randn(4)
 | 
			
		||||
            input = distribute_tensor(input_tensor, device_mesh, replica_spec)
 | 
			
		||||
 | 
			
		||||
            dist_res = torch.addmm(input, mat1, mat2)
 | 
			
		||||
            local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate)
 | 
			
		||||
            full_tensor = dist_res.full_tensor()
 | 
			
		||||
            self.assertEqual(full_tensor, local_res)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    run_tests()
 | 
			
		||||
@ -7915,9 +7915,13 @@ torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
        nt = torch.nested.nested_tensor(
 | 
			
		||||
            [
 | 
			
		||||
                torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype)
 | 
			
		||||
                if dtype is torch.bool
 | 
			
		||||
                else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
 | 
			
		||||
                (
 | 
			
		||||
                    torch.randint(
 | 
			
		||||
                        2, (n, *post_seq_len_shape), device=device, dtype=dtype
 | 
			
		||||
                    )
 | 
			
		||||
                    if dtype is torch.bool
 | 
			
		||||
                    else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
 | 
			
		||||
                )
 | 
			
		||||
                for n in range(2, 9)
 | 
			
		||||
            ],
 | 
			
		||||
            layout=torch.jagged,
 | 
			
		||||
@ -7966,9 +7970,13 @@ torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
        nt = torch.nested.nested_tensor(
 | 
			
		||||
            [
 | 
			
		||||
                torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype)
 | 
			
		||||
                if dtype is torch.bool
 | 
			
		||||
                else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
 | 
			
		||||
                (
 | 
			
		||||
                    torch.randint(
 | 
			
		||||
                        2, (n, *post_seq_len_shape), device=device, dtype=dtype
 | 
			
		||||
                    )
 | 
			
		||||
                    if dtype is torch.bool
 | 
			
		||||
                    else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
 | 
			
		||||
                )
 | 
			
		||||
                for n in range(2, 9)
 | 
			
		||||
            ],
 | 
			
		||||
            layout=torch.jagged,
 | 
			
		||||
@ -8713,7 +8721,7 @@ COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
 | 
			
		||||
    # min() / max(): weird bug
 | 
			
		||||
    XFailRule(
 | 
			
		||||
        error_type=AttributeError,
 | 
			
		||||
        error_msg="'ConstantIntNode' object has no attribute 'add'",
 | 
			
		||||
        error_msg="'NestedIntNode' object has no attribute 'add'",
 | 
			
		||||
        op_match_fn=lambda device, op: (
 | 
			
		||||
            op.full_name in {"max.reduction_with_dim", "min.reduction_with_dim"}
 | 
			
		||||
        ),
 | 
			
		||||
@ -8730,7 +8738,7 @@ COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
 | 
			
		||||
    # copysign(): formula is broken for (T, NT) broadcasting
 | 
			
		||||
    XFailRule(
 | 
			
		||||
        error_type=AttributeError,
 | 
			
		||||
        error_msg="'ConstantIntNode' object has no attribute 'add'",
 | 
			
		||||
        error_msg="'NestedIntNode' object has no attribute 'add'",
 | 
			
		||||
        op_match_fn=lambda device, op: (op.full_name == "copysign"),
 | 
			
		||||
        sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name),
 | 
			
		||||
        name="broken_copysign_compile_backward",
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,11 @@ TORCH_LIBRARY(c10d, m) {
 | 
			
		||||
  m.class_<Work>("Work")
 | 
			
		||||
      .def(torch::init<>())
 | 
			
		||||
      .def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
 | 
			
		||||
  m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
 | 
			
		||||
  m.class_<ReduceOp>("ReduceOp")
 | 
			
		||||
      .def(torch::init<>())
 | 
			
		||||
      .def("op", [](const c10::intrusive_ptr<ReduceOp>& self) -> int64_t {
 | 
			
		||||
        return self->op_;
 | 
			
		||||
      });
 | 
			
		||||
  m.def(
 | 
			
		||||
      "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
 | 
			
		||||
  m.def(
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										747
									
								
								torch/distributed/_local_tensor/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										747
									
								
								torch/distributed/_local_tensor/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,747 @@
 | 
			
		||||
from ast import Call
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
A LocalTensor is a tensor subclass which simulates a tensor that is
 | 
			
		||||
distributed across SPMD ranks.  A LocalTensor might be size N, but in fact
 | 
			
		||||
there are world_size shards/replicas of it stored internally.  When you do a
 | 
			
		||||
plain PyTorch operation on it, we apply the operation to each shard; when you
 | 
			
		||||
do a collective, we do the mathematically equivalent operation on the local
 | 
			
		||||
shards.  A LocalTensor is associated with a list of ranks which specify
 | 
			
		||||
which ranks it holds local tensors for.
 | 
			
		||||
 | 
			
		||||
NB, this is NOT a DataParallel like abstraction where you can run operations
 | 
			
		||||
on multiple different GPUs. It is intended purely for *debugging* purposes,
 | 
			
		||||
the overhead is almost certainly too high to keep eight GPUs (even the C++
 | 
			
		||||
autograd needs multithreading to keep up!)  (It might potentially be possible
 | 
			
		||||
to trace through this with torch.compile and then compile it with CUDA graphs
 | 
			
		||||
but this is currently a non-goal.)
 | 
			
		||||
 | 
			
		||||
We do not directly handling MPMD. However in practice even in SPMD you may
 | 
			
		||||
encounter divergence in behavior per rank (for example, uneven sharding
 | 
			
		||||
across ranks). To support scenarios like this, we provide a helper decorator
 | 
			
		||||
that allows you to run a function with no side effects for each LocalTensor
 | 
			
		||||
shard and combine results back into LocalTensor or LocalIntNode.
 | 
			
		||||
 | 
			
		||||
NB: This is a torch dispatch Tensor subclass, as we want to assume that autograd
 | 
			
		||||
is SPMD, so we run it once, and dispatch the inner autograd calls to the individual
 | 
			
		||||
local shards.
 | 
			
		||||
 | 
			
		||||
NOTE ABOUT MESH:  This subclass requires collectives that are issued to it to
 | 
			
		||||
respect a DeviceMesh like abstraction.  The reason for this is that when
 | 
			
		||||
DTensor issues us a collective for a particular rank, you will be asked to do
 | 
			
		||||
this on a specific process group which involves some ranks.  However, this
 | 
			
		||||
will only be for the LOCAL PG that this particular rank is participating in;
 | 
			
		||||
there will be a bunch of other PGs for other nodes that you don't get to see.
 | 
			
		||||
We need to be able to reverse engineer all of the collectives that don't
 | 
			
		||||
involve the current local rank here to actually issue them.  This can be done
 | 
			
		||||
two ways: (1) looking at the participating local ranks in the PG and computing
 | 
			
		||||
the complement which specifies all the other collectives you have to run, or
 | 
			
		||||
(2) retrieving the device mesh axis corresponding to the PG for this rank, and
 | 
			
		||||
then running all the fibers for this.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import contextlib
 | 
			
		||||
import functools
 | 
			
		||||
import operator
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from types import TracebackType
 | 
			
		||||
from typing import Any, Callable, Generator, Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import Size, SymBool, SymInt, Tensor
 | 
			
		||||
from torch._C import DispatchKey, DispatchKeySet
 | 
			
		||||
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
 | 
			
		||||
from torch.distributed import DeviceMesh
 | 
			
		||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
 | 
			
		||||
from torch.fx.experimental._constant_symnode import ConstantIntNode
 | 
			
		||||
from torch.nested._internal.nested_int import NestedIntNode
 | 
			
		||||
from torch.utils import _pytree as pytree
 | 
			
		||||
from torch.utils._python_dispatch import return_and_correct_aliasing, TorchDispatchMode
 | 
			
		||||
from torch.utils.checkpoint import get_device_states, set_device_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from . import _c10d
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _int_on_rank(i: "LocalIntNode | ConstantIntNode", r: int) -> int:
 | 
			
		||||
    if isinstance(i, LocalIntNode):
 | 
			
		||||
        return i._local_ints[r]
 | 
			
		||||
    elif isinstance(i, ConstantIntNode):
 | 
			
		||||
        return i.val
 | 
			
		||||
    else:
 | 
			
		||||
        raise AssertionError(type(i))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _check_for_subclass(flat_args: Sequence[object]) -> bool:
 | 
			
		||||
    return any(_check_for_subclass_arg(x) for x in flat_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _check_for_subclass_arg(x: object) -> bool:
 | 
			
		||||
    return (
 | 
			
		||||
        not isinstance(x, LocalTensor)
 | 
			
		||||
        and isinstance(x, Tensor)
 | 
			
		||||
        and type(x) not in (Tensor, torch.nn.Parameter, torch.nn.Buffer)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _map_to_rank_local_val(val: Any, rank: int) -> Any:
 | 
			
		||||
    if isinstance(val, LocalTensor):
 | 
			
		||||
        return val._local_tensors[rank]
 | 
			
		||||
    if isinstance(val, SymInt) and isinstance(val.node, LocalIntNode):
 | 
			
		||||
        return val.node._local_ints[rank]
 | 
			
		||||
    return val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _for_each_rank_run_func(
 | 
			
		||||
    func: Callable[..., Any],
 | 
			
		||||
    ranks: frozenset[int],
 | 
			
		||||
    args: Sequence[Any],
 | 
			
		||||
    kwargs: dict[str, Any],
 | 
			
		||||
    *,
 | 
			
		||||
    alias: bool = True,
 | 
			
		||||
) -> Any:
 | 
			
		||||
    flat_args, args_spec = pytree.tree_flatten((args, kwargs))
 | 
			
		||||
 | 
			
		||||
    cpu_state = torch.get_rng_state()
 | 
			
		||||
    devices, states = get_device_states((args, kwargs))
 | 
			
		||||
 | 
			
		||||
    flat_rank_rets = {}
 | 
			
		||||
 | 
			
		||||
    for r in sorted(ranks):
 | 
			
		||||
        torch.set_rng_state(cpu_state)
 | 
			
		||||
        set_device_states(devices, states)
 | 
			
		||||
        rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args]
 | 
			
		||||
        rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec)
 | 
			
		||||
        rank_ret = func(*rank_args, **rank_kwargs)
 | 
			
		||||
        flat_rank_rets[r] = rank_ret
 | 
			
		||||
 | 
			
		||||
    rr_key = next(iter(flat_rank_rets.keys()))
 | 
			
		||||
    rr_val = flat_rank_rets[rr_key]
 | 
			
		||||
 | 
			
		||||
    if isinstance(rr_val, Tensor):
 | 
			
		||||
        ret = LocalTensor({r: flat_rank_rets[r] for r in sorted(ranks)})
 | 
			
		||||
    elif isinstance(rr_val, (list, tuple)):
 | 
			
		||||
        ret_list = []
 | 
			
		||||
        for i in range(len(rr_val)):
 | 
			
		||||
            rets = {r: flat_rank_rets[r][i] for r in sorted(ranks)}
 | 
			
		||||
            v_it = iter(rets.values())
 | 
			
		||||
            v = next(v_it)
 | 
			
		||||
            if isinstance(v, Tensor):
 | 
			
		||||
                ret_list.append(LocalTensor(rets))
 | 
			
		||||
            elif isinstance(v, int) and not all(v == v2 for v2 in v_it):
 | 
			
		||||
                ret_list.append(torch.SymInt(LocalIntNode(rets)))
 | 
			
		||||
            else:
 | 
			
		||||
                assert all(v == v2 for v2 in v_it)
 | 
			
		||||
                ret_list.append(v)
 | 
			
		||||
        ret = type(rr_val)(ret_list)
 | 
			
		||||
    else:
 | 
			
		||||
        v_it = iter(flat_rank_rets.values())
 | 
			
		||||
        v = next(v_it)
 | 
			
		||||
        if all(v == v2 for v2 in v_it):
 | 
			
		||||
            return v
 | 
			
		||||
        if isinstance(v, int):
 | 
			
		||||
            return torch.SymInt(LocalIntNode(flat_rank_rets))
 | 
			
		||||
        raise AssertionError(f"Unexpected return type {type(v)}")
 | 
			
		||||
 | 
			
		||||
    if alias:
 | 
			
		||||
        return return_and_correct_aliasing(func, args, kwargs, ret)
 | 
			
		||||
    else:
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_extra_dispatch_keys(t: torch.Tensor) -> DispatchKeySet:
 | 
			
		||||
    extra_dispatch_keys = torch._C.DispatchKeySet.from_raw_repr(0)
 | 
			
		||||
    if torch._C._dispatch_keys(t).has(torch._C.DispatchKey.Conjugate):
 | 
			
		||||
        extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Conjugate)
 | 
			
		||||
    if torch._C._dispatch_keys(t).has(torch._C.DispatchKey.Negative):
 | 
			
		||||
        extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Negative)
 | 
			
		||||
    return extra_dispatch_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LocalIntNode:
 | 
			
		||||
    """
 | 
			
		||||
    Like a LocalTensor, but for an int.  We can't use a 0D tensor to represent this
 | 
			
		||||
    because often only a SymInt is accepted where we wish to use this.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __new__(cls, local_ints: dict[int, int]) -> "ConstantIntNode | LocalIntNode":  # type: ignore[misc]
 | 
			
		||||
        if len(set(local_ints.values())) == 1:
 | 
			
		||||
            return ConstantIntNode(next(iter(local_ints.values())))
 | 
			
		||||
        return super().__new__(cls)
 | 
			
		||||
 | 
			
		||||
    def __init__(self, local_ints: dict[int, int]):
 | 
			
		||||
        self._local_ints = local_ints
 | 
			
		||||
 | 
			
		||||
    def maybe_as_int(self) -> Optional[int]:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def is_int(self) -> bool:
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def is_float(self) -> bool:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def is_bool(self) -> bool:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def is_nested_int(self) -> bool:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def clone(self) -> "LocalIntNode":
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def _str(self) -> str:
 | 
			
		||||
        return f"LocalIntNode({self._local_ints})"
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        return self._str()
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return self._str()
 | 
			
		||||
 | 
			
		||||
    def _graph_repr(self) -> str:
 | 
			
		||||
        return self._str()
 | 
			
		||||
 | 
			
		||||
    def is_symbolic(self) -> bool:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def is_constant(self) -> bool:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def sym_max(
 | 
			
		||||
        self, other: "LocalIntNode | ConstantIntNode"
 | 
			
		||||
    ) -> "LocalIntNode | ConstantIntNode":
 | 
			
		||||
        return LocalIntNode(
 | 
			
		||||
            {
 | 
			
		||||
                r: max(self._local_ints[r], _int_on_rank(other, r))
 | 
			
		||||
                for r in self._local_ints
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def add(
 | 
			
		||||
        self, other: "LocalIntNode | ConstantIntNode"
 | 
			
		||||
    ) -> "LocalIntNode | ConstantIntNode":
 | 
			
		||||
        return LocalIntNode(
 | 
			
		||||
            {r: self._local_ints[r] + _int_on_rank(other, r) for r in self._local_ints}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def sub(
 | 
			
		||||
        self, other: "LocalIntNode | ConstantIntNode"
 | 
			
		||||
    ) -> "LocalIntNode | ConstantIntNode":
 | 
			
		||||
        return LocalIntNode(
 | 
			
		||||
            {r: self._local_ints[r] - _int_on_rank(other, r) for r in self._local_ints}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def mul(
 | 
			
		||||
        self, other: "LocalIntNode | ConstantIntNode"
 | 
			
		||||
    ) -> "LocalIntNode | ConstantIntNode":
 | 
			
		||||
        return LocalIntNode(
 | 
			
		||||
            {r: self._local_ints[r] * _int_on_rank(other, r) for r in self._local_ints}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def eq(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
 | 
			
		||||
        r = {self._local_ints[r] == _int_on_rank(other, r) for r in self._local_ints}
 | 
			
		||||
        return torch._C._get_constant_bool_symnode(len(r) == 1 and next(iter(r)))
 | 
			
		||||
 | 
			
		||||
    def gt(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
 | 
			
		||||
        r = {self._local_ints[r] > _int_on_rank(other, r) for r in self._local_ints}
 | 
			
		||||
        assert len(r) == 1, (self, other)
 | 
			
		||||
        return torch._C._get_constant_bool_symnode(next(iter(r)))
 | 
			
		||||
 | 
			
		||||
    def lt(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
 | 
			
		||||
        r = {self._local_ints[r] < _int_on_rank(other, r) for r in self._local_ints}
 | 
			
		||||
        assert len(r) == 1, (self, other)
 | 
			
		||||
        return torch._C._get_constant_bool_symnode(next(iter(r)))
 | 
			
		||||
 | 
			
		||||
    def wrap_int(self, num: int) -> "LocalIntNode | ConstantIntNode":
 | 
			
		||||
        return ConstantIntNode(num)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LocalTensor(torch.Tensor):
 | 
			
		||||
    """
 | 
			
		||||
    LocalTensor is a Tensor subclass that simulates a tensor distributed across multiple SPMD
 | 
			
		||||
    (Single Program, Multiple Data) ranks. Each LocalTensor instance internally holds a mapping from
 | 
			
		||||
    global rank ids to their corresponding local Tensor shards.Operations performed on a LocalTensor
 | 
			
		||||
    are applied independently to each local shard, mimicking distributed computation. Collectives
 | 
			
		||||
    and other distributed operations are handled by mapping them to the local shards as appropriate.
 | 
			
		||||
 | 
			
		||||
    Note:
 | 
			
		||||
        This class is primarily intended for debugging and simulating distributed tensor computations
 | 
			
		||||
        on a single process.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Map from global rank to the local tensor.
 | 
			
		||||
    _local_tensors: dict[int, torch.Tensor]
 | 
			
		||||
    # Precomputed for speed set of keys from the local tensor map.
 | 
			
		||||
    _ranks: frozenset[int]
 | 
			
		||||
    __slots__ = ["_local_tensors", "_ranks"]
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @torch._disable_dynamo
 | 
			
		||||
    def __new__(
 | 
			
		||||
        cls,
 | 
			
		||||
        local_tensors: dict[int, torch.Tensor],
 | 
			
		||||
    ) -> "LocalTensor":
 | 
			
		||||
        if any(t.requires_grad for t in local_tensors.values()):
 | 
			
		||||
            raise AssertionError(
 | 
			
		||||
                "Internal local_tensors require grad, but we will ignore those autograd graph. "
 | 
			
		||||
                "Make a custom autograd function and make sure you detach the inner tensors."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        it = iter(local_tensors.values())
 | 
			
		||||
        first_local_tensor = next(it)
 | 
			
		||||
 | 
			
		||||
        first_shape = first_local_tensor.shape
 | 
			
		||||
        first_stride = first_local_tensor.stride()
 | 
			
		||||
        dtype = first_local_tensor.dtype
 | 
			
		||||
        device = first_local_tensor.device
 | 
			
		||||
        layout = first_local_tensor.layout
 | 
			
		||||
 | 
			
		||||
        extra_dispatch_keys = _get_extra_dispatch_keys(first_local_tensor)
 | 
			
		||||
 | 
			
		||||
        # Assert that all tensors have the same dtype, layout and dispatch keys. Due
 | 
			
		||||
        # to uneven sharding, it is possible that tensors will have different shapes.
 | 
			
		||||
        for local_tensor in it:
 | 
			
		||||
            assert dtype == local_tensor.dtype, (
 | 
			
		||||
                "Tensors representing LocalTensor shards must have the same dtype"
 | 
			
		||||
            )
 | 
			
		||||
            assert layout == local_tensor.layout, (
 | 
			
		||||
                "Tensors representing LocalTensor shards must have the same layout"
 | 
			
		||||
            )
 | 
			
		||||
            assert extra_dispatch_keys == _get_extra_dispatch_keys(local_tensor), (
 | 
			
		||||
                "Tensors representing LocalTensor shards must have the same set of extra dispatch keys"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # Compute shape/stride.  We allow for non-SPMD'ness here
 | 
			
		||||
        local_shapes: dict[int, dict[int, int]] = defaultdict(
 | 
			
		||||
            dict
 | 
			
		||||
        )  # dim => rank => size
 | 
			
		||||
        local_strides: dict[int, dict[int, int]] = defaultdict(
 | 
			
		||||
            dict
 | 
			
		||||
        )  # dim => rank => size
 | 
			
		||||
        for r, local_tensor in local_tensors.items():
 | 
			
		||||
            for d, size in enumerate(local_tensor.shape):
 | 
			
		||||
                local_shapes[d][r] = size
 | 
			
		||||
                local_strides[d][r] = local_tensor.stride(d)
 | 
			
		||||
        shape = [
 | 
			
		||||
            (
 | 
			
		||||
                first_shape[d]
 | 
			
		||||
                if len(set(local_shapes[d])) == 1
 | 
			
		||||
                else torch.SymInt(LocalIntNode(local_shapes[d]))
 | 
			
		||||
            )
 | 
			
		||||
            for d in range(len(first_shape))
 | 
			
		||||
        ]
 | 
			
		||||
        strides = [
 | 
			
		||||
            (
 | 
			
		||||
                first_stride[d]
 | 
			
		||||
                if len(set(local_strides[d])) == 1
 | 
			
		||||
                else torch.SymInt(LocalIntNode(local_strides[d]))
 | 
			
		||||
            )
 | 
			
		||||
            for d in range(len(first_shape))
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        r = torch.Tensor._make_wrapper_subclass(
 | 
			
		||||
            cls,
 | 
			
		||||
            shape,
 | 
			
		||||
            strides=strides,
 | 
			
		||||
            dtype=dtype,
 | 
			
		||||
            device=device,
 | 
			
		||||
            layout=layout,
 | 
			
		||||
            requires_grad=False,
 | 
			
		||||
            _extra_dispatch_keys=extra_dispatch_keys,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        local_tensors = {
 | 
			
		||||
            r: v if not isinstance(v, AsyncCollectiveTensor) else v.wait()
 | 
			
		||||
            for r, v in local_tensors.items()
 | 
			
		||||
        }
 | 
			
		||||
        r._local_tensors = local_tensors
 | 
			
		||||
        r._ranks = frozenset(local_tensors.keys())
 | 
			
		||||
        return r
 | 
			
		||||
 | 
			
		||||
    @torch._disable_dynamo
 | 
			
		||||
    @mark_subclass_constructor_exportable_experimental  # type: ignore[misc]
 | 
			
		||||
    def __init__(self, *args: Any, **kwargs: Any):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:  # type: ignore[override]
 | 
			
		||||
        parts = []
 | 
			
		||||
        for k, v in self._local_tensors.items():
 | 
			
		||||
            parts.append(f"  {k}: {v}")
 | 
			
		||||
        tensors_str = ",\n".join(parts)
 | 
			
		||||
        return f"LocalTensor(\n{tensors_str}\n)"
 | 
			
		||||
 | 
			
		||||
    def __tensor_flatten__(self) -> tuple[list[str], tuple[Any, ...]]:
 | 
			
		||||
        """
 | 
			
		||||
        protocol to inform how to flatten a DTensor to local tensor
 | 
			
		||||
        for PT2 tracing
 | 
			
		||||
        """
 | 
			
		||||
        return ["_local_tensors"], ()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def __tensor_unflatten__(
 | 
			
		||||
        inner_tensors: dict[str, Any],
 | 
			
		||||
        flatten_spec: tuple[Any, ...],
 | 
			
		||||
        outer_size: torch.Size,
 | 
			
		||||
        outer_stride: tuple[int, ...],
 | 
			
		||||
    ) -> "LocalTensor":
 | 
			
		||||
        assert flatten_spec is not None, (
 | 
			
		||||
            "Expecting spec to be not None from `__tensor_flatten__` return value!"
 | 
			
		||||
        )
 | 
			
		||||
        local_tensors = inner_tensors["_local_tensors"]
 | 
			
		||||
        return LocalTensor(local_tensors)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @torch._disable_dynamo
 | 
			
		||||
    def __torch_dispatch__(  # type: ignore[override]
 | 
			
		||||
        cls,
 | 
			
		||||
        func: Any,
 | 
			
		||||
        types: tuple[Any, ...],
 | 
			
		||||
        args: tuple[Any, ...] = (),
 | 
			
		||||
        kwargs: dict[str, Any] | None = None,
 | 
			
		||||
    ) -> Any:
 | 
			
		||||
        if kwargs is None:
 | 
			
		||||
            kwargs = {}
 | 
			
		||||
 | 
			
		||||
        # This is horribly inefficient
 | 
			
		||||
        flat_args, args_spec = pytree.tree_flatten((args, kwargs))
 | 
			
		||||
        local_tensor = None
 | 
			
		||||
        for arg in flat_args:
 | 
			
		||||
            if isinstance(arg, LocalTensor):
 | 
			
		||||
                local_tensor = arg
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        assert local_tensor is not None, (
 | 
			
		||||
            "At least one of the arguments must be a LocalTensor"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Check for unrecognized tensor subclasses (but allow regular tensors and scalars)
 | 
			
		||||
        has_unrecognized_types = _check_for_subclass(flat_args)
 | 
			
		||||
        if has_unrecognized_types:
 | 
			
		||||
            unrecognized_types = [
 | 
			
		||||
                type(x) for x in flat_args if _check_for_subclass_arg(x)
 | 
			
		||||
            ]
 | 
			
		||||
            not_implemented_log.debug(
 | 
			
		||||
                "LocalTensor unrecognized subclass(es): %s", unrecognized_types
 | 
			
		||||
            )
 | 
			
		||||
            return NotImplemented
 | 
			
		||||
 | 
			
		||||
        with LocalTensorMode(local_tensor._ranks):
 | 
			
		||||
            return func(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def tolist(self) -> list[Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Reconcile and convert result to list.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        return self.reconcile().tolist()
 | 
			
		||||
 | 
			
		||||
    def reconcile(self) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Reconciles the LocalTensor into a single torch.Tensor by ensuring all local
 | 
			
		||||
        shards are identical and returning a detached clone of one of them.
 | 
			
		||||
 | 
			
		||||
        Note:
 | 
			
		||||
            This method is useful for extracting a representative tensor from a LocalTensor
 | 
			
		||||
            when all shards are expected to be the same, such as after a collective operation
 | 
			
		||||
            that synchronizes all ranks.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Force all local tensor shards across ranks to be the same
 | 
			
		||||
        it = iter(self._local_tensors.values())
 | 
			
		||||
        t1 = next(it)
 | 
			
		||||
        for t2 in it:
 | 
			
		||||
            assert torch.equal(t1, t2), (
 | 
			
		||||
                "LocalTensor shards must be the same to reconcile"
 | 
			
		||||
            )
 | 
			
		||||
        cl = t1.clone().detach()
 | 
			
		||||
        cl.requires_grad_(self.requires_grad)
 | 
			
		||||
        return cl
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LocalTensorMode(TorchDispatchMode):
 | 
			
		||||
    """
 | 
			
		||||
    A TorchDispatchMode that simulates SPMD (Single Program, Multiple Data) execution
 | 
			
		||||
    for LocalTensor objects across a set of ranks.
 | 
			
		||||
 | 
			
		||||
    LocalTensorMode enables PyTorch operations to be transparently applied to each
 | 
			
		||||
    local shard of a LocalTensor, as if they were distributed across multiple ranks.
 | 
			
		||||
    When active, this mode intercepts tensor operations and dispatches them to each
 | 
			
		||||
    rank's local tensor, collecting and wrapping the results as LocalTensors. It also
 | 
			
		||||
    handles collective operations by mapping them to local implementations.
 | 
			
		||||
 | 
			
		||||
    This mode is primarily intended for debugging and simulating distributed tensor
 | 
			
		||||
    computations on a single process, rather than for high-performance distributed
 | 
			
		||||
    training. It maintains a stack of active modes, patches DeviceMesh coordinate
 | 
			
		||||
    resolution, and provides utilities for temporarily disabling the mode or mapping
 | 
			
		||||
    functions over ranks.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # What ranks this local tensor mode is operating over
 | 
			
		||||
    def __init__(self, ranks: Union[int, frozenset[int]]):
 | 
			
		||||
        if isinstance(ranks, int):
 | 
			
		||||
            # assume is world size
 | 
			
		||||
            self.ranks = frozenset(range(ranks))
 | 
			
		||||
        else:
 | 
			
		||||
            assert isinstance(ranks, frozenset)
 | 
			
		||||
            self.ranks = ranks
 | 
			
		||||
        self._disable = False
 | 
			
		||||
        self._old_get_coordinate = None
 | 
			
		||||
 | 
			
		||||
    def __enter__(self) -> "LocalTensorMode":
 | 
			
		||||
        self._disable = False
 | 
			
		||||
        self._patch_device_mesh()
 | 
			
		||||
        _LOCAL_TENSOR_MODE.append(self)
 | 
			
		||||
 | 
			
		||||
        return super().__enter__()
 | 
			
		||||
 | 
			
		||||
    def __exit__(
 | 
			
		||||
        self,
 | 
			
		||||
        exc_type: type[BaseException] | None,
 | 
			
		||||
        exc_val: BaseException | None,
 | 
			
		||||
        exc_tb: TracebackType | None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self._disable = True
 | 
			
		||||
        self._unpatch_device_mesh()
 | 
			
		||||
        _LOCAL_TENSOR_MODE.pop()
 | 
			
		||||
        super().__exit__(exc_type, exc_val, exc_tb)
 | 
			
		||||
 | 
			
		||||
    def __torch_dispatch__(
 | 
			
		||||
        self,
 | 
			
		||||
        func: Any,
 | 
			
		||||
        types: tuple[Any, ...],
 | 
			
		||||
        args: tuple[Any, ...] = (),
 | 
			
		||||
        kwargs: dict[str, Any] | None = None,
 | 
			
		||||
    ) -> Any:
 | 
			
		||||
        if kwargs is None:
 | 
			
		||||
            kwargs = {}
 | 
			
		||||
 | 
			
		||||
        flat_args, args_spec = pytree.tree_flatten((args, kwargs))
 | 
			
		||||
 | 
			
		||||
        # Find all LocalTensor arguments to determine ranks
 | 
			
		||||
        local_tensors = [a for a in flat_args if isinstance(a, LocalTensor)]
 | 
			
		||||
 | 
			
		||||
        # Check for unrecognized tensor subclasses (but allow regular tensors and scalars)
 | 
			
		||||
        has_unrecognized_types = _check_for_subclass(flat_args)
 | 
			
		||||
        if has_unrecognized_types:
 | 
			
		||||
            unrecognized_types = [
 | 
			
		||||
                type(x) for x in flat_args if _check_for_subclass_arg(x)
 | 
			
		||||
            ]
 | 
			
		||||
            not_implemented_log.debug(
 | 
			
		||||
                "LocalTensorMode unrecognized subclass(es): %s", unrecognized_types
 | 
			
		||||
            )
 | 
			
		||||
            return NotImplemented
 | 
			
		||||
 | 
			
		||||
        # Factory functions convert into LocalTensor, so we don't have to
 | 
			
		||||
        # transmute a Tensor into a LocalTensor if mutation happens...
 | 
			
		||||
        # But if you do an operation on a Tensor, do NOT wrap it into a
 | 
			
		||||
        # LocalTensor.  This helps prevent accidents when you're doing Tensor
 | 
			
		||||
        # operations on the inner non-wrapped tensors.
 | 
			
		||||
        if not local_tensors:
 | 
			
		||||
            if self._disable or any(isinstance(a, Tensor) for a in flat_args):
 | 
			
		||||
                return func(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # For LocalTensors, verify they have compatible ranks
 | 
			
		||||
        for a in flat_args:
 | 
			
		||||
            if isinstance(a, LocalTensor):
 | 
			
		||||
                assert a._ranks == self.ranks, (
 | 
			
		||||
                    f"Input LocalTensor {a} and LocalTensorMode must be configured for the same ranks"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        if func.namespace == "c10d":
 | 
			
		||||
            if func is torch.ops.c10d.allreduce_.default:
 | 
			
		||||
                return _c10d._local_all_reduce_(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.allreduce_coalesced_.default:
 | 
			
		||||
                return _c10d._local_allreduce_coalesced_(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.reduce_scatter_tensor_coalesced_.default:
 | 
			
		||||
                return _c10d._local_reduce_scatter_tensor_coalesced_(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.scatter_.default:
 | 
			
		||||
                return _c10d._local_scatter_(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.broadcast_.default:
 | 
			
		||||
                return _c10d._local_broadcast_(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.allgather_.default:
 | 
			
		||||
                return _c10d._local_all_gather_(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.allgather_into_tensor_coalesced_.default:
 | 
			
		||||
                return _c10d._local_allgather_into_tensor_coalesced_(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.gather_.default:
 | 
			
		||||
                return _c10d._local_gather_(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.alltoall_.default:
 | 
			
		||||
                return _c10d._local_alltoall_(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.alltoall_base_.default:
 | 
			
		||||
                return _c10d._local_alltoall_base_(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.barrier.default:
 | 
			
		||||
                return _c10d._local_barrier(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.monitored_barrier_.default:
 | 
			
		||||
                return _c10d._local_monitored_barrier_(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.send.default:
 | 
			
		||||
                return _c10d._local_send(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.recv_.default:
 | 
			
		||||
                return _c10d._local_recv_(*args, **kwargs)
 | 
			
		||||
            elif func is torch.ops.c10d.recv_any_source_.default:
 | 
			
		||||
                return _c10d._local_recv_any_source_(*args, **kwargs)
 | 
			
		||||
            raise NotImplementedError(f"{func} not implemented")
 | 
			
		||||
 | 
			
		||||
        if func.namespace == "_c10d_functional" or func.namespace == "_dtensor":
 | 
			
		||||
            with LocalTensorMode(self.ranks):
 | 
			
		||||
                return func._op_dk(
 | 
			
		||||
                    DispatchKey.CompositeExplicitAutograd, *args, **kwargs
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        if func.namespace == "_c10d_functional_autograd":
 | 
			
		||||
            raise NotImplementedError(f"{func} not implemented")
 | 
			
		||||
 | 
			
		||||
        if func.namespace == "symm_mem":
 | 
			
		||||
            raise NotImplementedError(f"{func} not implemented")
 | 
			
		||||
 | 
			
		||||
        return _for_each_rank_run_func(func, self.ranks, args, kwargs, alias=True)
 | 
			
		||||
 | 
			
		||||
    @contextlib.contextmanager
 | 
			
		||||
    def disable(self) -> Generator[None, None, None]:
 | 
			
		||||
        """
 | 
			
		||||
        Disables LocalTensorMode temporarily. Primarily is intended to be used to perform
 | 
			
		||||
        rank specific computations and merge results back before enabling LocalTensorMode back.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        old = self._disable
 | 
			
		||||
        self._disable = True
 | 
			
		||||
        self._unpatch_device_mesh()
 | 
			
		||||
        try:
 | 
			
		||||
            yield
 | 
			
		||||
        finally:
 | 
			
		||||
            self._disable = old
 | 
			
		||||
            self._patch_device_mesh()
 | 
			
		||||
 | 
			
		||||
    def rank_map(self, cb: Callable[[int], Tensor]) -> LocalTensor:
 | 
			
		||||
        """
 | 
			
		||||
        Creates a LocalTensor instance by mapping rank id to ids local shard.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        with self.disable():
 | 
			
		||||
            return LocalTensor({r: cb(r) for r in self.ranks})
 | 
			
		||||
 | 
			
		||||
    def _patch_device_mesh(self) -> None:
 | 
			
		||||
        assert self._old_get_coordinate is None
 | 
			
		||||
        self._old_get_coordinate = DeviceMesh.get_coordinate  # type: ignore[assignment]
 | 
			
		||||
        DeviceMesh.get_coordinate = _LocalDeviceMesh.get_coordinate  # type: ignore[method-assign]
 | 
			
		||||
 | 
			
		||||
    def _unpatch_device_mesh(self) -> None:
 | 
			
		||||
        assert self._old_get_coordinate is not None
 | 
			
		||||
        DeviceMesh.get_coordinate = self._old_get_coordinate
 | 
			
		||||
        self._old_get_coordinate = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _LocalDeviceMesh:
 | 
			
		||||
    """
 | 
			
		||||
    Holds implementations of DeviceMesh functionality that must be patched while running
 | 
			
		||||
    under LocalTensorMode.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]:
 | 
			
		||||
        lm = local_tensor_mode()
 | 
			
		||||
        assert lm is not None, "Unexpectedly not in LocalTensorMode"
 | 
			
		||||
 | 
			
		||||
        rank_coords = (self.mesh == lm.rank_map(lambda r: torch.tensor(r))).nonzero()
 | 
			
		||||
        # NB: unlike the regular mechanism, we don't allow for MPMD
 | 
			
		||||
        assert rank_coords.size(0) == 1
 | 
			
		||||
        assert isinstance(rank_coords[0], LocalTensor)
 | 
			
		||||
 | 
			
		||||
        coords: list[dict[int, int]] = [{} for _ in range(rank_coords.size(1))]
 | 
			
		||||
        for r, v in rank_coords[0]._local_tensors.items():
 | 
			
		||||
            for i, x in enumerate(v.tolist()):
 | 
			
		||||
                coords[i][r] = x
 | 
			
		||||
        out = [torch.SymInt(LocalIntNode(c)) for c in coords]
 | 
			
		||||
 | 
			
		||||
        return out  # type: ignore[return-value]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reconcile_args(args: Any, kwargs: dict[str, Any] | None = None) -> Any:
 | 
			
		||||
    """
 | 
			
		||||
    Reconciles arguments by converting any LocalTensor instances in the input
 | 
			
		||||
    arguments to their underlying torch.Tensor representation.
 | 
			
		||||
 | 
			
		||||
    This function is typically used to prepare arguments for functions that
 | 
			
		||||
    expect standard torch.Tensor objects, by flattening the input arguments,
 | 
			
		||||
    replacing LocalTensor instances with their reconciled (standard tensor)
 | 
			
		||||
    versions, and then reconstructing the original argument structure.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        args: Positional arguments, possibly containing LocalTensor instances.
 | 
			
		||||
        kwargs: Keyword arguments, possibly containing LocalTensor instances.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Any: The arguments with all LocalTensor instances replaced by their reconciled torch.Tensor equivalents,
 | 
			
		||||
             preserving the original structure.
 | 
			
		||||
    """
 | 
			
		||||
    if kwargs is None:
 | 
			
		||||
        kwargs = {}
 | 
			
		||||
    flat_args, args_spec = pytree.tree_flatten((args, kwargs))
 | 
			
		||||
    reconciled_args = [
 | 
			
		||||
        a.reconcile() if isinstance(a, LocalTensor) else a for a in flat_args
 | 
			
		||||
    ]
 | 
			
		||||
    return pytree.tree_unflatten(reconciled_args, args_spec)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def local_tensor_mode() -> Optional[LocalTensorMode]:
 | 
			
		||||
    """
 | 
			
		||||
    Returns the current active LocalTensorMode if one exists.
 | 
			
		||||
 | 
			
		||||
    This function checks the global stack of LocalTensorMode instance. If there
 | 
			
		||||
    is at least one LocalTensorMode active, it returns the most recently entered
 | 
			
		||||
    (top of the stack) LocalTensorMode. If no LocalTensorMode is active, it returns None.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Optional[LocalTensorMode]: The current LocalTensorMode if active, else None.
 | 
			
		||||
    """
 | 
			
		||||
    if len(_LOCAL_TENSOR_MODE) > 0:
 | 
			
		||||
        return _LOCAL_TENSOR_MODE[-1]
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]:
 | 
			
		||||
    """
 | 
			
		||||
    Decorator that ensures a function is executed for each local tensor shard
 | 
			
		||||
    when running under LocalTensorMode. If not in LocalTensorMode, the function
 | 
			
		||||
    is executed normally. When in LocalTensorMode, the function is run for each
 | 
			
		||||
    rank, and the results are collected appropriately.
 | 
			
		||||
 | 
			
		||||
    This decorator is useful for functions that exhibit non-SPMD behavior, such
 | 
			
		||||
    as those requiring rank specific actions. For example, a function that computes
 | 
			
		||||
    offset into input tensor based on rank.
 | 
			
		||||
 | 
			
		||||
    Note that the function being decorated must not have any side effects and
 | 
			
		||||
    contain operations for a single rank only. For example, wrapping a function
 | 
			
		||||
    that performs a collective operation will not work.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        func (Callable[..., Any]): The function to be decorated.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Callable[..., Any]: The wrapped function that handles LocalTensorMode logic.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @functools.wraps(func)
 | 
			
		||||
    def wrapper(*args, **kwargs):  # type: ignore[no-untyped-def]
 | 
			
		||||
        lm = local_tensor_mode()
 | 
			
		||||
        if lm is None:
 | 
			
		||||
            return func(*args, **kwargs)
 | 
			
		||||
        ret = None
 | 
			
		||||
        with lm.disable():
 | 
			
		||||
            ret = _for_each_rank_run_func(func, lm.ranks, args, kwargs, alias=False)
 | 
			
		||||
 | 
			
		||||
        lm = local_tensor_mode()
 | 
			
		||||
        assert lm is not None
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
    return wrapper
 | 
			
		||||
							
								
								
									
										669
									
								
								torch/distributed/_local_tensor/_c10d.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										669
									
								
								torch/distributed/_local_tensor/_c10d.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,669 @@
 | 
			
		||||
import functools
 | 
			
		||||
import math
 | 
			
		||||
import operator
 | 
			
		||||
from typing import Sequence
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch._C import ScriptObject
 | 
			
		||||
from torch._C._distributed_c10d import FakeWork
 | 
			
		||||
from torch.distributed._mesh_layout import _MeshLayout
 | 
			
		||||
from torch.distributed.distributed_c10d import (
 | 
			
		||||
    _get_default_group,
 | 
			
		||||
    ProcessGroup,
 | 
			
		||||
    ReduceOp,
 | 
			
		||||
    Work,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# NOTE: Most of the c10d collectives often take a Tensor[] (or Tensor[][])
 | 
			
		||||
# when you would expect Tensor (or Tensor[]).  In fact, there will only ever
 | 
			
		||||
# be one Tensor in this case; the old signature was to support dispatching a
 | 
			
		||||
# collective on multiple devices (ala DataParallel) but we don't support that
 | 
			
		||||
# API anymore.  Note that we are not 100% consistent about this; some more
 | 
			
		||||
# modern collectives like _allgather_base_ got rid of the unnecessary list.
 | 
			
		||||
# When in doubt, consult the code that dispatches to the collective on the PG
 | 
			
		||||
# in distributed_c10d.py e.g., work = group.allgather([tensor_list], [tensor],
 | 
			
		||||
# opts) indicates its always a list.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _gcd_list(numbers: Sequence[int]) -> int:
 | 
			
		||||
    return 0 if not numbers else functools.reduce(math.gcd, numbers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _indices_to_layout(indices: list[int]) -> tuple[tuple[int, ...], tuple[int, ...]]:
 | 
			
		||||
    # Base case: A single index represents a point, not a dimension.
 | 
			
		||||
    if len(indices) <= 1:
 | 
			
		||||
        return (), ()
 | 
			
		||||
 | 
			
		||||
    # The smallest stride is likely the GCD of the differences between consecutive indices.
 | 
			
		||||
    # For a sorted, unique list, all differences will be positive.
 | 
			
		||||
    diffs = [indices[i] - indices[i - 1] for i in range(1, len(indices))]
 | 
			
		||||
    last_stride = _gcd_list(diffs)
 | 
			
		||||
 | 
			
		||||
    assert last_stride != 0, (
 | 
			
		||||
        # This case should not be reached if indices are unique and sorted.
 | 
			
		||||
        "Cannot determine stride; indices may not be unique."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Identify the starting index of each "row" in the last dimension.
 | 
			
		||||
    # An index starts a new row if the preceding index (index - stride) is not present.
 | 
			
		||||
    indices_set = set(indices)
 | 
			
		||||
    higher_dim_indices = [indices[0]]
 | 
			
		||||
    for index in indices[1:]:
 | 
			
		||||
        if (index - last_stride) not in indices_set:
 | 
			
		||||
            higher_dim_indices.append(index)
 | 
			
		||||
 | 
			
		||||
    # From the number of rows, we can deduce the shape of the last dimension.
 | 
			
		||||
    assert len(indices) % len(higher_dim_indices) == 0, (
 | 
			
		||||
        "Indices do not form a regular grid. "
 | 
			
		||||
        f"Found {len(higher_dim_indices)} subgroups for {len(indices)} total elements."
 | 
			
		||||
    )
 | 
			
		||||
    last_shape = len(indices) // len(higher_dim_indices)
 | 
			
		||||
 | 
			
		||||
    # Recurse on the higher-dimensional indices (the start of each row).
 | 
			
		||||
    higher_shapes, higher_strides = _indices_to_layout(higher_dim_indices)
 | 
			
		||||
 | 
			
		||||
    # Combine the results from the recursion with the current dimension's results.
 | 
			
		||||
    final_shapes = higher_shapes + (last_shape,)
 | 
			
		||||
    final_strides = higher_strides + (last_stride,)
 | 
			
		||||
 | 
			
		||||
    return final_shapes, final_strides
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _prepare_collective_groups(
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
) -> tuple[list[int], list[int], int]:
 | 
			
		||||
    process_group = ProcessGroup.unbox(process_group_so)
 | 
			
		||||
 | 
			
		||||
    ranks = torch.distributed.get_process_group_ranks(process_group)
 | 
			
		||||
    assert ranks
 | 
			
		||||
    # TODO: We can handle permutations but the layout inference algorithm will
 | 
			
		||||
    # lose the permutation so we will have to reapply it
 | 
			
		||||
    assert ranks == sorted(ranks), ranks
 | 
			
		||||
    offset = ranks[0]
 | 
			
		||||
    ranks = [r - offset for r in ranks]
 | 
			
		||||
 | 
			
		||||
    shape, strides = _indices_to_layout(ranks)
 | 
			
		||||
    layout = _MeshLayout(shape, strides)
 | 
			
		||||
 | 
			
		||||
    global_pg = _get_default_group()
 | 
			
		||||
    group_offsets = layout.complement(global_pg.size()).all_ranks_from_zero()
 | 
			
		||||
 | 
			
		||||
    return ranks, group_offsets, offset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_broadcast_(
 | 
			
		||||
    tensors: list[torch.Tensor],
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    root_rank: int,
 | 
			
		||||
    root_tensor: int,
 | 
			
		||||
    async_op: bool = True,
 | 
			
		||||
    timeout: int = -1,
 | 
			
		||||
) -> tuple[list[torch.Tensor], ScriptObject]:
 | 
			
		||||
    # "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
 | 
			
		||||
    # "int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"
 | 
			
		||||
    from . import LocalTensor
 | 
			
		||||
 | 
			
		||||
    assert len(tensors) == 1
 | 
			
		||||
    assert root_tensor == 0
 | 
			
		||||
    tensor = tensors[0]
 | 
			
		||||
 | 
			
		||||
    ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)
 | 
			
		||||
 | 
			
		||||
    # We're going to assume SPMD where for every rank group the root_rank is
 | 
			
		||||
    # the same relative to others
 | 
			
		||||
    relative_root_rank = root_rank - offset
 | 
			
		||||
 | 
			
		||||
    assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor"
 | 
			
		||||
 | 
			
		||||
    for group_offset in group_offsets:
 | 
			
		||||
        # For the tensors in this group [group_offset + r for r in ranks]
 | 
			
		||||
        # perform the broadcast on them
 | 
			
		||||
        group_ranks = [group_offset + r for r in ranks]
 | 
			
		||||
        source_rank = group_offset + relative_root_rank
 | 
			
		||||
        source_tensor = tensor._local_tensors[source_rank]
 | 
			
		||||
 | 
			
		||||
        # Broadcast the source tensor to all ranks in this group
 | 
			
		||||
        for rank in group_ranks:
 | 
			
		||||
            if source_rank != rank:
 | 
			
		||||
                tensor._local_tensors[rank].copy_(source_tensor)
 | 
			
		||||
 | 
			
		||||
    work = FakeWork()
 | 
			
		||||
    work_so = Work.boxed(work)
 | 
			
		||||
    return (tensors, work_so)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_reduce(
 | 
			
		||||
    reduce_op: ReduceOp,
 | 
			
		||||
    tensors: list[torch.Tensor],
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    if reduce_op == ReduceOp.SUM:
 | 
			
		||||
        op = operator.add
 | 
			
		||||
    elif reduce_op == ReduceOp.AVG:
 | 
			
		||||
        op = None
 | 
			
		||||
    elif reduce_op == ReduceOp.PRODUCT:
 | 
			
		||||
        op = operator.mul
 | 
			
		||||
    elif reduce_op == ReduceOp.MIN:
 | 
			
		||||
        op = torch.minimum
 | 
			
		||||
    elif reduce_op == ReduceOp.MAX:
 | 
			
		||||
        op = torch.maximum
 | 
			
		||||
    elif reduce_op == ReduceOp.BAND:
 | 
			
		||||
        op = torch.bitwise_and
 | 
			
		||||
    elif reduce_op == ReduceOp.BOR:
 | 
			
		||||
        op = torch.bitwise_or
 | 
			
		||||
    elif reduce_op == ReduceOp.BXOR:
 | 
			
		||||
        op = torch.bitwise_xor
 | 
			
		||||
    elif reduce_op == ReduceOp.PREMUL_SUM:
 | 
			
		||||
        raise NotImplementedError("PREMUL_SUM: need to add binding for scaling factor")
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(f"ReduceOp {reduce_op} not implemented")
 | 
			
		||||
 | 
			
		||||
    if reduce_op == ReduceOp.AVG:
 | 
			
		||||
        return functools.reduce(operator.add, tensors) / len(tensors)
 | 
			
		||||
    else:
 | 
			
		||||
        assert op is not None
 | 
			
		||||
        return functools.reduce(op, tensors)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_all_reduce_(
 | 
			
		||||
    tensors: list[torch.Tensor],
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    reduce_op_so: ScriptObject,
 | 
			
		||||
    sparse_indices: torch.Tensor | None = None,
 | 
			
		||||
    async_op: bool = True,
 | 
			
		||||
    timeout: int = -1,
 | 
			
		||||
) -> tuple[list[torch.Tensor], ScriptObject]:
 | 
			
		||||
    # "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
 | 
			
		||||
    # "__torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, "
 | 
			
		||||
    # "int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
 | 
			
		||||
    from . import LocalTensor
 | 
			
		||||
 | 
			
		||||
    assert len(tensors) == 1
 | 
			
		||||
    tensor = tensors[0]
 | 
			
		||||
    reduce_op = reduce_op_so.op()  # type: ignore[attr-defined]
 | 
			
		||||
 | 
			
		||||
    ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
 | 
			
		||||
 | 
			
		||||
    assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor"
 | 
			
		||||
 | 
			
		||||
    for group_offset in group_offsets:
 | 
			
		||||
        # For the tensors in this group [group_offset + r for r in ranks]
 | 
			
		||||
        # perform the allreduce on them
 | 
			
		||||
        group_ranks = [group_offset + r for r in ranks]
 | 
			
		||||
 | 
			
		||||
        # Collect tensors from the specified ranks in this group
 | 
			
		||||
        group_tensors = []
 | 
			
		||||
        for rank in group_ranks:
 | 
			
		||||
            group_tensors.append(tensor._local_tensors[rank])
 | 
			
		||||
 | 
			
		||||
        # Perform the reduction operation
 | 
			
		||||
        reduced_tensor = _local_reduce(reduce_op, group_tensors)
 | 
			
		||||
 | 
			
		||||
        # Update all tensors in the group with the reduced result
 | 
			
		||||
        for rank in group_ranks:
 | 
			
		||||
            tensor._local_tensors[rank].copy_(reduced_tensor)
 | 
			
		||||
 | 
			
		||||
    work = FakeWork()
 | 
			
		||||
    work_so = Work.boxed(work)
 | 
			
		||||
    return (tensors, work_so)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_allreduce_coalesced_(
 | 
			
		||||
    tensors: list[torch.Tensor],
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    reduce_op_so: ScriptObject,
 | 
			
		||||
    async_op: bool = True,
 | 
			
		||||
    timeout: int = -1,
 | 
			
		||||
) -> ScriptObject:
 | 
			
		||||
    # "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
 | 
			
		||||
    # "__torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"
 | 
			
		||||
    from . import LocalTensor
 | 
			
		||||
 | 
			
		||||
    reduce_op = reduce_op_so.op()  # type: ignore[attr-defined]
 | 
			
		||||
    ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
 | 
			
		||||
 | 
			
		||||
    for group_offset in group_offsets:
 | 
			
		||||
        # For the tensors in this group [group_offset + r for r in ranks]
 | 
			
		||||
        # perform the allreduce on all tensors together
 | 
			
		||||
        group_ranks = [group_offset + r for r in ranks]
 | 
			
		||||
 | 
			
		||||
        # For each tensor, perform the reduction operation
 | 
			
		||||
        for tensor in tensors:
 | 
			
		||||
            assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor"
 | 
			
		||||
            # Collect tensors from the specified ranks in this group
 | 
			
		||||
            group_tensors = []
 | 
			
		||||
            for rank in group_ranks:
 | 
			
		||||
                group_tensors.append(tensor._local_tensors[rank])
 | 
			
		||||
 | 
			
		||||
            # Perform the reduction operation
 | 
			
		||||
            reduced_tensor = _local_reduce(reduce_op, group_tensors)
 | 
			
		||||
 | 
			
		||||
            # Update all tensors in the group with the reduced result
 | 
			
		||||
            for rank in group_ranks:
 | 
			
		||||
                tensor._local_tensors[rank].copy_(reduced_tensor)
 | 
			
		||||
 | 
			
		||||
    work = FakeWork()
 | 
			
		||||
    work_so = Work.boxed(work)
 | 
			
		||||
    return work_so
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_reduce_scatter_tensor_coalesced_(
 | 
			
		||||
    output_tensors: list[torch.Tensor],
 | 
			
		||||
    input_tensors: list[torch.Tensor],
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    reduce_op_so: ScriptObject,
 | 
			
		||||
    async_op: bool = True,
 | 
			
		||||
    timeout: int = -1,
 | 
			
		||||
) -> ScriptObject:
 | 
			
		||||
    # "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, "
 | 
			
		||||
    # "__torch__.torch.classes.c10d.ProcessGroup process_group, "
 | 
			
		||||
    # "__torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, "
 | 
			
		||||
    # "int timeout=-1) -> __torch__.torch.classes.c10d.Work"
 | 
			
		||||
 | 
			
		||||
    from . import LocalTensor
 | 
			
		||||
 | 
			
		||||
    reduce_op = reduce_op_so.op()  # type: ignore[attr-defined]
 | 
			
		||||
    ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
 | 
			
		||||
 | 
			
		||||
    for group_offset in group_offsets:
 | 
			
		||||
        # For the tensors in this group [group_offset + r for r in ranks]
 | 
			
		||||
        # perform the allreduce on all tensors together
 | 
			
		||||
        group_ranks = [group_offset + r for r in ranks]
 | 
			
		||||
 | 
			
		||||
        # For each tensor, perform the reduction operation
 | 
			
		||||
        for input_tensor, output_tensor in zip(input_tensors, output_tensors):
 | 
			
		||||
            assert isinstance(input_tensor, LocalTensor), (
 | 
			
		||||
                "Input tensor must be a LocalTensor"
 | 
			
		||||
            )
 | 
			
		||||
            assert isinstance(output_tensor, LocalTensor), (
 | 
			
		||||
                "Output tensor must be a LocalTensor"
 | 
			
		||||
            )
 | 
			
		||||
            # Collect tensors from the specified ranks in this group
 | 
			
		||||
            group_inputs = []
 | 
			
		||||
            for rank in group_ranks:
 | 
			
		||||
                group_inputs.append(input_tensor._local_tensors[rank])
 | 
			
		||||
 | 
			
		||||
            # Perform the reduction operation
 | 
			
		||||
            reduced_input = _local_reduce(reduce_op, group_inputs)
 | 
			
		||||
 | 
			
		||||
            reduced_inpit_splits = torch.split(
 | 
			
		||||
                reduced_input, reduced_input.size(0) // len(group_ranks), dim=0
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Update all tensors in the group with the reduced result
 | 
			
		||||
            for rank in group_ranks:
 | 
			
		||||
                output_tensor._local_tensors[rank].copy_(reduced_inpit_splits[rank])
 | 
			
		||||
 | 
			
		||||
    work = FakeWork()
 | 
			
		||||
    work_so = Work.boxed(work)
 | 
			
		||||
    return work_so
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_all_gather_(
 | 
			
		||||
    output_tensors: list[list[torch.Tensor]],
 | 
			
		||||
    input_tensors: list[torch.Tensor],
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    async_op: bool = True,
 | 
			
		||||
    timeout: int = -1,
 | 
			
		||||
) -> tuple[list[list[torch.Tensor]], ScriptObject]:
 | 
			
		||||
    # "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, "
 | 
			
		||||
    # "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, "
 | 
			
		||||
    # "int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
 | 
			
		||||
 | 
			
		||||
    from . import LocalTensor
 | 
			
		||||
 | 
			
		||||
    assert len(output_tensors) == 1
 | 
			
		||||
    assert len(input_tensors) == 1
 | 
			
		||||
 | 
			
		||||
    input_tensor = input_tensors[0]
 | 
			
		||||
    output_tensors = output_tensors[0]
 | 
			
		||||
 | 
			
		||||
    ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
 | 
			
		||||
 | 
			
		||||
    assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor"
 | 
			
		||||
    for i in range(len(output_tensors)):
 | 
			
		||||
        assert isinstance(output_tensors[i], LocalTensor), (
 | 
			
		||||
            "Output tensor must be a LocalTensor"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    for group_offset in group_offsets:
 | 
			
		||||
        # For the tensors in this group [group_offset + r for r in ranks]
 | 
			
		||||
        # perform the all_gather on them
 | 
			
		||||
        group_ranks = [group_offset + r for r in ranks]
 | 
			
		||||
 | 
			
		||||
        # For each rank in the group, gather from their input tensor
 | 
			
		||||
        for i, rank_i in enumerate(group_ranks):
 | 
			
		||||
            output_tensors[i].copy_(input_tensor._local_tensors[rank_i])
 | 
			
		||||
 | 
			
		||||
    work = FakeWork()
 | 
			
		||||
    work_so = Work.boxed(work)
 | 
			
		||||
    return ([output_tensors], work_so)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_allgather_into_tensor_coalesced_(
 | 
			
		||||
    output_tensors: list[torch.Tensor],
 | 
			
		||||
    input_tensors: list[torch.Tensor],
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    async_op: bool = True,
 | 
			
		||||
) -> ScriptObject:
 | 
			
		||||
    # "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, "
 | 
			
		||||
    # "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) "
 | 
			
		||||
    # "-> __torch__.torch.classes.c10d.Work"
 | 
			
		||||
    from . import LocalTensor
 | 
			
		||||
 | 
			
		||||
    ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
 | 
			
		||||
 | 
			
		||||
    # Each output tensor should be sized to hold all gathered inputs
 | 
			
		||||
    # outputs[i] will contain all inputs[i] from all ranks
 | 
			
		||||
    assert len(output_tensors) == len(input_tensors), (
 | 
			
		||||
        f"Number of outputs ({len(output_tensors)}) must match number of inputs ({len(input_tensors)})"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    for group_offset in group_offsets:
 | 
			
		||||
        # For the tensors in this group [group_offset + r for r in ranks]
 | 
			
		||||
        # perform the allgather_into_tensor on them
 | 
			
		||||
        group_ranks = [group_offset + r for r in ranks]
 | 
			
		||||
 | 
			
		||||
        # For each input/output pair
 | 
			
		||||
        for input_tensor, output_tensor in zip(input_tensors, output_tensors):
 | 
			
		||||
            assert isinstance(input_tensor, LocalTensor), (
 | 
			
		||||
                "Input tensor must be a LocalTensor"
 | 
			
		||||
            )
 | 
			
		||||
            assert isinstance(output_tensor, LocalTensor), (
 | 
			
		||||
                "Output tensor must be a LocalTensor"
 | 
			
		||||
            )
 | 
			
		||||
            # Gather input_tensor from all ranks into output_tensor
 | 
			
		||||
            # The output should be a concatenation of all inputs along the first dimension
 | 
			
		||||
            gathered_tensors = []
 | 
			
		||||
            for rank in group_ranks:
 | 
			
		||||
                gathered_tensors.append(input_tensor._local_tensors[rank])
 | 
			
		||||
 | 
			
		||||
            # Concatenate along first dimension and copy to output
 | 
			
		||||
            if gathered_tensors:
 | 
			
		||||
                concatenated = torch.cat(gathered_tensors, dim=0)
 | 
			
		||||
                for rank in group_ranks:
 | 
			
		||||
                    output_tensor._local_tensors[rank].copy_(concatenated)
 | 
			
		||||
 | 
			
		||||
    work = FakeWork()
 | 
			
		||||
    work_so = Work.boxed(work)
 | 
			
		||||
    return work_so
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_gather_(
 | 
			
		||||
    output_tensors: list[list[torch.Tensor]],
 | 
			
		||||
    input_tensors: list[torch.Tensor],
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    root_rank: int,
 | 
			
		||||
    async_op: bool = True,
 | 
			
		||||
    timeout: int = -1,
 | 
			
		||||
) -> ScriptObject:
 | 
			
		||||
    # "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, "
 | 
			
		||||
    # "__torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, "
 | 
			
		||||
    # "bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"
 | 
			
		||||
    raise NotImplementedError(
 | 
			
		||||
        "LocalTensor does not support MPMD operations like gather "
 | 
			
		||||
        "(only root rank receives data). Use SPMD collective operations like allgather instead."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_scatter_(
 | 
			
		||||
    output_tensors: list[torch.Tensor],
 | 
			
		||||
    input_tensors: list[list[torch.Tensor]],
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    root_rank: int,
 | 
			
		||||
    async_op: bool = True,
 | 
			
		||||
    timeout: int = -1,
 | 
			
		||||
) -> tuple[list[torch.Tensor], ScriptObject]:
 | 
			
		||||
    # "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, "
 | 
			
		||||
    # "__torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, "
 | 
			
		||||
    # "bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
 | 
			
		||||
 | 
			
		||||
    from . import LocalTensor
 | 
			
		||||
 | 
			
		||||
    assert len(output_tensors) == 1
 | 
			
		||||
    assert len(input_tensors) == 1
 | 
			
		||||
    output_tensor = output_tensors[0]
 | 
			
		||||
    input_tensors = input_tensors[0]
 | 
			
		||||
 | 
			
		||||
    ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)
 | 
			
		||||
 | 
			
		||||
    # We're going to assume SPMD where for every rank group the root_rank is
 | 
			
		||||
    # the same relative to others
 | 
			
		||||
    relative_root_rank = root_rank - offset
 | 
			
		||||
 | 
			
		||||
    assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor"
 | 
			
		||||
    assert len(ranks) == len(input_tensors), (ranks, input_tensors)
 | 
			
		||||
 | 
			
		||||
    for group_offset in group_offsets:
 | 
			
		||||
        # For the tensors in this group [group_offset + r for r in ranks]
 | 
			
		||||
        # perform the scatter on them
 | 
			
		||||
        group_ranks = [group_offset + r for r in ranks]
 | 
			
		||||
 | 
			
		||||
        # Root rank scatters its input tensors to all ranks in this group
 | 
			
		||||
        for i, rank in enumerate(group_ranks):
 | 
			
		||||
            input_tensor = input_tensors[i]
 | 
			
		||||
            assert isinstance(input_tensor, LocalTensor)
 | 
			
		||||
            # Each rank i gets the i-th input tensor from the root
 | 
			
		||||
            source_tensor = input_tensor._local_tensors[
 | 
			
		||||
                group_offset + relative_root_rank
 | 
			
		||||
            ]
 | 
			
		||||
            output_tensor._local_tensors[rank].copy_(source_tensor)
 | 
			
		||||
 | 
			
		||||
    work = FakeWork()
 | 
			
		||||
    work_so = Work.boxed(work)
 | 
			
		||||
    return (output_tensors, work_so)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_alltoall_(
 | 
			
		||||
    output_tensors: list[torch.Tensor],
 | 
			
		||||
    input_tensors: list[torch.Tensor],
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    async_op: bool = True,
 | 
			
		||||
    timeout: int = -1,
 | 
			
		||||
) -> tuple[list[torch.Tensor], ScriptObject]:
 | 
			
		||||
    # "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, "
 | 
			
		||||
    # "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, "
 | 
			
		||||
    # "int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)";
 | 
			
		||||
 | 
			
		||||
    from . import LocalTensor
 | 
			
		||||
 | 
			
		||||
    ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
 | 
			
		||||
 | 
			
		||||
    assert len(input_tensors) == len(output_tensors) == len(ranks), (
 | 
			
		||||
        f"Number of input tensors ({len(input_tensors)}), "
 | 
			
		||||
        f"output tensors ({len(output_tensors)}), and ranks ({len(ranks)}) must match"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    for group_offset in group_offsets:
 | 
			
		||||
        # For the tensors in this group [group_offset + r for r in ranks]
 | 
			
		||||
        # perform the alltoall on them
 | 
			
		||||
        group_ranks = [group_offset + r for r in ranks]
 | 
			
		||||
 | 
			
		||||
        # In alltoall, rank i sends input_tensors[j] to rank j and receives into output_tensors[i] from rank j
 | 
			
		||||
        for i, rank_i in enumerate(group_ranks):
 | 
			
		||||
            output_tensor = output_tensors[i]
 | 
			
		||||
            assert isinstance(output_tensor, LocalTensor), (
 | 
			
		||||
                "Output tensor must be a LocalTensor"
 | 
			
		||||
            )
 | 
			
		||||
            for j, rank_j in enumerate(group_ranks):
 | 
			
		||||
                input_tensor = input_tensors[j]
 | 
			
		||||
                assert isinstance(input_tensor, LocalTensor), (
 | 
			
		||||
                    "Input tensor must be a LocalTensor"
 | 
			
		||||
                )
 | 
			
		||||
                # Rank i's j-th input tensor goes to rank j's i-th output tensor
 | 
			
		||||
                source_tensor = input_tensor._local_tensors[rank_i]
 | 
			
		||||
                output_tensor._local_tensors[rank_j].copy_(source_tensor)
 | 
			
		||||
 | 
			
		||||
    work = FakeWork()
 | 
			
		||||
    work_so = Work.boxed(work)
 | 
			
		||||
    return (output_tensors, work_so)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_alltoall_base_(
 | 
			
		||||
    output_tensor: torch.Tensor,
 | 
			
		||||
    input_tensor: torch.Tensor,
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    output_split_sizes: list[int],
 | 
			
		||||
    input_split_sizes: list[int],
 | 
			
		||||
    async_op: bool = True,
 | 
			
		||||
    timeout: int = -1,
 | 
			
		||||
) -> ScriptObject:
 | 
			
		||||
    # "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, "
 | 
			
		||||
    # "int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work";
 | 
			
		||||
 | 
			
		||||
    from . import LocalTensor
 | 
			
		||||
 | 
			
		||||
    ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
 | 
			
		||||
 | 
			
		||||
    assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor"
 | 
			
		||||
    assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor"
 | 
			
		||||
    # Convert split sizes to lists if they aren't already
 | 
			
		||||
    if output_split_sizes is not None:
 | 
			
		||||
        output_split_sizes = list(output_split_sizes)
 | 
			
		||||
    if input_split_sizes is not None:
 | 
			
		||||
        input_split_sizes = list(input_split_sizes)
 | 
			
		||||
 | 
			
		||||
    for group_offset in group_offsets:
 | 
			
		||||
        # For the tensors in this group [group_offset + r for r in ranks]
 | 
			
		||||
        # perform the alltoall_base on them
 | 
			
		||||
        group_ranks = [group_offset + r for r in ranks]
 | 
			
		||||
 | 
			
		||||
        for i, rank_i in enumerate(group_ranks):
 | 
			
		||||
            # Split input tensor from rank_i according to input_split_sizes
 | 
			
		||||
            rank_tensor = input_tensor._local_tensors[rank_i]
 | 
			
		||||
 | 
			
		||||
            if input_split_sizes is not None and len(input_split_sizes) > 0:
 | 
			
		||||
                # Split the input tensor
 | 
			
		||||
                input_splits = torch.split(rank_tensor, input_split_sizes, dim=0)
 | 
			
		||||
            else:
 | 
			
		||||
                # No split sizes specified, split evenly
 | 
			
		||||
                split_size = rank_tensor.size(0) // len(group_ranks)
 | 
			
		||||
                input_splits = torch.split(rank_tensor, split_size, dim=0)
 | 
			
		||||
 | 
			
		||||
            # Send each split to the corresponding rank
 | 
			
		||||
            for j, rank_j in enumerate(group_ranks):
 | 
			
		||||
                if j < len(input_splits):
 | 
			
		||||
                    split_tensor = input_splits[j]
 | 
			
		||||
 | 
			
		||||
                    # Determine where to place this split in the output tensor
 | 
			
		||||
                    if output_split_sizes is not None and len(output_split_sizes) > 0:
 | 
			
		||||
                        # Calculate offset based on output split sizes
 | 
			
		||||
                        output_offset = sum(output_split_sizes[:i]) if i > 0 else 0
 | 
			
		||||
                        end_offset = (
 | 
			
		||||
                            output_offset + output_split_sizes[i]
 | 
			
		||||
                            if i < len(output_split_sizes)
 | 
			
		||||
                            else output_tensor._local_tensors[rank_j].size(0)
 | 
			
		||||
                        )
 | 
			
		||||
                    else:
 | 
			
		||||
                        # No output split sizes, use even splits
 | 
			
		||||
                        split_size = output_tensor._local_tensors[rank_j].size(
 | 
			
		||||
                            0
 | 
			
		||||
                        ) // len(group_ranks)
 | 
			
		||||
                        output_offset = i * split_size
 | 
			
		||||
                        end_offset = min(
 | 
			
		||||
                            (i + 1) * split_size,
 | 
			
		||||
                            output_tensor._local_tensors[rank_j].size(0),
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                    # Copy the split to the appropriate section of the output tensor
 | 
			
		||||
                    output_section = output_tensor._local_tensors[rank_j][
 | 
			
		||||
                        output_offset:end_offset
 | 
			
		||||
                    ]
 | 
			
		||||
                    if output_section.numel() > 0:
 | 
			
		||||
                        # Reshape split_tensor to match output_section if necessary
 | 
			
		||||
                        if split_tensor.size() != output_section.size():
 | 
			
		||||
                            split_tensor = split_tensor.view(output_section.size())
 | 
			
		||||
                        output_section.copy_(split_tensor)
 | 
			
		||||
 | 
			
		||||
    work = FakeWork()
 | 
			
		||||
    work_so = Work.boxed(work)
 | 
			
		||||
    return work_so
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_barrier(
 | 
			
		||||
    tensor: torch.Tensor,
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    device_ids: list[int],
 | 
			
		||||
    async_op: bool = True,
 | 
			
		||||
    timeout: int = -1,
 | 
			
		||||
) -> ScriptObject:
 | 
			
		||||
    # "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, "
 | 
			
		||||
    # "int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work";
 | 
			
		||||
 | 
			
		||||
    from . import LocalTensor
 | 
			
		||||
 | 
			
		||||
    # Barrier is a synchronization primitive - in local simulation,
 | 
			
		||||
    # we don't need to do any actual work since all "ranks" are in the same process
 | 
			
		||||
    # Just validate that the tensor is a LocalTensor
 | 
			
		||||
    assert isinstance(tensor, LocalTensor)
 | 
			
		||||
 | 
			
		||||
    # In a real distributed setting, barrier would synchronize all processes
 | 
			
		||||
    # In local simulation, this is essentially a no-op since all ranks are local
 | 
			
		||||
    work = FakeWork()
 | 
			
		||||
    work_so = Work.boxed(work)
 | 
			
		||||
    return work_so
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_monitored_barrier_(
 | 
			
		||||
    tensor: torch.Tensor,
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    device_ids: list[int],
 | 
			
		||||
    timeout: int,
 | 
			
		||||
    wait_all_ranks: bool,
 | 
			
		||||
) -> None:
 | 
			
		||||
    # "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, "
 | 
			
		||||
    # "int[] device_ids, int timeout, bool wait_all_ranks) -> ()";
 | 
			
		||||
 | 
			
		||||
    from . import LocalTensor
 | 
			
		||||
 | 
			
		||||
    # Monitored barrier is a synchronization primitive with monitoring - in local simulation,
 | 
			
		||||
    # we don't need to do any actual work since all "ranks" are in the same process
 | 
			
		||||
    # Just validate that the tensor is a LocalTensor
 | 
			
		||||
    assert isinstance(tensor, LocalTensor)
 | 
			
		||||
 | 
			
		||||
    # In a real distributed setting, monitored barrier would synchronize all processes
 | 
			
		||||
    # and provide monitoring capabilities. In local simulation, this is essentially a no-op
 | 
			
		||||
    # since all ranks are local and no actual synchronization is needed
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_send(
 | 
			
		||||
    tensors: list[torch.Tensor],
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    dst: int,
 | 
			
		||||
    tag: int,
 | 
			
		||||
) -> ScriptObject:
 | 
			
		||||
    # "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
 | 
			
		||||
    # "int dst, int tag) -> __torch__.torch.classes.c10d.Work";
 | 
			
		||||
 | 
			
		||||
    raise NotImplementedError(
 | 
			
		||||
        "LocalTensor does not support MPMD operations like send. "
 | 
			
		||||
        "Use SPMD collective operations instead."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_recv_(
 | 
			
		||||
    tensors: list[torch.Tensor],
 | 
			
		||||
    process_group_so: ScriptObject,
 | 
			
		||||
    src: int,
 | 
			
		||||
    tag: int,
 | 
			
		||||
) -> ScriptObject:
 | 
			
		||||
    # "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
 | 
			
		||||
    # "int src, int tag) -> __torch__.torch.classes.c10d.Work";
 | 
			
		||||
 | 
			
		||||
    raise NotImplementedError(
 | 
			
		||||
        "LocalTensor does not support MPMD operations like recv. "
 | 
			
		||||
        "Use SPMD collective operations instead."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _local_recv_any_source_(
 | 
			
		||||
    tensors: list[torch.Tensor], process_group_so: ScriptObject, tag: int
 | 
			
		||||
) -> ScriptObject:
 | 
			
		||||
    # "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
 | 
			
		||||
    # "int tag) -> __torch__.torch.classes.c10d.Work";
 | 
			
		||||
 | 
			
		||||
    raise NotImplementedError(
 | 
			
		||||
        "LocalTensor does not support MPMD operations like recv_any_source. "
 | 
			
		||||
        "Use SPMD collective operations instead."
 | 
			
		||||
    )
 | 
			
		||||
@ -10,6 +10,7 @@ import torch.distributed._functional_collectives as funcol
 | 
			
		||||
import torch.distributed.tensor._dtensor_spec as dtensor_spec
 | 
			
		||||
from torch._C._distributed_c10d import _resolve_process_group
 | 
			
		||||
from torch._logging import warning_once
 | 
			
		||||
from torch.distributed._local_tensor import local_tensor_mode
 | 
			
		||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
 | 
			
		||||
from torch.distributed.distributed_c10d import (
 | 
			
		||||
    _get_group_size_by_name,
 | 
			
		||||
@ -40,7 +41,7 @@ def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
 | 
			
		||||
    if mesh.device_type == "cpu":
 | 
			
		||||
    if mesh.device_type == "cpu" and local_tensor_mode() is None:
 | 
			
		||||
        # Gloo does not support alltoall, so falling back to allgather + chunk
 | 
			
		||||
        warning_once(
 | 
			
		||||
            logger,
 | 
			
		||||
 | 
			
		||||
@ -165,7 +165,7 @@ class OpDispatcher:
 | 
			
		||||
                raise
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                f"Sharding propagation failed for {op_info.schema}"
 | 
			
		||||
                f"{e}\n\nSharding propagation failed for {op_info.schema}"
 | 
			
		||||
            ) from e
 | 
			
		||||
 | 
			
		||||
        output_sharding = op_info.output_sharding
 | 
			
		||||
 | 
			
		||||
@ -319,6 +319,10 @@ LINEAR_REDUCTION_OP_MAP = {
 | 
			
		||||
    aten.all.dim: "sum",
 | 
			
		||||
    aten.sum.default: "sum",
 | 
			
		||||
    aten.sum.dim_IntList: "sum",
 | 
			
		||||
    aten.any.default: "sum",
 | 
			
		||||
    aten.any.dim: "sum",
 | 
			
		||||
    aten.any.out: "sum",
 | 
			
		||||
    # These are only valid when there is no padding
 | 
			
		||||
    aten.prod.default: "product",
 | 
			
		||||
    aten.prod.dim_int: "product",
 | 
			
		||||
    aten.prod.int_out: "product",
 | 
			
		||||
@ -332,9 +336,6 @@ LINEAR_REDUCTION_OP_MAP = {
 | 
			
		||||
    aten.min.default: "min",
 | 
			
		||||
    aten.min.dim: "min",
 | 
			
		||||
    aten.min.out: "min",
 | 
			
		||||
    aten.any.default: "sum",
 | 
			
		||||
    aten.any.dim: "sum",
 | 
			
		||||
    aten.any.out: "sum",
 | 
			
		||||
    aten.amax.default: "max",
 | 
			
		||||
    aten.amax.out: "max",
 | 
			
		||||
    aten.amin.default: "min",
 | 
			
		||||
 | 
			
		||||
@ -383,6 +383,7 @@ def redistribute_local_tensor(
 | 
			
		||||
                    raise RuntimeError(
 | 
			
		||||
                        f"redistribute from {current} to {target} not supported yet"
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
            elif target.is_shard():
 | 
			
		||||
                # Case 2: target is Shard
 | 
			
		||||
                target_placement = cast(Shard, target)
 | 
			
		||||
 | 
			
		||||
@ -149,7 +149,7 @@ def _compute_local_shape_and_global_offset(
 | 
			
		||||
    ordered_placements = _explicit_order_placements(mesh_shape, placements)
 | 
			
		||||
 | 
			
		||||
    local_shape = list(global_shape)
 | 
			
		||||
    # We'll compute the data for where the shard beings on a per-dim basis.
 | 
			
		||||
    # We'll compute the data for where the shard begins on a per-dim basis.
 | 
			
		||||
    # However, a single dim can be sharded multiple times, so we will end up
 | 
			
		||||
    # doing a Sum(size*stride) like computation to determine the location of our
 | 
			
		||||
    # shard for each of the shardings on that dim.
 | 
			
		||||
@ -170,6 +170,14 @@ def _compute_local_shape_and_global_offset(
 | 
			
		||||
 | 
			
		||||
            local_shape[shard_dim] = shard_size
 | 
			
		||||
 | 
			
		||||
            shard_global_offset = global_offset[shard_dim] + not_none(shard_offset)
 | 
			
		||||
 | 
			
		||||
            zero_global_offset = global_shape[shard_dim]
 | 
			
		||||
            if isinstance(shard_global_offset, torch.SymInt) and not isinstance(
 | 
			
		||||
                zero_global_offset, torch.SymInt
 | 
			
		||||
            ):
 | 
			
		||||
                zero_global_offset = torch.SymInt(zero_global_offset)
 | 
			
		||||
 | 
			
		||||
            global_offset[shard_dim] = torch.sym_ite(
 | 
			
		||||
                shard_size == 0,
 | 
			
		||||
                # Special case to fill in a standardized non-garbage value for
 | 
			
		||||
@ -179,11 +187,11 @@ def _compute_local_shape_and_global_offset(
 | 
			
		||||
                # Note that you can end up with zero-size shards that are
 | 
			
		||||
                # still otherwise in bounds for the tensor (TODO: give an
 | 
			
		||||
                # example).
 | 
			
		||||
                global_shape[shard_dim],
 | 
			
		||||
                zero_global_offset,
 | 
			
		||||
                # As we successively shard the same dimension, we keep
 | 
			
		||||
                # advancing our pointer beyond our original offset until we
 | 
			
		||||
                # get to the final chunk start.
 | 
			
		||||
                global_offset[shard_dim] + not_none(shard_offset),
 | 
			
		||||
                shard_global_offset,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    # NOTE: the offset compute relies on the local shard index and it has no
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,7 @@ from typing import cast, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed._functional_collectives as funcol
 | 
			
		||||
from torch.distributed._local_tensor import maybe_run_for_local_tensor
 | 
			
		||||
from torch.distributed.device_mesh import DeviceMesh
 | 
			
		||||
from torch.distributed.tensor._collective_utils import (
 | 
			
		||||
    fill_empty_tensor_to_shards,
 | 
			
		||||
@ -128,6 +129,7 @@ class Shard(Placement):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @maybe_run_for_local_tensor
 | 
			
		||||
    def local_shard_size_and_offset(
 | 
			
		||||
        curr_local_size: int,
 | 
			
		||||
        num_chunks: int,
 | 
			
		||||
@ -170,6 +172,20 @@ class Shard(Placement):
 | 
			
		||||
    ) -> tuple[int, Optional[int]]:
 | 
			
		||||
        return Shard.local_shard_size_and_offset(curr_local_size, num_chunks, rank)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @maybe_run_for_local_tensor
 | 
			
		||||
    def _maybe_unpad_tensor_with_sizes(
 | 
			
		||||
        dim, local_tensor, pad_sizes, mesh_dim_local_rank, make_contiguous
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        # Only unpad if the local_tensor was padded on the dimension.
 | 
			
		||||
        if pad_sizes[mesh_dim_local_rank] > 0:
 | 
			
		||||
            local_tensor = unpad_tensor(
 | 
			
		||||
                local_tensor, dim, pad_sizes[mesh_dim_local_rank]
 | 
			
		||||
            )
 | 
			
		||||
            if make_contiguous:
 | 
			
		||||
                local_tensor = local_tensor.contiguous()
 | 
			
		||||
        return local_tensor
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _make_shard_tensor(
 | 
			
		||||
        dim: int,
 | 
			
		||||
@ -198,24 +214,28 @@ class Shard(Placement):
 | 
			
		||||
                dim, tensor, num_chunks, with_padding=False, contiguous=True
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            return scatter_list[mesh_dim_local_rank]
 | 
			
		||||
            return Shard._select_shard(scatter_list, mesh_dim_local_rank)
 | 
			
		||||
 | 
			
		||||
        scatter_list, pad_sizes = Shard._make_split_tensor(
 | 
			
		||||
            dim, tensor, num_chunks, with_padding=True, contiguous=True
 | 
			
		||||
        )
 | 
			
		||||
        output = torch.empty_like(scatter_list[mesh_dim_local_rank])
 | 
			
		||||
 | 
			
		||||
        it = iter(scatter_list)
 | 
			
		||||
        first = next(it)
 | 
			
		||||
        # Tensors in the scatter list are expected to have the same shape because
 | 
			
		||||
        # split is requested with padding.
 | 
			
		||||
        assert all(first.shape == v.shape for v in it)
 | 
			
		||||
 | 
			
		||||
        output = torch.empty_like(first)
 | 
			
		||||
 | 
			
		||||
        # perform scatter from the src_data_rank as data source when it is not None
 | 
			
		||||
        mesh_scatter(
 | 
			
		||||
            output, scatter_list, mesh, mesh_dim=mesh_dim, group_src=src_data_rank
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Only unpad if the local_tensor was padded on the dimension.
 | 
			
		||||
        if pad_sizes[mesh_dim_local_rank] > 0:
 | 
			
		||||
            output = unpad_tensor(output, dim, pad_sizes[mesh_dim_local_rank])
 | 
			
		||||
            # Unpad might return a view, hence we need to remake it contiguous
 | 
			
		||||
            output = output.contiguous()
 | 
			
		||||
        return output
 | 
			
		||||
        return Shard._maybe_unpad_tensor_with_sizes(
 | 
			
		||||
            dim, output, pad_sizes, mesh_dim_local_rank, True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _shard_tensor(
 | 
			
		||||
        self,
 | 
			
		||||
@ -245,6 +265,7 @@ class Shard(Placement):
 | 
			
		||||
            return tensor
 | 
			
		||||
 | 
			
		||||
        is_padded = tensor.size(self.dim) % num_chunks != 0
 | 
			
		||||
        pad_sizes = None
 | 
			
		||||
        if is_padded:
 | 
			
		||||
            scattered_list, pad_sizes = Shard._make_split_tensor(
 | 
			
		||||
                self.dim, tensor, num_chunks, with_padding=True, contiguous=True
 | 
			
		||||
@ -258,9 +279,47 @@ class Shard(Placement):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if is_padded:
 | 
			
		||||
            output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]])  # type: ignore[possibly-undefined]
 | 
			
		||||
            assert pad_sizes is not None
 | 
			
		||||
            output = Shard._maybe_unpad_tensor_with_sizes(
 | 
			
		||||
                self.dim, output, pad_sizes, my_coordinate[mesh_dim], False
 | 
			
		||||
            )
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    @maybe_run_for_local_tensor
 | 
			
		||||
    def _maybe_pad_tensor(
 | 
			
		||||
        self,
 | 
			
		||||
        local_tensor: torch.Tensor,
 | 
			
		||||
        logical_dim_size: int,
 | 
			
		||||
        num_chunks: int,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        is_padded = logical_dim_size % num_chunks != 0
 | 
			
		||||
 | 
			
		||||
        if is_padded:
 | 
			
		||||
            full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
 | 
			
		||||
            pad_size = full_chunk_size - local_tensor.size(self.dim)
 | 
			
		||||
            local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
 | 
			
		||||
 | 
			
		||||
        if not local_tensor.is_contiguous():
 | 
			
		||||
            local_tensor = local_tensor.contiguous()
 | 
			
		||||
 | 
			
		||||
        return local_tensor
 | 
			
		||||
 | 
			
		||||
    @maybe_run_for_local_tensor
 | 
			
		||||
    def _maybe_unpad_tensor(
 | 
			
		||||
        self,
 | 
			
		||||
        local_tensor: torch.Tensor,
 | 
			
		||||
        logical_dim_size: int,
 | 
			
		||||
        num_chunks: int,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        is_padded = logical_dim_size % num_chunks != 0
 | 
			
		||||
 | 
			
		||||
        if is_padded:
 | 
			
		||||
            full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
 | 
			
		||||
            unpad_size = full_chunk_size * num_chunks - logical_dim_size  # type: ignore[possibly-undefined]
 | 
			
		||||
            local_tensor = unpad_tensor(local_tensor, self.dim, unpad_size)
 | 
			
		||||
 | 
			
		||||
        return local_tensor
 | 
			
		||||
 | 
			
		||||
    def _to_replicate_tensor(
 | 
			
		||||
        self,
 | 
			
		||||
        local_tensor: torch.Tensor,
 | 
			
		||||
@ -273,28 +332,27 @@ class Shard(Placement):
 | 
			
		||||
        is replicated on the previously sharded mesh dimension
 | 
			
		||||
        """
 | 
			
		||||
        num_chunks = mesh.size(mesh_dim=mesh_dim)
 | 
			
		||||
 | 
			
		||||
        logical_dim_size = current_logical_shape[self.dim]
 | 
			
		||||
        is_padded = logical_dim_size % num_chunks != 0
 | 
			
		||||
 | 
			
		||||
        if is_padded:
 | 
			
		||||
            full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
 | 
			
		||||
            pad_size = full_chunk_size - local_tensor.size(self.dim)
 | 
			
		||||
            local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
 | 
			
		||||
 | 
			
		||||
        if not local_tensor.is_contiguous():
 | 
			
		||||
            local_tensor = local_tensor.contiguous()
 | 
			
		||||
        local_tensor = self._maybe_pad_tensor(
 | 
			
		||||
            local_tensor, logical_dim_size, num_chunks
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        result = funcol.all_gather_tensor(
 | 
			
		||||
            local_tensor,
 | 
			
		||||
            gather_dim=self.dim,
 | 
			
		||||
            group=(mesh, mesh_dim),
 | 
			
		||||
        )
 | 
			
		||||
        if is_padded:
 | 
			
		||||
            unpad_size = full_chunk_size * num_chunks - logical_dim_size  # type: ignore[possibly-undefined]
 | 
			
		||||
            result = unpad_tensor(result, self.dim, unpad_size)
 | 
			
		||||
 | 
			
		||||
        result = self._maybe_unpad_tensor(result, logical_dim_size, num_chunks)
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @maybe_run_for_local_tensor
 | 
			
		||||
    def _select_shard(shards: list[torch.Tensor], shard_index) -> torch.Tensor:
 | 
			
		||||
        return shards[shard_index].clone()
 | 
			
		||||
 | 
			
		||||
    def _replicate_to_shard(
 | 
			
		||||
        self,
 | 
			
		||||
        local_tensor: torch.Tensor,
 | 
			
		||||
@ -313,7 +371,8 @@ class Shard(Placement):
 | 
			
		||||
            with_padding=False,
 | 
			
		||||
            contiguous=False,
 | 
			
		||||
        )
 | 
			
		||||
        return shards[shard_index].clone()
 | 
			
		||||
 | 
			
		||||
        return Shard._select_shard(shards, shard_index)
 | 
			
		||||
 | 
			
		||||
    def _to_new_shard_dim(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
@ -41,6 +41,9 @@ class ConstantIntNode:
 | 
			
		||||
    def _graph_repr(self) -> str:
 | 
			
		||||
        return self._str()
 | 
			
		||||
 | 
			
		||||
    def add(self, other: Any) -> Any:
 | 
			
		||||
        return other.add(self)
 | 
			
		||||
 | 
			
		||||
    def mul(self, other: Any) -> Any:
 | 
			
		||||
        return other.mul(self)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -13,6 +13,7 @@ import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch.distributed._local_tensor import LocalTensor
 | 
			
		||||
from torch.distributed.tensor import (
 | 
			
		||||
    DeviceMesh,
 | 
			
		||||
    distribute_tensor,
 | 
			
		||||
@ -660,7 +661,7 @@ class DTensorConverter:
 | 
			
		||||
    def to_dist_tensor(
 | 
			
		||||
        self, t: torch.Tensor, mesh: DeviceMesh, placements: list[Placement]
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        if type(t) is torch.Tensor or type(t) is nn.Parameter:
 | 
			
		||||
        if type(t) is torch.Tensor or type(t) is nn.Parameter or type(t) is LocalTensor:
 | 
			
		||||
            if self.is_supported_tensor(t):
 | 
			
		||||
                self.hit += 1
 | 
			
		||||
                if t.ndim == 0:
 | 
			
		||||
@ -669,7 +670,7 @@ class DTensorConverter:
 | 
			
		||||
                else:
 | 
			
		||||
                    # distribute non-scalar tensors
 | 
			
		||||
                    r = distribute_tensor(t, mesh, placements)
 | 
			
		||||
                if type(t) is nn.Parameter:
 | 
			
		||||
                if isinstance(t, nn.Parameter):
 | 
			
		||||
                    r = nn.Parameter(  # type: ignore[assignment]
 | 
			
		||||
                        r, requires_grad=r.requires_grad
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
@ -1292,7 +1292,7 @@ SAC_IGNORED_OPS = {
 | 
			
		||||
    # With subclasses involved, these metadata ops become dispatchable, this
 | 
			
		||||
    # can result in incorrectness if these ops are selected cached.
 | 
			
		||||
    torch.ops.prim.device.default,
 | 
			
		||||
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
 | 
			
		||||
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)  # type: ignore[has-type]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _CachingTorchDispatchMode(TorchDispatchMode):
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user