mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 22:14:53 +08:00
Compare commits
1 Commits
ciflow/tru
...
huba/local
| Author | SHA1 | Date | |
|---|---|---|---|
| b1648d4bc9 |
@ -3,13 +3,23 @@
|
||||
|
||||
import pathlib
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
from torch.distributed._local_tensor import (
|
||||
LocalIntNode,
|
||||
LocalTensorMode,
|
||||
maybe_run_for_local_tensor,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
@ -44,6 +54,11 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
c10d_functional = torch.ops.c10d_functional
|
||||
|
||||
|
||||
@maybe_run_for_local_tensor
|
||||
def map_tensor_for_rank(tensor, rank, func):
|
||||
return func(tensor, rank)
|
||||
|
||||
|
||||
class DummyMLP(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
@ -592,7 +607,12 @@ class DTensorTest(DTensorTestBase):
|
||||
self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
|
||||
self.assertEqual(sharded_tensor.placements, placements)
|
||||
local_tensor = sharded_tensor.to_local()
|
||||
self.assertEqual(local_tensor, full_tensor[range(self.rank, self.rank + 1), :])
|
||||
self.assertEqual(
|
||||
local_tensor,
|
||||
map_tensor_for_rank(
|
||||
full_tensor, self.rank, lambda ft, r: ft[range(r, r + 1), :]
|
||||
),
|
||||
)
|
||||
|
||||
# Shard by column
|
||||
placements = [Shard(1)]
|
||||
@ -600,7 +620,12 @@ class DTensorTest(DTensorTestBase):
|
||||
self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
|
||||
self.assertEqual(sharded_tensor.placements, placements)
|
||||
local_tensor = sharded_tensor.to_local()
|
||||
self.assertEqual(local_tensor, full_tensor[:, range(self.rank, self.rank + 1)])
|
||||
self.assertEqual(
|
||||
local_tensor,
|
||||
map_tensor_for_rank(
|
||||
full_tensor, self.rank, lambda ft, r: ft[:, range(r, r + 1)]
|
||||
),
|
||||
)
|
||||
|
||||
# assert full tensor is not changed
|
||||
self.assertEqual(full_tensor, torch.arange(ws * ws).reshape(ws, ws))
|
||||
@ -620,6 +645,105 @@ class DTensorTest(DTensorTestBase):
|
||||
self.assertEqual(local_tensor.item(), self.rank)
|
||||
|
||||
|
||||
class LocalDTensorTest(DTensorTest):
|
||||
def get_local_tensor_mode(self):
|
||||
return LocalTensorMode(frozenset(range(0, self.world_size)))
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return torch.SymInt(LocalIntNode({r: r for r in range(self.world_size)}))
|
||||
|
||||
@rank.setter
|
||||
def rank(self, rank):
|
||||
pass
|
||||
|
||||
def join_or_run(self, fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self):
|
||||
fn()
|
||||
|
||||
return types.MethodType(wrapper, self)
|
||||
|
||||
def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
|
||||
dist.init_process_group("fake", rank=0, world_size=self.world_size)
|
||||
self._pg = c10d._get_default_group()
|
||||
|
||||
def destroy_pg(self, device_id: Optional[int] = None) -> None:
|
||||
dist.destroy_process_group(self._pg)
|
||||
self._pg = None
|
||||
|
||||
def _spawn_processes(self) -> None:
|
||||
pass
|
||||
|
||||
def test_dtensor_constructor(self):
|
||||
pass
|
||||
|
||||
def test_meta_dtensor(self):
|
||||
pass
|
||||
|
||||
def test_modules_w_meta_dtensor(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_stride(self):
|
||||
pass
|
||||
|
||||
def test_from_local(self):
|
||||
pass
|
||||
|
||||
def test_from_local_uneven_sharding(self):
|
||||
pass
|
||||
|
||||
def test_from_local_uneven_sharding_raise_error(self):
|
||||
pass
|
||||
|
||||
def test_from_local_negative_dim(self):
|
||||
pass
|
||||
|
||||
def test_to_local(self):
|
||||
pass
|
||||
|
||||
def test_to_local_grad_hint(self):
|
||||
pass
|
||||
|
||||
def test_full_tensor_sync(self):
|
||||
pass
|
||||
|
||||
def test_full_tensor_grad_hint(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_new_empty_strided(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_async_output(self):
|
||||
pass
|
||||
|
||||
def test_from_local_then_to_local(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_spec_read_only_after_set(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_spec_hash(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_properties(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_save_load(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_save_load_import(self):
|
||||
pass
|
||||
|
||||
def test_shard_tensor_2d(self):
|
||||
with self.get_local_tensor_mode():
|
||||
super().test_shard_tensor_2d()
|
||||
|
||||
def test_shard_tensor(self):
|
||||
with self.get_local_tensor_mode():
|
||||
super().test_shard_tensor()
|
||||
|
||||
|
||||
class DTensorMeshTest(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import copy
|
||||
import re
|
||||
import unittest
|
||||
import warnings
|
||||
@ -8,6 +9,7 @@ import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.testing._internal.common_methods_invocations as common_ops
|
||||
from torch.distributed._local_tensor import LocalTensorMode, reconcile_args
|
||||
from torch.distributed.tensor import (
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
@ -21,7 +23,7 @@ from torch.testing._internal.common_device_type import (
|
||||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db
|
||||
from torch.testing._internal.common_utils import run_tests, suppress_warnings
|
||||
from torch.testing._internal.common_utils import run_tests, suppress_warnings, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorConverter,
|
||||
DTensorOpTestBase,
|
||||
@ -49,7 +51,7 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
|
||||
return (op_name, variant_name, device_type, dtypes, False)
|
||||
|
||||
|
||||
def skipOps(test_case_name, base_test_name, to_skip):
|
||||
def skipOps(op_db, test_case_name, base_test_name, to_skip):
|
||||
all_opinfos = op_db
|
||||
for xfail in to_skip:
|
||||
op_name, variant_name, device_type, dtypes, expected_failure = xfail
|
||||
@ -88,6 +90,34 @@ def skipOps(test_case_name, base_test_name, to_skip):
|
||||
return wrapped
|
||||
|
||||
|
||||
def repurpose_ops(op_db, base_test_name, derived_test_name):
|
||||
"""
|
||||
Copies op info database and for the decorators that applied to base test class updates
|
||||
them to apply to derived test class. The class update is required because decorators are applied
|
||||
only if the class name matches (it doesn't consider base classes).
|
||||
|
||||
Specifically we use this function to create two test classes (one for multi-threaded and one for
|
||||
local tensor flavors) that share common test body but different rules for skip or fail.
|
||||
|
||||
Args:
|
||||
op_db: List of OpInfo objects to be repurposed.
|
||||
base_test_name: The original test class name to be replaced.
|
||||
derived_test_name: The new test class name to set in decorators.
|
||||
|
||||
Returns:
|
||||
list: A new list of OpInfo objects with updated target class names for the
|
||||
decorator.
|
||||
"""
|
||||
repurposed_ops = []
|
||||
for opinfo in op_db:
|
||||
opinfo_copy = copy.deepcopy(opinfo)
|
||||
for decorator in list(opinfo_copy.decorators):
|
||||
if hasattr(decorator, "cls_name") and decorator.cls_name == base_test_name:
|
||||
decorator.cls_name = derived_test_name
|
||||
repurposed_ops.append(opinfo_copy)
|
||||
return repurposed_ops
|
||||
|
||||
|
||||
# Re-generate this failed list, turn on dry_run of the below func
|
||||
# check_dtensor_func(self, test, op, dry_run=True), then run sth
|
||||
# like python test/distributed/tensor/test_dtensor_ops.py > failed.expect
|
||||
@ -162,7 +192,6 @@ dtensor_fails = {
|
||||
xfail("fmin"),
|
||||
xfail("frexp"),
|
||||
xfail("full"),
|
||||
xfail("full_like"),
|
||||
xfail("geometric"),
|
||||
xfail("geqrf"),
|
||||
xfail("grid_sampler_2d"),
|
||||
@ -226,7 +255,6 @@ dtensor_fails = {
|
||||
xfail("masked_select"),
|
||||
xfail("masked.argmax"),
|
||||
xfail("masked.argmin"),
|
||||
xfail("masked.cumprod"),
|
||||
xfail("masked.logsumexp"),
|
||||
xfail("masked.median"),
|
||||
xfail("matrix_exp"),
|
||||
@ -244,8 +272,6 @@ dtensor_fails = {
|
||||
xfail("native_batch_norm"),
|
||||
xfail("narrow_copy"),
|
||||
xfail("ne"),
|
||||
xfail("new_empty"),
|
||||
xfail("new_empty_strided"),
|
||||
xfail("transpose"),
|
||||
xfail("nn.functional.adaptive_avg_pool1d"),
|
||||
xfail("nn.functional.adaptive_avg_pool2d"),
|
||||
@ -272,8 +298,6 @@ dtensor_fails = {
|
||||
xfail("nn.functional.cosine_similarity"),
|
||||
xfail("nn.functional.ctc_loss"),
|
||||
xfail("nn.functional.dropout"),
|
||||
xfail("nn.functional.dropout2d"),
|
||||
xfail("nn.functional.dropout3d"),
|
||||
xfail("nn.functional.elu"),
|
||||
xfail("nn.functional.fractional_max_pool2d"),
|
||||
xfail("nn.functional.fractional_max_pool3d"),
|
||||
@ -307,7 +331,6 @@ dtensor_fails = {
|
||||
xfail("nn.functional.multi_margin_loss"),
|
||||
xfail("nn.functional.multilabel_margin_loss"),
|
||||
xfail("nn.functional.multilabel_soft_margin_loss"),
|
||||
xfail("nn.functional.multi_head_attention_forward"),
|
||||
xfail("nn.functional.pad", "reflect"),
|
||||
xfail("nn.functional.pad", "replicate"),
|
||||
xfail("nn.functional.pad", "replicate_negative"),
|
||||
@ -482,13 +505,21 @@ dtensor_fails = {
|
||||
skip("_segment_reduce", "offsets"),
|
||||
# TODO: fix the following ops
|
||||
skip("squeeze"),
|
||||
# These must be skipped as their contents are nondeterministic
|
||||
skip("empty"),
|
||||
skip("empty_strided"),
|
||||
skip("empty_like"),
|
||||
skip("empty_permuted"),
|
||||
skip("new_empty"),
|
||||
skip("new_empty_strided"),
|
||||
}
|
||||
|
||||
dtensor_multi_threaded_fails = {
|
||||
xfail("full_like"),
|
||||
xfail("nn.functional.dropout2d"),
|
||||
xfail("nn.functional.dropout3d"),
|
||||
xfail("masked.cumprod"),
|
||||
skip("nn.functional.multi_head_attention_forward"),
|
||||
}
|
||||
|
||||
# Add a list of ops that are currently failing BW pass
|
||||
skip_bw = [
|
||||
@ -507,7 +538,13 @@ OP_DB_WORLD_SIZE = 4
|
||||
DEVICE_TYPE = "cpu"
|
||||
|
||||
|
||||
class TestDTensorOps(DTensorOpTestBase):
|
||||
class TestDTensorOps(TestCase):
|
||||
__test__ = False
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls.__test__ = True
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return OP_DB_WORLD_SIZE
|
||||
@ -535,14 +572,6 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
|
||||
self.check_dtensor_func(test, op)
|
||||
|
||||
# only allow float dytpe for now, we can relax this constraint
|
||||
# when feel necessary later (i.e when adding quantization support).
|
||||
@suppress_warnings
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails)
|
||||
def test_dtensor_op_db(self, dtype, op):
|
||||
self.run_opinfo_test(dtype, op)
|
||||
|
||||
def assert_ref_dtensor_equal(self, dtensor_rs, rs):
|
||||
flat_dtensor_rs = pytree.tree_leaves(dtensor_rs)
|
||||
flat_rs = pytree.tree_leaves(rs)
|
||||
@ -567,6 +596,9 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
|
||||
self.assertEqualOnRank(dtensor_r, r)
|
||||
|
||||
def assertEqualOnRank(self, x, y, msg=None, *, rank=0) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def run_dtensor_crossref(self, func, args, kwargs):
|
||||
to_dtensor = DTensorConverter(self.mesh, args, kwargs)
|
||||
|
||||
@ -580,7 +612,8 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
return res
|
||||
|
||||
# TODO: also handle cases where func raise an exception
|
||||
rs = func(*args, **kwargs)
|
||||
op_args, op_kwargs = reconcile_args(args, kwargs)
|
||||
rs = func(*op_args, **op_kwargs)
|
||||
rs = concat_res_if_necessary(func, rs)
|
||||
|
||||
def to_replicate(e: object) -> object:
|
||||
@ -635,12 +668,12 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
self.assert_ref_dtensor_equal(dtensor_rs, rs)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"failed to convert args to DTensor; "
|
||||
f"Failed to convert args to DTensor; "
|
||||
f"originally (*{args}, **{kwargs})"
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"{str(e)}\n\nfailed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
|
||||
f"{str(e)}\n\nFailed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
|
||||
) from e
|
||||
return rs
|
||||
|
||||
@ -656,7 +689,7 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
else:
|
||||
print(f"xfail('{opinfo.name}'),")
|
||||
|
||||
def test_one_hot(self):
|
||||
def run_one_hot(self):
|
||||
ops = [op for op in op_db if op.name == "nn.functional.one_hot"]
|
||||
assert len(ops) == 1
|
||||
op = ops[0]
|
||||
@ -668,7 +701,7 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
sample_inputs_filter=lambda s: s.kwargs["num_classes"] != -1,
|
||||
)
|
||||
|
||||
def test_mean(self):
|
||||
def run_mean(self):
|
||||
self.mesh = init_device_mesh(DEVICE_TYPE, (self.world_size,))
|
||||
|
||||
shape = [2 * self.world_size + 1, 2 * self.world_size]
|
||||
@ -692,6 +725,7 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
full_tensor = mean.full_tensor()
|
||||
|
||||
self.assertEqual(full_tensor, tensor.mean(dim=reduce_dim))
|
||||
|
||||
if is_evenly_shardable:
|
||||
self.assertTrue("P->R" in debug_mode.debug_string())
|
||||
else:
|
||||
@ -720,9 +754,76 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
_ = torch.ops.aten.embedding.default(weight_dtensor, input_dtensor)
|
||||
|
||||
|
||||
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
|
||||
instantiate_device_type_tests(TestDTensorOps, globals(), only_for=(DEVICE_TYPE,))
|
||||
class TestMultiThreadedDTensorOps(DTensorOpTestBase, TestDTensorOps):
|
||||
_op_db = repurpose_ops(op_db, "TestDTensorOps", "TestMultiThreadedDTensorOps")
|
||||
|
||||
@suppress_warnings
|
||||
@ops(_op_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps(
|
||||
_op_db,
|
||||
"TestMultiThreadedDTensorOps",
|
||||
"test_dtensor_op_db",
|
||||
dtensor_fails | dtensor_multi_threaded_fails,
|
||||
)
|
||||
def test_dtensor_op_db(self, dtype, op):
|
||||
self.run_opinfo_test(dtype, op)
|
||||
|
||||
def test_mean(self):
|
||||
self.run_mean()
|
||||
|
||||
def test_one_hot(self):
|
||||
self.run_one_hot()
|
||||
|
||||
|
||||
class TestLocalDTensorOps(TestDTensorOps):
|
||||
_op_db = repurpose_ops(op_db, "TestDTensorOps", "TestLocalDTensorOps")
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.distributed.init_process_group("fake", rank=0, world_size=self.world_size)
|
||||
self.fake_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
dist.destroy_process_group()
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
@suppress_warnings
|
||||
@ops(_op_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps(
|
||||
_op_db,
|
||||
"TestLocalDTensorOps",
|
||||
"test_dtensor_op_db",
|
||||
dtensor_fails,
|
||||
)
|
||||
def test_dtensor_op_db(self, dtype, op):
|
||||
self.run_opinfo_test(dtype, op)
|
||||
|
||||
def test_mean(self):
|
||||
with LocalTensorMode(frozenset(range(0, self.world_size))):
|
||||
self.run_mean()
|
||||
|
||||
def test_one_hot(self):
|
||||
self.run_one_hot()
|
||||
|
||||
def run_opinfo_test(
|
||||
self, dtype, op, requires_grad=True, sample_inputs_filter=lambda s: True
|
||||
):
|
||||
with LocalTensorMode(frozenset(range(0, self.world_size))):
|
||||
super().run_opinfo_test(dtype, op, requires_grad, sample_inputs_filter)
|
||||
|
||||
def assertEqualOnRank(self, x, y, msg=None, *, rank=0):
|
||||
self.assertEqual(x, y, msg)
|
||||
|
||||
|
||||
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
|
||||
instantiate_device_type_tests(
|
||||
TestMultiThreadedDTensorOps, globals(), only_for=(DEVICE_TYPE,)
|
||||
)
|
||||
|
||||
instantiate_device_type_tests(TestLocalDTensorOps, globals(), only_for=(DEVICE_TYPE,))
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
415
test/distributed/test_local_tensor.py
Normal file
415
test/distributed/test_local_tensor.py
Normal file
@ -0,0 +1,415 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._local_tensor import (
|
||||
local_tensor_mode,
|
||||
LocalTensor,
|
||||
LocalTensorMode,
|
||||
)
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
init_device_mesh,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class LocalTensorTestBase(TestCase):
|
||||
def assertEqual(self, lhs, rhs, **kwargs):
|
||||
mode = local_tensor_mode()
|
||||
with nullcontext() if mode is None else mode.disable():
|
||||
if isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor):
|
||||
assert isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor)
|
||||
super().assertEqual(lhs._ranks, rhs._ranks)
|
||||
for r in lhs._ranks:
|
||||
super().assertEqual(
|
||||
lhs._local_tensors[r],
|
||||
rhs._local_tensors[r],
|
||||
lambda m: f"rank {r}: {m}",
|
||||
)
|
||||
elif isinstance(lhs, LocalTensor) or isinstance(rhs, LocalTensor):
|
||||
lhs, rhs = (lhs, rhs) if isinstance(lhs, LocalTensor) else (rhs, lhs)
|
||||
for r in lhs._ranks:
|
||||
super().assertEqual(
|
||||
lhs._local_tensors[r], rhs, lambda m: f"rank {r}: {m}"
|
||||
)
|
||||
else:
|
||||
return super().assertEqual(lhs, rhs, **kwargs)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
raise NotImplementedError("override world-size in your subclass")
|
||||
|
||||
def build_device_mesh(self) -> DeviceMesh:
|
||||
return init_device_mesh("cpu", (self.world_size,))
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch.distributed.init_process_group(
|
||||
# TODO: test other ranks too
|
||||
"fake",
|
||||
rank=0,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
dist.destroy_process_group()
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
|
||||
class TestLocalTensorWorld2(LocalTensorTestBase):
|
||||
world_size = 2
|
||||
|
||||
def test_local_tensor_dtype_consistency(self):
|
||||
"""Test that LocalTensor enforces dtype consistency."""
|
||||
device = torch.device("cpu")
|
||||
shape = (2, 3)
|
||||
|
||||
inconsistent_tensors = {
|
||||
0: torch.randn(shape, dtype=torch.float32, device=device),
|
||||
1: torch.randn(
|
||||
shape, dtype=torch.float64, device=device
|
||||
), # Different dtype
|
||||
}
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
LocalTensor(inconsistent_tensors)
|
||||
|
||||
def test_local_tensor_creation_fails_with_grad_tensors(self):
|
||||
"""Test that LocalTensor creation fails when local tensors have requires_grad=True."""
|
||||
device = torch.device("cpu")
|
||||
shape = (2, 3)
|
||||
dtype = torch.float32
|
||||
|
||||
# Create sample local tensors for different ranks
|
||||
local_tensors = {
|
||||
0: torch.randn(shape, dtype=dtype, device=device, requires_grad=True),
|
||||
1: torch.randn(shape, dtype=dtype, device=device, requires_grad=True),
|
||||
}
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
LocalTensor(local_tensors)
|
||||
|
||||
# TODO: test flatten/unflatten
|
||||
|
||||
def test_basic_arithmetic_operations(self):
|
||||
"""Test basic arithmetic operations on LocalTensors."""
|
||||
device = torch.device("cpu")
|
||||
shape = (2, 3)
|
||||
dtype = torch.float32
|
||||
|
||||
# Create identical local tensors for consistency tests
|
||||
base_tensor = torch.randn(shape, dtype=dtype, device=device)
|
||||
identical_local_tensors = {
|
||||
0: base_tensor.clone(),
|
||||
1: base_tensor.clone(),
|
||||
}
|
||||
|
||||
lt1 = LocalTensor(identical_local_tensors)
|
||||
lt2 = LocalTensor(identical_local_tensors)
|
||||
|
||||
# Test addition
|
||||
result_add = lt1 + lt2
|
||||
self.assertIsInstance(result_add, LocalTensor)
|
||||
self.assertEqual(len(result_add._local_tensors), 2)
|
||||
|
||||
# Verify the operation was applied to each local tensor
|
||||
for rank in identical_local_tensors.keys():
|
||||
expected = identical_local_tensors[rank] + identical_local_tensors[rank]
|
||||
self.assertEqual(result_add._local_tensors[rank], expected)
|
||||
|
||||
# Test multiplication
|
||||
result_mul = lt1 * 2.0
|
||||
self.assertIsInstance(result_mul, LocalTensor)
|
||||
for rank in identical_local_tensors.keys():
|
||||
expected = identical_local_tensors[rank] * 2.0
|
||||
self.assertEqual(result_mul._local_tensors[rank], expected)
|
||||
|
||||
# TODO: consider an op-info test; we don't actually need to cover all ops
|
||||
# but it will help make sure views and more exotic things are done
|
||||
# correctly (in standard subclass style)
|
||||
|
||||
def test_mixed_operations_with_regular_tensors(self):
|
||||
"""Test operations between LocalTensors and regular tensors."""
|
||||
device = torch.device("cpu")
|
||||
shape = (2, 3)
|
||||
dtype = torch.float32
|
||||
|
||||
# Create identical local tensors for consistency tests
|
||||
base_tensor = torch.randn(shape, dtype=dtype, device=device)
|
||||
identical_local_tensors = {
|
||||
0: base_tensor.clone(),
|
||||
1: base_tensor.clone(),
|
||||
}
|
||||
|
||||
lt = LocalTensor(identical_local_tensors)
|
||||
regular_tensor = torch.ones_like(identical_local_tensors[0])
|
||||
|
||||
# Test LocalTensor + regular tensor
|
||||
result = lt + regular_tensor
|
||||
self.assertIsInstance(result, LocalTensor)
|
||||
|
||||
for rank in identical_local_tensors.keys():
|
||||
expected = identical_local_tensors[rank] + regular_tensor
|
||||
self.assertEqual(result._local_tensors[rank], expected)
|
||||
|
||||
def test_local_tensor_mode(self):
|
||||
"""Test LocalTensorMode functionality."""
|
||||
device = torch.device("cpu")
|
||||
shape = (2, 3)
|
||||
dtype = torch.float32
|
||||
|
||||
# Create identical local tensors for consistency tests
|
||||
base_tensor = torch.randn(shape, dtype=dtype, device=device)
|
||||
identical_local_tensors = {
|
||||
0: base_tensor.clone(),
|
||||
1: base_tensor.clone(),
|
||||
}
|
||||
|
||||
lt = LocalTensor(identical_local_tensors)
|
||||
|
||||
with LocalTensorMode(lt._ranks):
|
||||
result = lt + 1.0
|
||||
self.assertIsInstance(result, LocalTensor)
|
||||
|
||||
regular = torch.ones(2, 2)
|
||||
regular_result = regular + 1.0
|
||||
self.assertIsInstance(regular, LocalTensor)
|
||||
self.assertIsInstance(regular_result, LocalTensor)
|
||||
|
||||
def test_empty_local_tensors(self):
|
||||
"""Test behavior with empty local tensors dict."""
|
||||
# TODO: raise a better error here
|
||||
with self.assertRaises(StopIteration): # next() on empty iterator
|
||||
LocalTensor({})
|
||||
|
||||
def test_collectives_within_local_tensor_mode(self):
|
||||
"""Test that collective operations work within LocalTensorMode context."""
|
||||
test_tensors = {
|
||||
0: torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
|
||||
1: torch.tensor([[5.0, 6.0], [7.0, 8.0]]),
|
||||
}
|
||||
lt = LocalTensor(test_tensors)
|
||||
fake_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
with LocalTensorMode(lt._ranks):
|
||||
# Test all_reduce within mode
|
||||
lt_sum = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
|
||||
dist.all_reduce(lt_sum, group=fake_pg)
|
||||
|
||||
expected_sum = torch.tensor([[6.0, 8.0], [10.0, 12.0]])
|
||||
for rank in test_tensors.keys():
|
||||
self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
|
||||
|
||||
# Test broadcast within mode
|
||||
lt_broadcast = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
|
||||
dist.broadcast(lt_broadcast, src=0, group=fake_pg)
|
||||
|
||||
for rank in test_tensors.keys():
|
||||
self.assertEqual(lt_broadcast._local_tensors[rank], test_tensors[0])
|
||||
|
||||
# Test that regular operations still work
|
||||
result = lt + 1.0
|
||||
self.assertIsInstance(result, LocalTensor)
|
||||
|
||||
def test_scalar_mul_reduction_bug(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
mesh = self.build_device_mesh()
|
||||
|
||||
tensor = torch.tensor([10, 10]).float()
|
||||
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
|
||||
y = dt.sum() * 1 # noqa: F841
|
||||
|
||||
tensor = torch.arange(10).reshape(10, 1).float().requires_grad_()
|
||||
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
|
||||
|
||||
print(dt.sum() * 1, dt.sum() * 2, dt.sum() * 3)
|
||||
|
||||
def test_uneven_sharding_mean_bug(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
mesh = self.build_device_mesh()
|
||||
tensor = torch.arange(12).reshape(-1, 4).float()
|
||||
|
||||
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
|
||||
|
||||
mean = dt.mean()
|
||||
self.assertEqual(mean.placements, [Replicate()])
|
||||
full = mean.full_tensor()
|
||||
self.assertEqual(tensor.mean(), full)
|
||||
|
||||
def test_uneven_sharding_prod(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
mesh = self.build_device_mesh()
|
||||
tensor = (torch.arange(12) + 1).reshape(-1, 4).float()
|
||||
|
||||
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
|
||||
|
||||
x = dt.prod()
|
||||
full = x.full_tensor()
|
||||
self.assertEqual(tensor.prod(), full)
|
||||
|
||||
def test_even_sharding_mean_is_partial(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
mesh = self.build_device_mesh()
|
||||
tensor = torch.arange(16).reshape(4, 4).float()
|
||||
|
||||
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
|
||||
|
||||
mean = dt.mean()
|
||||
full = mean.full_tensor()
|
||||
self.assertEqual(tensor.mean(), full)
|
||||
self.assertEqual(mean.placements, [Partial("avg")])
|
||||
|
||||
|
||||
class TestLocalTensorWorld3(LocalTensorTestBase):
|
||||
world_size = 3
|
||||
|
||||
def test_collective_reduction_operations(self):
|
||||
"""Test different reduction operations for all_reduce."""
|
||||
# Create different tensors for each rank with simple values for testing
|
||||
test_tensors = {
|
||||
0: torch.tensor([[1.0, 4.0], [2.0, 5.0]]),
|
||||
1: torch.tensor([[2.0, 1.0], [3.0, 6.0]]),
|
||||
2: torch.tensor([[3.0, 2.0], [1.0, 4.0]]),
|
||||
}
|
||||
|
||||
fake_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
# Test SUM reduction
|
||||
lt_sum = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
|
||||
dist.all_reduce(lt_sum, op=dist.ReduceOp.SUM, group=fake_pg)
|
||||
expected_sum = torch.tensor([[6.0, 7.0], [6.0, 15.0]]) # Sum of all tensors
|
||||
for rank in test_tensors.keys():
|
||||
self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
|
||||
|
||||
# Test MAX reduction
|
||||
lt_max = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
|
||||
dist.all_reduce(lt_max, op=dist.ReduceOp.MAX, group=fake_pg)
|
||||
expected_max = torch.tensor([[3.0, 4.0], [3.0, 6.0]]) # Max across all tensors
|
||||
for rank in test_tensors.keys():
|
||||
self.assertEqual(lt_max._local_tensors[rank], expected_max)
|
||||
|
||||
# Test MIN reduction
|
||||
lt_min = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
|
||||
dist.all_reduce(lt_min, op=dist.ReduceOp.MIN, group=fake_pg)
|
||||
expected_min = torch.tensor([[1.0, 1.0], [1.0, 4.0]]) # Min across all tensors
|
||||
for rank in test_tensors.keys():
|
||||
self.assertEqual(lt_min._local_tensors[rank], expected_min)
|
||||
|
||||
def test_all_reduce_collective(self):
|
||||
"""Test that all_reduce collective operation works correctly with LocalTensor."""
|
||||
# Create different tensors for each rank
|
||||
different_tensors = {
|
||||
0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]),
|
||||
2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]),
|
||||
}
|
||||
|
||||
fake_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
# Test all_reduce with SUM (default)
|
||||
lt_sum = LocalTensor({k: v.clone() for k, v in different_tensors.items()})
|
||||
lt_sum = lt_sum + 1
|
||||
dist.all_reduce(lt_sum, group=fake_pg)
|
||||
|
||||
# Verify all ranks have the sum of all tensors (after adding 1 to each)
|
||||
expected_sum = torch.tensor([[114.0, 225.0, 336.0], [447.0, 558.0, 669.0]])
|
||||
for rank in different_tensors.keys():
|
||||
self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
|
||||
|
||||
def test_broadcast_collective(self):
|
||||
"""Test that broadcast collective operation works correctly with LocalTensor."""
|
||||
# Create different tensors for each rank
|
||||
different_tensors = {
|
||||
0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]),
|
||||
2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]),
|
||||
}
|
||||
|
||||
fake_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
# Test broadcast from rank 1
|
||||
lt_broadcast = LocalTensor({k: v.clone() for k, v in different_tensors.items()})
|
||||
dist.broadcast(lt_broadcast, src=1, group=fake_pg)
|
||||
|
||||
# Verify all ranks have rank 1's original tensor
|
||||
expected_broadcast = different_tensors[1]
|
||||
for rank in different_tensors.keys():
|
||||
self.assertEqual(lt_broadcast._local_tensors[rank], expected_broadcast)
|
||||
|
||||
def test_all_gather_collective(self):
|
||||
"""Test that all_gather collective operation works correctly with LocalTensor."""
|
||||
# Create different tensors for each rank
|
||||
different_tensors = {
|
||||
0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]),
|
||||
2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]),
|
||||
}
|
||||
|
||||
fake_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
# Test all_gather
|
||||
lt_gather = LocalTensor(different_tensors)
|
||||
tensor_list = [torch.zeros_like(lt_gather) for _ in range(3)]
|
||||
|
||||
dist.all_gather(tensor_list, lt_gather, group=fake_pg)
|
||||
|
||||
# Verify each position in tensor_list contains the corresponding rank's tensor
|
||||
self.assertEqual(tensor_list[0], different_tensors[0])
|
||||
self.assertEqual(tensor_list[1], different_tensors[1])
|
||||
self.assertEqual(tensor_list[2], different_tensors[2])
|
||||
|
||||
|
||||
class TestLocalTensorWorld4(LocalTensorTestBase):
|
||||
world_size = 4
|
||||
|
||||
def test_dtensor_cat(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
t1 = torch.arange(16).view(4, 4).float()
|
||||
d1 = distribute_tensor(t1, device_mesh, [Replicate()])
|
||||
t2 = (torch.arange(16) + 16).view(4, 4).float()
|
||||
d2 = distribute_tensor(t2, device_mesh, [Shard(0)])
|
||||
|
||||
local_res = torch.cat([t1, t2], dim=-1)
|
||||
dist_res = torch.cat([d1, d2], dim=-1)
|
||||
full_tensor = dist_res.full_tensor()
|
||||
self.assertEqual(full_tensor, local_res)
|
||||
|
||||
|
||||
class TestLocalTensorWorld8(LocalTensorTestBase):
|
||||
world_size = 8
|
||||
|
||||
def test_dtensor_addmm(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
shard_spec = [Shard(0)]
|
||||
replica_spec = [Replicate()]
|
||||
|
||||
tensor_to_shard = torch.randn(12, 8)
|
||||
mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
|
||||
tensor_to_replicate = torch.randn(8, 4)
|
||||
mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec)
|
||||
input_tensor = torch.randn(4)
|
||||
input = distribute_tensor(input_tensor, device_mesh, replica_spec)
|
||||
|
||||
dist_res = torch.addmm(input, mat1, mat2)
|
||||
local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate)
|
||||
full_tensor = dist_res.full_tensor()
|
||||
self.assertEqual(full_tensor, local_res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -7915,9 +7915,13 @@ torch.cuda.synchronize()
|
||||
|
||||
nt = torch.nested.nested_tensor(
|
||||
[
|
||||
torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype)
|
||||
if dtype is torch.bool
|
||||
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
|
||||
(
|
||||
torch.randint(
|
||||
2, (n, *post_seq_len_shape), device=device, dtype=dtype
|
||||
)
|
||||
if dtype is torch.bool
|
||||
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
|
||||
)
|
||||
for n in range(2, 9)
|
||||
],
|
||||
layout=torch.jagged,
|
||||
@ -7966,9 +7970,13 @@ torch.cuda.synchronize()
|
||||
|
||||
nt = torch.nested.nested_tensor(
|
||||
[
|
||||
torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype)
|
||||
if dtype is torch.bool
|
||||
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
|
||||
(
|
||||
torch.randint(
|
||||
2, (n, *post_seq_len_shape), device=device, dtype=dtype
|
||||
)
|
||||
if dtype is torch.bool
|
||||
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
|
||||
)
|
||||
for n in range(2, 9)
|
||||
],
|
||||
layout=torch.jagged,
|
||||
@ -8713,7 +8721,7 @@ COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
|
||||
# min() / max(): weird bug
|
||||
XFailRule(
|
||||
error_type=AttributeError,
|
||||
error_msg="'ConstantIntNode' object has no attribute 'add'",
|
||||
error_msg="'NestedIntNode' object has no attribute 'add'",
|
||||
op_match_fn=lambda device, op: (
|
||||
op.full_name in {"max.reduction_with_dim", "min.reduction_with_dim"}
|
||||
),
|
||||
@ -8730,7 +8738,7 @@ COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
|
||||
# copysign(): formula is broken for (T, NT) broadcasting
|
||||
XFailRule(
|
||||
error_type=AttributeError,
|
||||
error_msg="'ConstantIntNode' object has no attribute 'add'",
|
||||
error_msg="'NestedIntNode' object has no attribute 'add'",
|
||||
op_match_fn=lambda device, op: (op.full_name == "copysign"),
|
||||
sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name),
|
||||
name="broken_copysign_compile_backward",
|
||||
|
||||
@ -15,7 +15,11 @@ TORCH_LIBRARY(c10d, m) {
|
||||
m.class_<Work>("Work")
|
||||
.def(torch::init<>())
|
||||
.def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
|
||||
m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
|
||||
m.class_<ReduceOp>("ReduceOp")
|
||||
.def(torch::init<>())
|
||||
.def("op", [](const c10::intrusive_ptr<ReduceOp>& self) -> int64_t {
|
||||
return self->op_;
|
||||
});
|
||||
m.def(
|
||||
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
m.def(
|
||||
|
||||
747
torch/distributed/_local_tensor/__init__.py
Normal file
747
torch/distributed/_local_tensor/__init__.py
Normal file
@ -0,0 +1,747 @@
|
||||
from ast import Call
|
||||
|
||||
|
||||
"""
|
||||
A LocalTensor is a tensor subclass which simulates a tensor that is
|
||||
distributed across SPMD ranks. A LocalTensor might be size N, but in fact
|
||||
there are world_size shards/replicas of it stored internally. When you do a
|
||||
plain PyTorch operation on it, we apply the operation to each shard; when you
|
||||
do a collective, we do the mathematically equivalent operation on the local
|
||||
shards. A LocalTensor is associated with a list of ranks which specify
|
||||
which ranks it holds local tensors for.
|
||||
|
||||
NB, this is NOT a DataParallel like abstraction where you can run operations
|
||||
on multiple different GPUs. It is intended purely for *debugging* purposes,
|
||||
the overhead is almost certainly too high to keep eight GPUs (even the C++
|
||||
autograd needs multithreading to keep up!) (It might potentially be possible
|
||||
to trace through this with torch.compile and then compile it with CUDA graphs
|
||||
but this is currently a non-goal.)
|
||||
|
||||
We do not directly handling MPMD. However in practice even in SPMD you may
|
||||
encounter divergence in behavior per rank (for example, uneven sharding
|
||||
across ranks). To support scenarios like this, we provide a helper decorator
|
||||
that allows you to run a function with no side effects for each LocalTensor
|
||||
shard and combine results back into LocalTensor or LocalIntNode.
|
||||
|
||||
NB: This is a torch dispatch Tensor subclass, as we want to assume that autograd
|
||||
is SPMD, so we run it once, and dispatch the inner autograd calls to the individual
|
||||
local shards.
|
||||
|
||||
NOTE ABOUT MESH: This subclass requires collectives that are issued to it to
|
||||
respect a DeviceMesh like abstraction. The reason for this is that when
|
||||
DTensor issues us a collective for a particular rank, you will be asked to do
|
||||
this on a specific process group which involves some ranks. However, this
|
||||
will only be for the LOCAL PG that this particular rank is participating in;
|
||||
there will be a bunch of other PGs for other nodes that you don't get to see.
|
||||
We need to be able to reverse engineer all of the collectives that don't
|
||||
involve the current local rank here to actually issue them. This can be done
|
||||
two ways: (1) looking at the participating local ranks in the PG and computing
|
||||
the complement which specifies all the other collectives you have to run, or
|
||||
(2) retrieving the device mesh axis corresponding to the PG for this rank, and
|
||||
then running all the fibers for this.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import operator
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, Generator, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Size, SymBool, SymInt, Tensor
|
||||
from torch._C import DispatchKey, DispatchKeySet
|
||||
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
|
||||
from torch.distributed import DeviceMesh
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
from torch.fx.experimental._constant_symnode import ConstantIntNode
|
||||
from torch.nested._internal.nested_int import NestedIntNode
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._python_dispatch import return_and_correct_aliasing, TorchDispatchMode
|
||||
from torch.utils.checkpoint import get_device_states, set_device_states
|
||||
|
||||
|
||||
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
|
||||
|
||||
|
||||
from . import _c10d
|
||||
|
||||
|
||||
def _int_on_rank(i: "LocalIntNode | ConstantIntNode", r: int) -> int:
|
||||
if isinstance(i, LocalIntNode):
|
||||
return i._local_ints[r]
|
||||
elif isinstance(i, ConstantIntNode):
|
||||
return i.val
|
||||
else:
|
||||
raise AssertionError(type(i))
|
||||
|
||||
|
||||
def _check_for_subclass(flat_args: Sequence[object]) -> bool:
|
||||
return any(_check_for_subclass_arg(x) for x in flat_args)
|
||||
|
||||
|
||||
def _check_for_subclass_arg(x: object) -> bool:
|
||||
return (
|
||||
not isinstance(x, LocalTensor)
|
||||
and isinstance(x, Tensor)
|
||||
and type(x) not in (Tensor, torch.nn.Parameter, torch.nn.Buffer)
|
||||
)
|
||||
|
||||
|
||||
def _map_to_rank_local_val(val: Any, rank: int) -> Any:
|
||||
if isinstance(val, LocalTensor):
|
||||
return val._local_tensors[rank]
|
||||
if isinstance(val, SymInt) and isinstance(val.node, LocalIntNode):
|
||||
return val.node._local_ints[rank]
|
||||
return val
|
||||
|
||||
|
||||
def _for_each_rank_run_func(
|
||||
func: Callable[..., Any],
|
||||
ranks: frozenset[int],
|
||||
args: Sequence[Any],
|
||||
kwargs: dict[str, Any],
|
||||
*,
|
||||
alias: bool = True,
|
||||
) -> Any:
|
||||
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
cpu_state = torch.get_rng_state()
|
||||
devices, states = get_device_states((args, kwargs))
|
||||
|
||||
flat_rank_rets = {}
|
||||
|
||||
for r in sorted(ranks):
|
||||
torch.set_rng_state(cpu_state)
|
||||
set_device_states(devices, states)
|
||||
rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args]
|
||||
rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec)
|
||||
rank_ret = func(*rank_args, **rank_kwargs)
|
||||
flat_rank_rets[r] = rank_ret
|
||||
|
||||
rr_key = next(iter(flat_rank_rets.keys()))
|
||||
rr_val = flat_rank_rets[rr_key]
|
||||
|
||||
if isinstance(rr_val, Tensor):
|
||||
ret = LocalTensor({r: flat_rank_rets[r] for r in sorted(ranks)})
|
||||
elif isinstance(rr_val, (list, tuple)):
|
||||
ret_list = []
|
||||
for i in range(len(rr_val)):
|
||||
rets = {r: flat_rank_rets[r][i] for r in sorted(ranks)}
|
||||
v_it = iter(rets.values())
|
||||
v = next(v_it)
|
||||
if isinstance(v, Tensor):
|
||||
ret_list.append(LocalTensor(rets))
|
||||
elif isinstance(v, int) and not all(v == v2 for v2 in v_it):
|
||||
ret_list.append(torch.SymInt(LocalIntNode(rets)))
|
||||
else:
|
||||
assert all(v == v2 for v2 in v_it)
|
||||
ret_list.append(v)
|
||||
ret = type(rr_val)(ret_list)
|
||||
else:
|
||||
v_it = iter(flat_rank_rets.values())
|
||||
v = next(v_it)
|
||||
if all(v == v2 for v2 in v_it):
|
||||
return v
|
||||
if isinstance(v, int):
|
||||
return torch.SymInt(LocalIntNode(flat_rank_rets))
|
||||
raise AssertionError(f"Unexpected return type {type(v)}")
|
||||
|
||||
if alias:
|
||||
return return_and_correct_aliasing(func, args, kwargs, ret)
|
||||
else:
|
||||
return ret
|
||||
|
||||
|
||||
def _get_extra_dispatch_keys(t: torch.Tensor) -> DispatchKeySet:
|
||||
extra_dispatch_keys = torch._C.DispatchKeySet.from_raw_repr(0)
|
||||
if torch._C._dispatch_keys(t).has(torch._C.DispatchKey.Conjugate):
|
||||
extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Conjugate)
|
||||
if torch._C._dispatch_keys(t).has(torch._C.DispatchKey.Negative):
|
||||
extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Negative)
|
||||
return extra_dispatch_keys
|
||||
|
||||
|
||||
class LocalIntNode:
|
||||
"""
|
||||
Like a LocalTensor, but for an int. We can't use a 0D tensor to represent this
|
||||
because often only a SymInt is accepted where we wish to use this.
|
||||
"""
|
||||
|
||||
def __new__(cls, local_ints: dict[int, int]) -> "ConstantIntNode | LocalIntNode": # type: ignore[misc]
|
||||
if len(set(local_ints.values())) == 1:
|
||||
return ConstantIntNode(next(iter(local_ints.values())))
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, local_ints: dict[int, int]):
|
||||
self._local_ints = local_ints
|
||||
|
||||
def maybe_as_int(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
def is_int(self) -> bool:
|
||||
return True
|
||||
|
||||
def is_float(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_bool(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_nested_int(self) -> bool:
|
||||
return False
|
||||
|
||||
def clone(self) -> "LocalIntNode":
|
||||
return self
|
||||
|
||||
def _str(self) -> str:
|
||||
return f"LocalIntNode({self._local_ints})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._str()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self._str()
|
||||
|
||||
def _graph_repr(self) -> str:
|
||||
return self._str()
|
||||
|
||||
def is_symbolic(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_constant(self) -> bool:
|
||||
return False
|
||||
|
||||
def sym_max(
|
||||
self, other: "LocalIntNode | ConstantIntNode"
|
||||
) -> "LocalIntNode | ConstantIntNode":
|
||||
return LocalIntNode(
|
||||
{
|
||||
r: max(self._local_ints[r], _int_on_rank(other, r))
|
||||
for r in self._local_ints
|
||||
}
|
||||
)
|
||||
|
||||
def add(
|
||||
self, other: "LocalIntNode | ConstantIntNode"
|
||||
) -> "LocalIntNode | ConstantIntNode":
|
||||
return LocalIntNode(
|
||||
{r: self._local_ints[r] + _int_on_rank(other, r) for r in self._local_ints}
|
||||
)
|
||||
|
||||
def sub(
|
||||
self, other: "LocalIntNode | ConstantIntNode"
|
||||
) -> "LocalIntNode | ConstantIntNode":
|
||||
return LocalIntNode(
|
||||
{r: self._local_ints[r] - _int_on_rank(other, r) for r in self._local_ints}
|
||||
)
|
||||
|
||||
def mul(
|
||||
self, other: "LocalIntNode | ConstantIntNode"
|
||||
) -> "LocalIntNode | ConstantIntNode":
|
||||
return LocalIntNode(
|
||||
{r: self._local_ints[r] * _int_on_rank(other, r) for r in self._local_ints}
|
||||
)
|
||||
|
||||
def eq(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
|
||||
r = {self._local_ints[r] == _int_on_rank(other, r) for r in self._local_ints}
|
||||
return torch._C._get_constant_bool_symnode(len(r) == 1 and next(iter(r)))
|
||||
|
||||
def gt(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
|
||||
r = {self._local_ints[r] > _int_on_rank(other, r) for r in self._local_ints}
|
||||
assert len(r) == 1, (self, other)
|
||||
return torch._C._get_constant_bool_symnode(next(iter(r)))
|
||||
|
||||
def lt(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
|
||||
r = {self._local_ints[r] < _int_on_rank(other, r) for r in self._local_ints}
|
||||
assert len(r) == 1, (self, other)
|
||||
return torch._C._get_constant_bool_symnode(next(iter(r)))
|
||||
|
||||
def wrap_int(self, num: int) -> "LocalIntNode | ConstantIntNode":
|
||||
return ConstantIntNode(num)
|
||||
|
||||
|
||||
class LocalTensor(torch.Tensor):
|
||||
"""
|
||||
LocalTensor is a Tensor subclass that simulates a tensor distributed across multiple SPMD
|
||||
(Single Program, Multiple Data) ranks. Each LocalTensor instance internally holds a mapping from
|
||||
global rank ids to their corresponding local Tensor shards.Operations performed on a LocalTensor
|
||||
are applied independently to each local shard, mimicking distributed computation. Collectives
|
||||
and other distributed operations are handled by mapping them to the local shards as appropriate.
|
||||
|
||||
Note:
|
||||
This class is primarily intended for debugging and simulating distributed tensor computations
|
||||
on a single process.
|
||||
|
||||
"""
|
||||
|
||||
# Map from global rank to the local tensor.
|
||||
_local_tensors: dict[int, torch.Tensor]
|
||||
# Precomputed for speed set of keys from the local tensor map.
|
||||
_ranks: frozenset[int]
|
||||
__slots__ = ["_local_tensors", "_ranks"]
|
||||
|
||||
@staticmethod
|
||||
@torch._disable_dynamo
|
||||
def __new__(
|
||||
cls,
|
||||
local_tensors: dict[int, torch.Tensor],
|
||||
) -> "LocalTensor":
|
||||
if any(t.requires_grad for t in local_tensors.values()):
|
||||
raise AssertionError(
|
||||
"Internal local_tensors require grad, but we will ignore those autograd graph. "
|
||||
"Make a custom autograd function and make sure you detach the inner tensors."
|
||||
)
|
||||
|
||||
it = iter(local_tensors.values())
|
||||
first_local_tensor = next(it)
|
||||
|
||||
first_shape = first_local_tensor.shape
|
||||
first_stride = first_local_tensor.stride()
|
||||
dtype = first_local_tensor.dtype
|
||||
device = first_local_tensor.device
|
||||
layout = first_local_tensor.layout
|
||||
|
||||
extra_dispatch_keys = _get_extra_dispatch_keys(first_local_tensor)
|
||||
|
||||
# Assert that all tensors have the same dtype, layout and dispatch keys. Due
|
||||
# to uneven sharding, it is possible that tensors will have different shapes.
|
||||
for local_tensor in it:
|
||||
assert dtype == local_tensor.dtype, (
|
||||
"Tensors representing LocalTensor shards must have the same dtype"
|
||||
)
|
||||
assert layout == local_tensor.layout, (
|
||||
"Tensors representing LocalTensor shards must have the same layout"
|
||||
)
|
||||
assert extra_dispatch_keys == _get_extra_dispatch_keys(local_tensor), (
|
||||
"Tensors representing LocalTensor shards must have the same set of extra dispatch keys"
|
||||
)
|
||||
|
||||
# Compute shape/stride. We allow for non-SPMD'ness here
|
||||
local_shapes: dict[int, dict[int, int]] = defaultdict(
|
||||
dict
|
||||
) # dim => rank => size
|
||||
local_strides: dict[int, dict[int, int]] = defaultdict(
|
||||
dict
|
||||
) # dim => rank => size
|
||||
for r, local_tensor in local_tensors.items():
|
||||
for d, size in enumerate(local_tensor.shape):
|
||||
local_shapes[d][r] = size
|
||||
local_strides[d][r] = local_tensor.stride(d)
|
||||
shape = [
|
||||
(
|
||||
first_shape[d]
|
||||
if len(set(local_shapes[d])) == 1
|
||||
else torch.SymInt(LocalIntNode(local_shapes[d]))
|
||||
)
|
||||
for d in range(len(first_shape))
|
||||
]
|
||||
strides = [
|
||||
(
|
||||
first_stride[d]
|
||||
if len(set(local_strides[d])) == 1
|
||||
else torch.SymInt(LocalIntNode(local_strides[d]))
|
||||
)
|
||||
for d in range(len(first_shape))
|
||||
]
|
||||
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
shape,
|
||||
strides=strides,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
layout=layout,
|
||||
requires_grad=False,
|
||||
_extra_dispatch_keys=extra_dispatch_keys,
|
||||
)
|
||||
|
||||
local_tensors = {
|
||||
r: v if not isinstance(v, AsyncCollectiveTensor) else v.wait()
|
||||
for r, v in local_tensors.items()
|
||||
}
|
||||
r._local_tensors = local_tensors
|
||||
r._ranks = frozenset(local_tensors.keys())
|
||||
return r
|
||||
|
||||
@torch._disable_dynamo
|
||||
@mark_subclass_constructor_exportable_experimental # type: ignore[misc]
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
parts = []
|
||||
for k, v in self._local_tensors.items():
|
||||
parts.append(f" {k}: {v}")
|
||||
tensors_str = ",\n".join(parts)
|
||||
return f"LocalTensor(\n{tensors_str}\n)"
|
||||
|
||||
def __tensor_flatten__(self) -> tuple[list[str], tuple[Any, ...]]:
|
||||
"""
|
||||
protocol to inform how to flatten a DTensor to local tensor
|
||||
for PT2 tracing
|
||||
"""
|
||||
return ["_local_tensors"], ()
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(
|
||||
inner_tensors: dict[str, Any],
|
||||
flatten_spec: tuple[Any, ...],
|
||||
outer_size: torch.Size,
|
||||
outer_stride: tuple[int, ...],
|
||||
) -> "LocalTensor":
|
||||
assert flatten_spec is not None, (
|
||||
"Expecting spec to be not None from `__tensor_flatten__` return value!"
|
||||
)
|
||||
local_tensors = inner_tensors["_local_tensors"]
|
||||
return LocalTensor(local_tensors)
|
||||
|
||||
@classmethod
|
||||
@torch._disable_dynamo
|
||||
def __torch_dispatch__( # type: ignore[override]
|
||||
cls,
|
||||
func: Any,
|
||||
types: tuple[Any, ...],
|
||||
args: tuple[Any, ...] = (),
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# This is horribly inefficient
|
||||
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
|
||||
local_tensor = None
|
||||
for arg in flat_args:
|
||||
if isinstance(arg, LocalTensor):
|
||||
local_tensor = arg
|
||||
break
|
||||
|
||||
assert local_tensor is not None, (
|
||||
"At least one of the arguments must be a LocalTensor"
|
||||
)
|
||||
|
||||
# Check for unrecognized tensor subclasses (but allow regular tensors and scalars)
|
||||
has_unrecognized_types = _check_for_subclass(flat_args)
|
||||
if has_unrecognized_types:
|
||||
unrecognized_types = [
|
||||
type(x) for x in flat_args if _check_for_subclass_arg(x)
|
||||
]
|
||||
not_implemented_log.debug(
|
||||
"LocalTensor unrecognized subclass(es): %s", unrecognized_types
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
with LocalTensorMode(local_tensor._ranks):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def tolist(self) -> list[Any]:
|
||||
"""
|
||||
Reconcile and convert result to list.
|
||||
"""
|
||||
|
||||
return self.reconcile().tolist()
|
||||
|
||||
def reconcile(self) -> torch.Tensor:
|
||||
"""
|
||||
Reconciles the LocalTensor into a single torch.Tensor by ensuring all local
|
||||
shards are identical and returning a detached clone of one of them.
|
||||
|
||||
Note:
|
||||
This method is useful for extracting a representative tensor from a LocalTensor
|
||||
when all shards are expected to be the same, such as after a collective operation
|
||||
that synchronizes all ranks.
|
||||
"""
|
||||
|
||||
# Force all local tensor shards across ranks to be the same
|
||||
it = iter(self._local_tensors.values())
|
||||
t1 = next(it)
|
||||
for t2 in it:
|
||||
assert torch.equal(t1, t2), (
|
||||
"LocalTensor shards must be the same to reconcile"
|
||||
)
|
||||
cl = t1.clone().detach()
|
||||
cl.requires_grad_(self.requires_grad)
|
||||
return cl
|
||||
|
||||
|
||||
_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = []
|
||||
|
||||
|
||||
class LocalTensorMode(TorchDispatchMode):
|
||||
"""
|
||||
A TorchDispatchMode that simulates SPMD (Single Program, Multiple Data) execution
|
||||
for LocalTensor objects across a set of ranks.
|
||||
|
||||
LocalTensorMode enables PyTorch operations to be transparently applied to each
|
||||
local shard of a LocalTensor, as if they were distributed across multiple ranks.
|
||||
When active, this mode intercepts tensor operations and dispatches them to each
|
||||
rank's local tensor, collecting and wrapping the results as LocalTensors. It also
|
||||
handles collective operations by mapping them to local implementations.
|
||||
|
||||
This mode is primarily intended for debugging and simulating distributed tensor
|
||||
computations on a single process, rather than for high-performance distributed
|
||||
training. It maintains a stack of active modes, patches DeviceMesh coordinate
|
||||
resolution, and provides utilities for temporarily disabling the mode or mapping
|
||||
functions over ranks.
|
||||
"""
|
||||
|
||||
# What ranks this local tensor mode is operating over
|
||||
def __init__(self, ranks: Union[int, frozenset[int]]):
|
||||
if isinstance(ranks, int):
|
||||
# assume is world size
|
||||
self.ranks = frozenset(range(ranks))
|
||||
else:
|
||||
assert isinstance(ranks, frozenset)
|
||||
self.ranks = ranks
|
||||
self._disable = False
|
||||
self._old_get_coordinate = None
|
||||
|
||||
def __enter__(self) -> "LocalTensorMode":
|
||||
self._disable = False
|
||||
self._patch_device_mesh()
|
||||
_LOCAL_TENSOR_MODE.append(self)
|
||||
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
self._disable = True
|
||||
self._unpatch_device_mesh()
|
||||
_LOCAL_TENSOR_MODE.pop()
|
||||
super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def __torch_dispatch__(
|
||||
self,
|
||||
func: Any,
|
||||
types: tuple[Any, ...],
|
||||
args: tuple[Any, ...] = (),
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
# Find all LocalTensor arguments to determine ranks
|
||||
local_tensors = [a for a in flat_args if isinstance(a, LocalTensor)]
|
||||
|
||||
# Check for unrecognized tensor subclasses (but allow regular tensors and scalars)
|
||||
has_unrecognized_types = _check_for_subclass(flat_args)
|
||||
if has_unrecognized_types:
|
||||
unrecognized_types = [
|
||||
type(x) for x in flat_args if _check_for_subclass_arg(x)
|
||||
]
|
||||
not_implemented_log.debug(
|
||||
"LocalTensorMode unrecognized subclass(es): %s", unrecognized_types
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
# Factory functions convert into LocalTensor, so we don't have to
|
||||
# transmute a Tensor into a LocalTensor if mutation happens...
|
||||
# But if you do an operation on a Tensor, do NOT wrap it into a
|
||||
# LocalTensor. This helps prevent accidents when you're doing Tensor
|
||||
# operations on the inner non-wrapped tensors.
|
||||
if not local_tensors:
|
||||
if self._disable or any(isinstance(a, Tensor) for a in flat_args):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# For LocalTensors, verify they have compatible ranks
|
||||
for a in flat_args:
|
||||
if isinstance(a, LocalTensor):
|
||||
assert a._ranks == self.ranks, (
|
||||
f"Input LocalTensor {a} and LocalTensorMode must be configured for the same ranks"
|
||||
)
|
||||
|
||||
if func.namespace == "c10d":
|
||||
if func is torch.ops.c10d.allreduce_.default:
|
||||
return _c10d._local_all_reduce_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.allreduce_coalesced_.default:
|
||||
return _c10d._local_allreduce_coalesced_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.reduce_scatter_tensor_coalesced_.default:
|
||||
return _c10d._local_reduce_scatter_tensor_coalesced_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.scatter_.default:
|
||||
return _c10d._local_scatter_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.broadcast_.default:
|
||||
return _c10d._local_broadcast_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.allgather_.default:
|
||||
return _c10d._local_all_gather_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.allgather_into_tensor_coalesced_.default:
|
||||
return _c10d._local_allgather_into_tensor_coalesced_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.gather_.default:
|
||||
return _c10d._local_gather_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.alltoall_.default:
|
||||
return _c10d._local_alltoall_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.alltoall_base_.default:
|
||||
return _c10d._local_alltoall_base_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.barrier.default:
|
||||
return _c10d._local_barrier(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.monitored_barrier_.default:
|
||||
return _c10d._local_monitored_barrier_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.send.default:
|
||||
return _c10d._local_send(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.recv_.default:
|
||||
return _c10d._local_recv_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.recv_any_source_.default:
|
||||
return _c10d._local_recv_any_source_(*args, **kwargs)
|
||||
raise NotImplementedError(f"{func} not implemented")
|
||||
|
||||
if func.namespace == "_c10d_functional" or func.namespace == "_dtensor":
|
||||
with LocalTensorMode(self.ranks):
|
||||
return func._op_dk(
|
||||
DispatchKey.CompositeExplicitAutograd, *args, **kwargs
|
||||
)
|
||||
|
||||
if func.namespace == "_c10d_functional_autograd":
|
||||
raise NotImplementedError(f"{func} not implemented")
|
||||
|
||||
if func.namespace == "symm_mem":
|
||||
raise NotImplementedError(f"{func} not implemented")
|
||||
|
||||
return _for_each_rank_run_func(func, self.ranks, args, kwargs, alias=True)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def disable(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
Disables LocalTensorMode temporarily. Primarily is intended to be used to perform
|
||||
rank specific computations and merge results back before enabling LocalTensorMode back.
|
||||
"""
|
||||
|
||||
old = self._disable
|
||||
self._disable = True
|
||||
self._unpatch_device_mesh()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._disable = old
|
||||
self._patch_device_mesh()
|
||||
|
||||
def rank_map(self, cb: Callable[[int], Tensor]) -> LocalTensor:
|
||||
"""
|
||||
Creates a LocalTensor instance by mapping rank id to ids local shard.
|
||||
"""
|
||||
|
||||
with self.disable():
|
||||
return LocalTensor({r: cb(r) for r in self.ranks})
|
||||
|
||||
def _patch_device_mesh(self) -> None:
|
||||
assert self._old_get_coordinate is None
|
||||
self._old_get_coordinate = DeviceMesh.get_coordinate # type: ignore[assignment]
|
||||
DeviceMesh.get_coordinate = _LocalDeviceMesh.get_coordinate # type: ignore[method-assign]
|
||||
|
||||
def _unpatch_device_mesh(self) -> None:
|
||||
assert self._old_get_coordinate is not None
|
||||
DeviceMesh.get_coordinate = self._old_get_coordinate
|
||||
self._old_get_coordinate = None
|
||||
|
||||
|
||||
class _LocalDeviceMesh:
|
||||
"""
|
||||
Holds implementations of DeviceMesh functionality that must be patched while running
|
||||
under LocalTensorMode.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]:
|
||||
lm = local_tensor_mode()
|
||||
assert lm is not None, "Unexpectedly not in LocalTensorMode"
|
||||
|
||||
rank_coords = (self.mesh == lm.rank_map(lambda r: torch.tensor(r))).nonzero()
|
||||
# NB: unlike the regular mechanism, we don't allow for MPMD
|
||||
assert rank_coords.size(0) == 1
|
||||
assert isinstance(rank_coords[0], LocalTensor)
|
||||
|
||||
coords: list[dict[int, int]] = [{} for _ in range(rank_coords.size(1))]
|
||||
for r, v in rank_coords[0]._local_tensors.items():
|
||||
for i, x in enumerate(v.tolist()):
|
||||
coords[i][r] = x
|
||||
out = [torch.SymInt(LocalIntNode(c)) for c in coords]
|
||||
|
||||
return out # type: ignore[return-value]
|
||||
|
||||
|
||||
def reconcile_args(args: Any, kwargs: dict[str, Any] | None = None) -> Any:
|
||||
"""
|
||||
Reconciles arguments by converting any LocalTensor instances in the input
|
||||
arguments to their underlying torch.Tensor representation.
|
||||
|
||||
This function is typically used to prepare arguments for functions that
|
||||
expect standard torch.Tensor objects, by flattening the input arguments,
|
||||
replacing LocalTensor instances with their reconciled (standard tensor)
|
||||
versions, and then reconstructing the original argument structure.
|
||||
|
||||
Args:
|
||||
args: Positional arguments, possibly containing LocalTensor instances.
|
||||
kwargs: Keyword arguments, possibly containing LocalTensor instances.
|
||||
|
||||
Returns:
|
||||
Any: The arguments with all LocalTensor instances replaced by their reconciled torch.Tensor equivalents,
|
||||
preserving the original structure.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
|
||||
reconciled_args = [
|
||||
a.reconcile() if isinstance(a, LocalTensor) else a for a in flat_args
|
||||
]
|
||||
return pytree.tree_unflatten(reconciled_args, args_spec)
|
||||
|
||||
|
||||
def local_tensor_mode() -> Optional[LocalTensorMode]:
|
||||
"""
|
||||
Returns the current active LocalTensorMode if one exists.
|
||||
|
||||
This function checks the global stack of LocalTensorMode instance. If there
|
||||
is at least one LocalTensorMode active, it returns the most recently entered
|
||||
(top of the stack) LocalTensorMode. If no LocalTensorMode is active, it returns None.
|
||||
|
||||
Returns:
|
||||
Optional[LocalTensorMode]: The current LocalTensorMode if active, else None.
|
||||
"""
|
||||
if len(_LOCAL_TENSOR_MODE) > 0:
|
||||
return _LOCAL_TENSOR_MODE[-1]
|
||||
return None
|
||||
|
||||
|
||||
def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
Decorator that ensures a function is executed for each local tensor shard
|
||||
when running under LocalTensorMode. If not in LocalTensorMode, the function
|
||||
is executed normally. When in LocalTensorMode, the function is run for each
|
||||
rank, and the results are collected appropriately.
|
||||
|
||||
This decorator is useful for functions that exhibit non-SPMD behavior, such
|
||||
as those requiring rank specific actions. For example, a function that computes
|
||||
offset into input tensor based on rank.
|
||||
|
||||
Note that the function being decorated must not have any side effects and
|
||||
contain operations for a single rank only. For example, wrapping a function
|
||||
that performs a collective operation will not work.
|
||||
|
||||
Args:
|
||||
func (Callable[..., Any]): The function to be decorated.
|
||||
|
||||
Returns:
|
||||
Callable[..., Any]: The wrapped function that handles LocalTensorMode logic.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
lm = local_tensor_mode()
|
||||
if lm is None:
|
||||
return func(*args, **kwargs)
|
||||
ret = None
|
||||
with lm.disable():
|
||||
ret = _for_each_rank_run_func(func, lm.ranks, args, kwargs, alias=False)
|
||||
|
||||
lm = local_tensor_mode()
|
||||
assert lm is not None
|
||||
return ret
|
||||
|
||||
return wrapper
|
||||
669
torch/distributed/_local_tensor/_c10d.py
Normal file
669
torch/distributed/_local_tensor/_c10d.py
Normal file
@ -0,0 +1,669 @@
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from torch._C import ScriptObject
|
||||
from torch._C._distributed_c10d import FakeWork
|
||||
from torch.distributed._mesh_layout import _MeshLayout
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_get_default_group,
|
||||
ProcessGroup,
|
||||
ReduceOp,
|
||||
Work,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: Most of the c10d collectives often take a Tensor[] (or Tensor[][])
|
||||
# when you would expect Tensor (or Tensor[]). In fact, there will only ever
|
||||
# be one Tensor in this case; the old signature was to support dispatching a
|
||||
# collective on multiple devices (ala DataParallel) but we don't support that
|
||||
# API anymore. Note that we are not 100% consistent about this; some more
|
||||
# modern collectives like _allgather_base_ got rid of the unnecessary list.
|
||||
# When in doubt, consult the code that dispatches to the collective on the PG
|
||||
# in distributed_c10d.py e.g., work = group.allgather([tensor_list], [tensor],
|
||||
# opts) indicates its always a list.
|
||||
|
||||
|
||||
def _gcd_list(numbers: Sequence[int]) -> int:
|
||||
return 0 if not numbers else functools.reduce(math.gcd, numbers)
|
||||
|
||||
|
||||
def _indices_to_layout(indices: list[int]) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
# Base case: A single index represents a point, not a dimension.
|
||||
if len(indices) <= 1:
|
||||
return (), ()
|
||||
|
||||
# The smallest stride is likely the GCD of the differences between consecutive indices.
|
||||
# For a sorted, unique list, all differences will be positive.
|
||||
diffs = [indices[i] - indices[i - 1] for i in range(1, len(indices))]
|
||||
last_stride = _gcd_list(diffs)
|
||||
|
||||
assert last_stride != 0, (
|
||||
# This case should not be reached if indices are unique and sorted.
|
||||
"Cannot determine stride; indices may not be unique."
|
||||
)
|
||||
|
||||
# Identify the starting index of each "row" in the last dimension.
|
||||
# An index starts a new row if the preceding index (index - stride) is not present.
|
||||
indices_set = set(indices)
|
||||
higher_dim_indices = [indices[0]]
|
||||
for index in indices[1:]:
|
||||
if (index - last_stride) not in indices_set:
|
||||
higher_dim_indices.append(index)
|
||||
|
||||
# From the number of rows, we can deduce the shape of the last dimension.
|
||||
assert len(indices) % len(higher_dim_indices) == 0, (
|
||||
"Indices do not form a regular grid. "
|
||||
f"Found {len(higher_dim_indices)} subgroups for {len(indices)} total elements."
|
||||
)
|
||||
last_shape = len(indices) // len(higher_dim_indices)
|
||||
|
||||
# Recurse on the higher-dimensional indices (the start of each row).
|
||||
higher_shapes, higher_strides = _indices_to_layout(higher_dim_indices)
|
||||
|
||||
# Combine the results from the recursion with the current dimension's results.
|
||||
final_shapes = higher_shapes + (last_shape,)
|
||||
final_strides = higher_strides + (last_stride,)
|
||||
|
||||
return final_shapes, final_strides
|
||||
|
||||
|
||||
def _prepare_collective_groups(
|
||||
process_group_so: ScriptObject,
|
||||
) -> tuple[list[int], list[int], int]:
|
||||
process_group = ProcessGroup.unbox(process_group_so)
|
||||
|
||||
ranks = torch.distributed.get_process_group_ranks(process_group)
|
||||
assert ranks
|
||||
# TODO: We can handle permutations but the layout inference algorithm will
|
||||
# lose the permutation so we will have to reapply it
|
||||
assert ranks == sorted(ranks), ranks
|
||||
offset = ranks[0]
|
||||
ranks = [r - offset for r in ranks]
|
||||
|
||||
shape, strides = _indices_to_layout(ranks)
|
||||
layout = _MeshLayout(shape, strides)
|
||||
|
||||
global_pg = _get_default_group()
|
||||
group_offsets = layout.complement(global_pg.size()).all_ranks_from_zero()
|
||||
|
||||
return ranks, group_offsets, offset
|
||||
|
||||
|
||||
def _local_broadcast_(
|
||||
tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
root_rank: int,
|
||||
root_tensor: int,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> tuple[list[torch.Tensor], ScriptObject]:
|
||||
# "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"
|
||||
from . import LocalTensor
|
||||
|
||||
assert len(tensors) == 1
|
||||
assert root_tensor == 0
|
||||
tensor = tensors[0]
|
||||
|
||||
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
# We're going to assume SPMD where for every rank group the root_rank is
|
||||
# the same relative to others
|
||||
relative_root_rank = root_rank - offset
|
||||
|
||||
assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor"
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the broadcast on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
source_rank = group_offset + relative_root_rank
|
||||
source_tensor = tensor._local_tensors[source_rank]
|
||||
|
||||
# Broadcast the source tensor to all ranks in this group
|
||||
for rank in group_ranks:
|
||||
if source_rank != rank:
|
||||
tensor._local_tensors[rank].copy_(source_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return (tensors, work_so)
|
||||
|
||||
|
||||
def _local_reduce(
|
||||
reduce_op: ReduceOp,
|
||||
tensors: list[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if reduce_op == ReduceOp.SUM:
|
||||
op = operator.add
|
||||
elif reduce_op == ReduceOp.AVG:
|
||||
op = None
|
||||
elif reduce_op == ReduceOp.PRODUCT:
|
||||
op = operator.mul
|
||||
elif reduce_op == ReduceOp.MIN:
|
||||
op = torch.minimum
|
||||
elif reduce_op == ReduceOp.MAX:
|
||||
op = torch.maximum
|
||||
elif reduce_op == ReduceOp.BAND:
|
||||
op = torch.bitwise_and
|
||||
elif reduce_op == ReduceOp.BOR:
|
||||
op = torch.bitwise_or
|
||||
elif reduce_op == ReduceOp.BXOR:
|
||||
op = torch.bitwise_xor
|
||||
elif reduce_op == ReduceOp.PREMUL_SUM:
|
||||
raise NotImplementedError("PREMUL_SUM: need to add binding for scaling factor")
|
||||
else:
|
||||
raise NotImplementedError(f"ReduceOp {reduce_op} not implemented")
|
||||
|
||||
if reduce_op == ReduceOp.AVG:
|
||||
return functools.reduce(operator.add, tensors) / len(tensors)
|
||||
else:
|
||||
assert op is not None
|
||||
return functools.reduce(op, tensors)
|
||||
|
||||
|
||||
def _local_all_reduce_(
|
||||
tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
reduce_op_so: ScriptObject,
|
||||
sparse_indices: torch.Tensor | None = None,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> tuple[list[torch.Tensor], ScriptObject]:
|
||||
# "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "__torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, "
|
||||
# "int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
from . import LocalTensor
|
||||
|
||||
assert len(tensors) == 1
|
||||
tensor = tensors[0]
|
||||
reduce_op = reduce_op_so.op() # type: ignore[attr-defined]
|
||||
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor"
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the allreduce on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# Collect tensors from the specified ranks in this group
|
||||
group_tensors = []
|
||||
for rank in group_ranks:
|
||||
group_tensors.append(tensor._local_tensors[rank])
|
||||
|
||||
# Perform the reduction operation
|
||||
reduced_tensor = _local_reduce(reduce_op, group_tensors)
|
||||
|
||||
# Update all tensors in the group with the reduced result
|
||||
for rank in group_ranks:
|
||||
tensor._local_tensors[rank].copy_(reduced_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return (tensors, work_so)
|
||||
|
||||
|
||||
def _local_allreduce_coalesced_(
|
||||
tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
reduce_op_so: ScriptObject,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> ScriptObject:
|
||||
# "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "__torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"
|
||||
from . import LocalTensor
|
||||
|
||||
reduce_op = reduce_op_so.op() # type: ignore[attr-defined]
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the allreduce on all tensors together
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# For each tensor, perform the reduction operation
|
||||
for tensor in tensors:
|
||||
assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor"
|
||||
# Collect tensors from the specified ranks in this group
|
||||
group_tensors = []
|
||||
for rank in group_ranks:
|
||||
group_tensors.append(tensor._local_tensors[rank])
|
||||
|
||||
# Perform the reduction operation
|
||||
reduced_tensor = _local_reduce(reduce_op, group_tensors)
|
||||
|
||||
# Update all tensors in the group with the reduced result
|
||||
for rank in group_ranks:
|
||||
tensor._local_tensors[rank].copy_(reduced_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return work_so
|
||||
|
||||
|
||||
def _local_reduce_scatter_tensor_coalesced_(
|
||||
output_tensors: list[torch.Tensor],
|
||||
input_tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
reduce_op_so: ScriptObject,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> ScriptObject:
|
||||
# "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, "
|
||||
# "__torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "__torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, "
|
||||
# "int timeout=-1) -> __torch__.torch.classes.c10d.Work"
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
reduce_op = reduce_op_so.op() # type: ignore[attr-defined]
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the allreduce on all tensors together
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# For each tensor, perform the reduction operation
|
||||
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
|
||||
assert isinstance(input_tensor, LocalTensor), (
|
||||
"Input tensor must be a LocalTensor"
|
||||
)
|
||||
assert isinstance(output_tensor, LocalTensor), (
|
||||
"Output tensor must be a LocalTensor"
|
||||
)
|
||||
# Collect tensors from the specified ranks in this group
|
||||
group_inputs = []
|
||||
for rank in group_ranks:
|
||||
group_inputs.append(input_tensor._local_tensors[rank])
|
||||
|
||||
# Perform the reduction operation
|
||||
reduced_input = _local_reduce(reduce_op, group_inputs)
|
||||
|
||||
reduced_inpit_splits = torch.split(
|
||||
reduced_input, reduced_input.size(0) // len(group_ranks), dim=0
|
||||
)
|
||||
|
||||
# Update all tensors in the group with the reduced result
|
||||
for rank in group_ranks:
|
||||
output_tensor._local_tensors[rank].copy_(reduced_inpit_splits[rank])
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return work_so
|
||||
|
||||
|
||||
def _local_all_gather_(
|
||||
output_tensors: list[list[torch.Tensor]],
|
||||
input_tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> tuple[list[list[torch.Tensor]], ScriptObject]:
|
||||
# "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, "
|
||||
# "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, "
|
||||
# "int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
assert len(output_tensors) == 1
|
||||
assert len(input_tensors) == 1
|
||||
|
||||
input_tensor = input_tensors[0]
|
||||
output_tensors = output_tensors[0]
|
||||
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor"
|
||||
for i in range(len(output_tensors)):
|
||||
assert isinstance(output_tensors[i], LocalTensor), (
|
||||
"Output tensor must be a LocalTensor"
|
||||
)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the all_gather on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# For each rank in the group, gather from their input tensor
|
||||
for i, rank_i in enumerate(group_ranks):
|
||||
output_tensors[i].copy_(input_tensor._local_tensors[rank_i])
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return ([output_tensors], work_so)
|
||||
|
||||
|
||||
def _local_allgather_into_tensor_coalesced_(
|
||||
output_tensors: list[torch.Tensor],
|
||||
input_tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
async_op: bool = True,
|
||||
) -> ScriptObject:
|
||||
# "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, "
|
||||
# "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) "
|
||||
# "-> __torch__.torch.classes.c10d.Work"
|
||||
from . import LocalTensor
|
||||
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
# Each output tensor should be sized to hold all gathered inputs
|
||||
# outputs[i] will contain all inputs[i] from all ranks
|
||||
assert len(output_tensors) == len(input_tensors), (
|
||||
f"Number of outputs ({len(output_tensors)}) must match number of inputs ({len(input_tensors)})"
|
||||
)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the allgather_into_tensor on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# For each input/output pair
|
||||
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
|
||||
assert isinstance(input_tensor, LocalTensor), (
|
||||
"Input tensor must be a LocalTensor"
|
||||
)
|
||||
assert isinstance(output_tensor, LocalTensor), (
|
||||
"Output tensor must be a LocalTensor"
|
||||
)
|
||||
# Gather input_tensor from all ranks into output_tensor
|
||||
# The output should be a concatenation of all inputs along the first dimension
|
||||
gathered_tensors = []
|
||||
for rank in group_ranks:
|
||||
gathered_tensors.append(input_tensor._local_tensors[rank])
|
||||
|
||||
# Concatenate along first dimension and copy to output
|
||||
if gathered_tensors:
|
||||
concatenated = torch.cat(gathered_tensors, dim=0)
|
||||
for rank in group_ranks:
|
||||
output_tensor._local_tensors[rank].copy_(concatenated)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return work_so
|
||||
|
||||
|
||||
def _local_gather_(
|
||||
output_tensors: list[list[torch.Tensor]],
|
||||
input_tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
root_rank: int,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> ScriptObject:
|
||||
# "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, "
|
||||
# "__torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, "
|
||||
# "bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"
|
||||
raise NotImplementedError(
|
||||
"LocalTensor does not support MPMD operations like gather "
|
||||
"(only root rank receives data). Use SPMD collective operations like allgather instead."
|
||||
)
|
||||
|
||||
|
||||
def _local_scatter_(
|
||||
output_tensors: list[torch.Tensor],
|
||||
input_tensors: list[list[torch.Tensor]],
|
||||
process_group_so: ScriptObject,
|
||||
root_rank: int,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> tuple[list[torch.Tensor], ScriptObject]:
|
||||
# "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, "
|
||||
# "__torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, "
|
||||
# "bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
assert len(output_tensors) == 1
|
||||
assert len(input_tensors) == 1
|
||||
output_tensor = output_tensors[0]
|
||||
input_tensors = input_tensors[0]
|
||||
|
||||
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
# We're going to assume SPMD where for every rank group the root_rank is
|
||||
# the same relative to others
|
||||
relative_root_rank = root_rank - offset
|
||||
|
||||
assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor"
|
||||
assert len(ranks) == len(input_tensors), (ranks, input_tensors)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the scatter on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# Root rank scatters its input tensors to all ranks in this group
|
||||
for i, rank in enumerate(group_ranks):
|
||||
input_tensor = input_tensors[i]
|
||||
assert isinstance(input_tensor, LocalTensor)
|
||||
# Each rank i gets the i-th input tensor from the root
|
||||
source_tensor = input_tensor._local_tensors[
|
||||
group_offset + relative_root_rank
|
||||
]
|
||||
output_tensor._local_tensors[rank].copy_(source_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return (output_tensors, work_so)
|
||||
|
||||
|
||||
def _local_alltoall_(
|
||||
output_tensors: list[torch.Tensor],
|
||||
input_tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> tuple[list[torch.Tensor], ScriptObject]:
|
||||
# "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, "
|
||||
# "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, "
|
||||
# "int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)";
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
assert len(input_tensors) == len(output_tensors) == len(ranks), (
|
||||
f"Number of input tensors ({len(input_tensors)}), "
|
||||
f"output tensors ({len(output_tensors)}), and ranks ({len(ranks)}) must match"
|
||||
)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the alltoall on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# In alltoall, rank i sends input_tensors[j] to rank j and receives into output_tensors[i] from rank j
|
||||
for i, rank_i in enumerate(group_ranks):
|
||||
output_tensor = output_tensors[i]
|
||||
assert isinstance(output_tensor, LocalTensor), (
|
||||
"Output tensor must be a LocalTensor"
|
||||
)
|
||||
for j, rank_j in enumerate(group_ranks):
|
||||
input_tensor = input_tensors[j]
|
||||
assert isinstance(input_tensor, LocalTensor), (
|
||||
"Input tensor must be a LocalTensor"
|
||||
)
|
||||
# Rank i's j-th input tensor goes to rank j's i-th output tensor
|
||||
source_tensor = input_tensor._local_tensors[rank_i]
|
||||
output_tensor._local_tensors[rank_j].copy_(source_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return (output_tensors, work_so)
|
||||
|
||||
|
||||
def _local_alltoall_base_(
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
process_group_so: ScriptObject,
|
||||
output_split_sizes: list[int],
|
||||
input_split_sizes: list[int],
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> ScriptObject:
|
||||
# "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work";
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor"
|
||||
assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor"
|
||||
# Convert split sizes to lists if they aren't already
|
||||
if output_split_sizes is not None:
|
||||
output_split_sizes = list(output_split_sizes)
|
||||
if input_split_sizes is not None:
|
||||
input_split_sizes = list(input_split_sizes)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the alltoall_base on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
for i, rank_i in enumerate(group_ranks):
|
||||
# Split input tensor from rank_i according to input_split_sizes
|
||||
rank_tensor = input_tensor._local_tensors[rank_i]
|
||||
|
||||
if input_split_sizes is not None and len(input_split_sizes) > 0:
|
||||
# Split the input tensor
|
||||
input_splits = torch.split(rank_tensor, input_split_sizes, dim=0)
|
||||
else:
|
||||
# No split sizes specified, split evenly
|
||||
split_size = rank_tensor.size(0) // len(group_ranks)
|
||||
input_splits = torch.split(rank_tensor, split_size, dim=0)
|
||||
|
||||
# Send each split to the corresponding rank
|
||||
for j, rank_j in enumerate(group_ranks):
|
||||
if j < len(input_splits):
|
||||
split_tensor = input_splits[j]
|
||||
|
||||
# Determine where to place this split in the output tensor
|
||||
if output_split_sizes is not None and len(output_split_sizes) > 0:
|
||||
# Calculate offset based on output split sizes
|
||||
output_offset = sum(output_split_sizes[:i]) if i > 0 else 0
|
||||
end_offset = (
|
||||
output_offset + output_split_sizes[i]
|
||||
if i < len(output_split_sizes)
|
||||
else output_tensor._local_tensors[rank_j].size(0)
|
||||
)
|
||||
else:
|
||||
# No output split sizes, use even splits
|
||||
split_size = output_tensor._local_tensors[rank_j].size(
|
||||
0
|
||||
) // len(group_ranks)
|
||||
output_offset = i * split_size
|
||||
end_offset = min(
|
||||
(i + 1) * split_size,
|
||||
output_tensor._local_tensors[rank_j].size(0),
|
||||
)
|
||||
|
||||
# Copy the split to the appropriate section of the output tensor
|
||||
output_section = output_tensor._local_tensors[rank_j][
|
||||
output_offset:end_offset
|
||||
]
|
||||
if output_section.numel() > 0:
|
||||
# Reshape split_tensor to match output_section if necessary
|
||||
if split_tensor.size() != output_section.size():
|
||||
split_tensor = split_tensor.view(output_section.size())
|
||||
output_section.copy_(split_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return work_so
|
||||
|
||||
|
||||
def _local_barrier(
|
||||
tensor: torch.Tensor,
|
||||
process_group_so: ScriptObject,
|
||||
device_ids: list[int],
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> ScriptObject:
|
||||
# "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work";
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
# Barrier is a synchronization primitive - in local simulation,
|
||||
# we don't need to do any actual work since all "ranks" are in the same process
|
||||
# Just validate that the tensor is a LocalTensor
|
||||
assert isinstance(tensor, LocalTensor)
|
||||
|
||||
# In a real distributed setting, barrier would synchronize all processes
|
||||
# In local simulation, this is essentially a no-op since all ranks are local
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return work_so
|
||||
|
||||
|
||||
def _local_monitored_barrier_(
|
||||
tensor: torch.Tensor,
|
||||
process_group_so: ScriptObject,
|
||||
device_ids: list[int],
|
||||
timeout: int,
|
||||
wait_all_ranks: bool,
|
||||
) -> None:
|
||||
# "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int[] device_ids, int timeout, bool wait_all_ranks) -> ()";
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
# Monitored barrier is a synchronization primitive with monitoring - in local simulation,
|
||||
# we don't need to do any actual work since all "ranks" are in the same process
|
||||
# Just validate that the tensor is a LocalTensor
|
||||
assert isinstance(tensor, LocalTensor)
|
||||
|
||||
# In a real distributed setting, monitored barrier would synchronize all processes
|
||||
# and provide monitoring capabilities. In local simulation, this is essentially a no-op
|
||||
# since all ranks are local and no actual synchronization is needed
|
||||
return
|
||||
|
||||
|
||||
def _local_send(
|
||||
tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
dst: int,
|
||||
tag: int,
|
||||
) -> ScriptObject:
|
||||
# "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int dst, int tag) -> __torch__.torch.classes.c10d.Work";
|
||||
|
||||
raise NotImplementedError(
|
||||
"LocalTensor does not support MPMD operations like send. "
|
||||
"Use SPMD collective operations instead."
|
||||
)
|
||||
|
||||
|
||||
def _local_recv_(
|
||||
tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
src: int,
|
||||
tag: int,
|
||||
) -> ScriptObject:
|
||||
# "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int src, int tag) -> __torch__.torch.classes.c10d.Work";
|
||||
|
||||
raise NotImplementedError(
|
||||
"LocalTensor does not support MPMD operations like recv. "
|
||||
"Use SPMD collective operations instead."
|
||||
)
|
||||
|
||||
|
||||
def _local_recv_any_source_(
|
||||
tensors: list[torch.Tensor], process_group_so: ScriptObject, tag: int
|
||||
) -> ScriptObject:
|
||||
# "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int tag) -> __torch__.torch.classes.c10d.Work";
|
||||
|
||||
raise NotImplementedError(
|
||||
"LocalTensor does not support MPMD operations like recv_any_source. "
|
||||
"Use SPMD collective operations instead."
|
||||
)
|
||||
@ -10,6 +10,7 @@ import torch.distributed._functional_collectives as funcol
|
||||
import torch.distributed.tensor._dtensor_spec as dtensor_spec
|
||||
from torch._C._distributed_c10d import _resolve_process_group
|
||||
from torch._logging import warning_once
|
||||
from torch.distributed._local_tensor import local_tensor_mode
|
||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_get_group_size_by_name,
|
||||
@ -40,7 +41,7 @@ def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
|
||||
|
||||
|
||||
def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
|
||||
if mesh.device_type == "cpu":
|
||||
if mesh.device_type == "cpu" and local_tensor_mode() is None:
|
||||
# Gloo does not support alltoall, so falling back to allgather + chunk
|
||||
warning_once(
|
||||
logger,
|
||||
|
||||
@ -165,7 +165,7 @@ class OpDispatcher:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Sharding propagation failed for {op_info.schema}"
|
||||
f"{e}\n\nSharding propagation failed for {op_info.schema}"
|
||||
) from e
|
||||
|
||||
output_sharding = op_info.output_sharding
|
||||
|
||||
@ -319,6 +319,10 @@ LINEAR_REDUCTION_OP_MAP = {
|
||||
aten.all.dim: "sum",
|
||||
aten.sum.default: "sum",
|
||||
aten.sum.dim_IntList: "sum",
|
||||
aten.any.default: "sum",
|
||||
aten.any.dim: "sum",
|
||||
aten.any.out: "sum",
|
||||
# These are only valid when there is no padding
|
||||
aten.prod.default: "product",
|
||||
aten.prod.dim_int: "product",
|
||||
aten.prod.int_out: "product",
|
||||
@ -332,9 +336,6 @@ LINEAR_REDUCTION_OP_MAP = {
|
||||
aten.min.default: "min",
|
||||
aten.min.dim: "min",
|
||||
aten.min.out: "min",
|
||||
aten.any.default: "sum",
|
||||
aten.any.dim: "sum",
|
||||
aten.any.out: "sum",
|
||||
aten.amax.default: "max",
|
||||
aten.amax.out: "max",
|
||||
aten.amin.default: "min",
|
||||
|
||||
@ -383,6 +383,7 @@ def redistribute_local_tensor(
|
||||
raise RuntimeError(
|
||||
f"redistribute from {current} to {target} not supported yet"
|
||||
)
|
||||
|
||||
elif target.is_shard():
|
||||
# Case 2: target is Shard
|
||||
target_placement = cast(Shard, target)
|
||||
|
||||
@ -149,7 +149,7 @@ def _compute_local_shape_and_global_offset(
|
||||
ordered_placements = _explicit_order_placements(mesh_shape, placements)
|
||||
|
||||
local_shape = list(global_shape)
|
||||
# We'll compute the data for where the shard beings on a per-dim basis.
|
||||
# We'll compute the data for where the shard begins on a per-dim basis.
|
||||
# However, a single dim can be sharded multiple times, so we will end up
|
||||
# doing a Sum(size*stride) like computation to determine the location of our
|
||||
# shard for each of the shardings on that dim.
|
||||
@ -170,6 +170,14 @@ def _compute_local_shape_and_global_offset(
|
||||
|
||||
local_shape[shard_dim] = shard_size
|
||||
|
||||
shard_global_offset = global_offset[shard_dim] + not_none(shard_offset)
|
||||
|
||||
zero_global_offset = global_shape[shard_dim]
|
||||
if isinstance(shard_global_offset, torch.SymInt) and not isinstance(
|
||||
zero_global_offset, torch.SymInt
|
||||
):
|
||||
zero_global_offset = torch.SymInt(zero_global_offset)
|
||||
|
||||
global_offset[shard_dim] = torch.sym_ite(
|
||||
shard_size == 0,
|
||||
# Special case to fill in a standardized non-garbage value for
|
||||
@ -179,11 +187,11 @@ def _compute_local_shape_and_global_offset(
|
||||
# Note that you can end up with zero-size shards that are
|
||||
# still otherwise in bounds for the tensor (TODO: give an
|
||||
# example).
|
||||
global_shape[shard_dim],
|
||||
zero_global_offset,
|
||||
# As we successively shard the same dimension, we keep
|
||||
# advancing our pointer beyond our original offset until we
|
||||
# get to the final chunk start.
|
||||
global_offset[shard_dim] + not_none(shard_offset),
|
||||
shard_global_offset,
|
||||
)
|
||||
|
||||
# NOTE: the offset compute relies on the local shard index and it has no
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch.distributed._local_tensor import maybe_run_for_local_tensor
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._collective_utils import (
|
||||
fill_empty_tensor_to_shards,
|
||||
@ -128,6 +129,7 @@ class Shard(Placement):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def local_shard_size_and_offset(
|
||||
curr_local_size: int,
|
||||
num_chunks: int,
|
||||
@ -170,6 +172,20 @@ class Shard(Placement):
|
||||
) -> tuple[int, Optional[int]]:
|
||||
return Shard.local_shard_size_and_offset(curr_local_size, num_chunks, rank)
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _maybe_unpad_tensor_with_sizes(
|
||||
dim, local_tensor, pad_sizes, mesh_dim_local_rank, make_contiguous
|
||||
) -> torch.Tensor:
|
||||
# Only unpad if the local_tensor was padded on the dimension.
|
||||
if pad_sizes[mesh_dim_local_rank] > 0:
|
||||
local_tensor = unpad_tensor(
|
||||
local_tensor, dim, pad_sizes[mesh_dim_local_rank]
|
||||
)
|
||||
if make_contiguous:
|
||||
local_tensor = local_tensor.contiguous()
|
||||
return local_tensor
|
||||
|
||||
@staticmethod
|
||||
def _make_shard_tensor(
|
||||
dim: int,
|
||||
@ -198,24 +214,28 @@ class Shard(Placement):
|
||||
dim, tensor, num_chunks, with_padding=False, contiguous=True
|
||||
)
|
||||
|
||||
return scatter_list[mesh_dim_local_rank]
|
||||
return Shard._select_shard(scatter_list, mesh_dim_local_rank)
|
||||
|
||||
scatter_list, pad_sizes = Shard._make_split_tensor(
|
||||
dim, tensor, num_chunks, with_padding=True, contiguous=True
|
||||
)
|
||||
output = torch.empty_like(scatter_list[mesh_dim_local_rank])
|
||||
|
||||
it = iter(scatter_list)
|
||||
first = next(it)
|
||||
# Tensors in the scatter list are expected to have the same shape because
|
||||
# split is requested with padding.
|
||||
assert all(first.shape == v.shape for v in it)
|
||||
|
||||
output = torch.empty_like(first)
|
||||
|
||||
# perform scatter from the src_data_rank as data source when it is not None
|
||||
mesh_scatter(
|
||||
output, scatter_list, mesh, mesh_dim=mesh_dim, group_src=src_data_rank
|
||||
)
|
||||
|
||||
# Only unpad if the local_tensor was padded on the dimension.
|
||||
if pad_sizes[mesh_dim_local_rank] > 0:
|
||||
output = unpad_tensor(output, dim, pad_sizes[mesh_dim_local_rank])
|
||||
# Unpad might return a view, hence we need to remake it contiguous
|
||||
output = output.contiguous()
|
||||
return output
|
||||
return Shard._maybe_unpad_tensor_with_sizes(
|
||||
dim, output, pad_sizes, mesh_dim_local_rank, True
|
||||
)
|
||||
|
||||
def _shard_tensor(
|
||||
self,
|
||||
@ -245,6 +265,7 @@ class Shard(Placement):
|
||||
return tensor
|
||||
|
||||
is_padded = tensor.size(self.dim) % num_chunks != 0
|
||||
pad_sizes = None
|
||||
if is_padded:
|
||||
scattered_list, pad_sizes = Shard._make_split_tensor(
|
||||
self.dim, tensor, num_chunks, with_padding=True, contiguous=True
|
||||
@ -258,9 +279,47 @@ class Shard(Placement):
|
||||
)
|
||||
|
||||
if is_padded:
|
||||
output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined]
|
||||
assert pad_sizes is not None
|
||||
output = Shard._maybe_unpad_tensor_with_sizes(
|
||||
self.dim, output, pad_sizes, my_coordinate[mesh_dim], False
|
||||
)
|
||||
return output
|
||||
|
||||
@maybe_run_for_local_tensor
|
||||
def _maybe_pad_tensor(
|
||||
self,
|
||||
local_tensor: torch.Tensor,
|
||||
logical_dim_size: int,
|
||||
num_chunks: int,
|
||||
) -> torch.Tensor:
|
||||
is_padded = logical_dim_size % num_chunks != 0
|
||||
|
||||
if is_padded:
|
||||
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
|
||||
pad_size = full_chunk_size - local_tensor.size(self.dim)
|
||||
local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
|
||||
|
||||
if not local_tensor.is_contiguous():
|
||||
local_tensor = local_tensor.contiguous()
|
||||
|
||||
return local_tensor
|
||||
|
||||
@maybe_run_for_local_tensor
|
||||
def _maybe_unpad_tensor(
|
||||
self,
|
||||
local_tensor: torch.Tensor,
|
||||
logical_dim_size: int,
|
||||
num_chunks: int,
|
||||
) -> torch.Tensor:
|
||||
is_padded = logical_dim_size % num_chunks != 0
|
||||
|
||||
if is_padded:
|
||||
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
|
||||
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined]
|
||||
local_tensor = unpad_tensor(local_tensor, self.dim, unpad_size)
|
||||
|
||||
return local_tensor
|
||||
|
||||
def _to_replicate_tensor(
|
||||
self,
|
||||
local_tensor: torch.Tensor,
|
||||
@ -273,28 +332,27 @@ class Shard(Placement):
|
||||
is replicated on the previously sharded mesh dimension
|
||||
"""
|
||||
num_chunks = mesh.size(mesh_dim=mesh_dim)
|
||||
|
||||
logical_dim_size = current_logical_shape[self.dim]
|
||||
is_padded = logical_dim_size % num_chunks != 0
|
||||
|
||||
if is_padded:
|
||||
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
|
||||
pad_size = full_chunk_size - local_tensor.size(self.dim)
|
||||
local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
|
||||
|
||||
if not local_tensor.is_contiguous():
|
||||
local_tensor = local_tensor.contiguous()
|
||||
local_tensor = self._maybe_pad_tensor(
|
||||
local_tensor, logical_dim_size, num_chunks
|
||||
)
|
||||
|
||||
result = funcol.all_gather_tensor(
|
||||
local_tensor,
|
||||
gather_dim=self.dim,
|
||||
group=(mesh, mesh_dim),
|
||||
)
|
||||
if is_padded:
|
||||
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined]
|
||||
result = unpad_tensor(result, self.dim, unpad_size)
|
||||
|
||||
result = self._maybe_unpad_tensor(result, logical_dim_size, num_chunks)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _select_shard(shards: list[torch.Tensor], shard_index) -> torch.Tensor:
|
||||
return shards[shard_index].clone()
|
||||
|
||||
def _replicate_to_shard(
|
||||
self,
|
||||
local_tensor: torch.Tensor,
|
||||
@ -313,7 +371,8 @@ class Shard(Placement):
|
||||
with_padding=False,
|
||||
contiguous=False,
|
||||
)
|
||||
return shards[shard_index].clone()
|
||||
|
||||
return Shard._select_shard(shards, shard_index)
|
||||
|
||||
def _to_new_shard_dim(
|
||||
self,
|
||||
|
||||
@ -41,6 +41,9 @@ class ConstantIntNode:
|
||||
def _graph_repr(self) -> str:
|
||||
return self._str()
|
||||
|
||||
def add(self, other: Any) -> Any:
|
||||
return other.add(self)
|
||||
|
||||
def mul(self, other: Any) -> Any:
|
||||
return other.mul(self)
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._local_tensor import LocalTensor
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
@ -660,7 +661,7 @@ class DTensorConverter:
|
||||
def to_dist_tensor(
|
||||
self, t: torch.Tensor, mesh: DeviceMesh, placements: list[Placement]
|
||||
) -> torch.Tensor:
|
||||
if type(t) is torch.Tensor or type(t) is nn.Parameter:
|
||||
if type(t) is torch.Tensor or type(t) is nn.Parameter or type(t) is LocalTensor:
|
||||
if self.is_supported_tensor(t):
|
||||
self.hit += 1
|
||||
if t.ndim == 0:
|
||||
@ -669,7 +670,7 @@ class DTensorConverter:
|
||||
else:
|
||||
# distribute non-scalar tensors
|
||||
r = distribute_tensor(t, mesh, placements)
|
||||
if type(t) is nn.Parameter:
|
||||
if isinstance(t, nn.Parameter):
|
||||
r = nn.Parameter( # type: ignore[assignment]
|
||||
r, requires_grad=r.requires_grad
|
||||
)
|
||||
|
||||
@ -1292,7 +1292,7 @@ SAC_IGNORED_OPS = {
|
||||
# With subclasses involved, these metadata ops become dispatchable, this
|
||||
# can result in incorrectness if these ops are selected cached.
|
||||
torch.ops.prim.device.default,
|
||||
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
|
||||
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) # type: ignore[has-type]
|
||||
|
||||
|
||||
class _CachingTorchDispatchMode(TorchDispatchMode):
|
||||
|
||||
Reference in New Issue
Block a user