mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
LocalTensor (#164537)
A LocalTensor is a tensor subclass which simulates a tensor that is distributed across SPMD ranks. A LocalTensor might be size N, but in fact there are world_size shards/replicas of it stored internally. When you do a plain PyTorch operation on it, we apply the operation to each shard; when you do a collective, we do the mathematically equivalent operation on the local shards. A LocalTensor is associated with a list of ranks which specify which ranks it holds local tensors for. NB, this is NOT a DataParallel like abstraction where you can run operations on multiple different GPUs. It is intended purely for *debugging* purposes, the overhead is almost certainly too high to keep eight GPUs (even the C++ autograd needs multithreading to keep up!) (It might potentially be possible to trace through this with torch.compile and then compile it with CUDA graphs but this is currently a non-goal.) In order to handle MPMD, we provide a helper decorator that allows you to run a function with no side effects for each LocalTensor shard and combine results back into LocalTensor or LocalIntNode. Note: This PR convert all DTensor ops and some DTensor tests to illustrate intended usage and ensure conrrectness. In subsequent PR more tests will be converted. DUring test conversion we aim to share as much as possible of test logic between multi-process / multi-threaded and local tensor tests. We would like to developers to be able to run both flavors of the tests. Note: This work is based on the original proposal by @ezyang (WIP PR https://github.com/pytorch/pytorch/pull/162753). Pull Request resolved: https://github.com/pytorch/pytorch/pull/164537 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a2601630cd
commit
5e58420dff
@ -3,13 +3,23 @@
|
||||
|
||||
import pathlib
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
from torch.distributed._local_tensor import (
|
||||
LocalIntNode,
|
||||
LocalTensorMode,
|
||||
maybe_run_for_local_tensor,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
@ -44,6 +54,11 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
c10d_functional = torch.ops.c10d_functional
|
||||
|
||||
|
||||
@maybe_run_for_local_tensor
|
||||
def map_tensor_for_rank(tensor, rank, func):
|
||||
return func(tensor, rank)
|
||||
|
||||
|
||||
class DummyMLP(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
@ -592,7 +607,12 @@ class DTensorTest(DTensorTestBase):
|
||||
self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
|
||||
self.assertEqual(sharded_tensor.placements, placements)
|
||||
local_tensor = sharded_tensor.to_local()
|
||||
self.assertEqual(local_tensor, full_tensor[range(self.rank, self.rank + 1), :])
|
||||
self.assertEqual(
|
||||
local_tensor,
|
||||
map_tensor_for_rank(
|
||||
full_tensor, self.rank, lambda ft, r: ft[range(r, r + 1), :]
|
||||
),
|
||||
)
|
||||
|
||||
# Shard by column
|
||||
placements = [Shard(1)]
|
||||
@ -600,7 +620,12 @@ class DTensorTest(DTensorTestBase):
|
||||
self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
|
||||
self.assertEqual(sharded_tensor.placements, placements)
|
||||
local_tensor = sharded_tensor.to_local()
|
||||
self.assertEqual(local_tensor, full_tensor[:, range(self.rank, self.rank + 1)])
|
||||
self.assertEqual(
|
||||
local_tensor,
|
||||
map_tensor_for_rank(
|
||||
full_tensor, self.rank, lambda ft, r: ft[:, range(r, r + 1)]
|
||||
),
|
||||
)
|
||||
|
||||
# assert full tensor is not changed
|
||||
self.assertEqual(full_tensor, torch.arange(ws * ws).reshape(ws, ws))
|
||||
@ -620,6 +645,105 @@ class DTensorTest(DTensorTestBase):
|
||||
self.assertEqual(local_tensor.item(), self.rank)
|
||||
|
||||
|
||||
class LocalDTensorTest(DTensorTest):
|
||||
def get_local_tensor_mode(self):
|
||||
return LocalTensorMode(frozenset(range(0, self.world_size)))
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return torch.SymInt(LocalIntNode({r: r for r in range(self.world_size)}))
|
||||
|
||||
@rank.setter
|
||||
def rank(self, rank):
|
||||
pass
|
||||
|
||||
def join_or_run(self, fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self):
|
||||
fn()
|
||||
|
||||
return types.MethodType(wrapper, self)
|
||||
|
||||
def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
|
||||
dist.init_process_group("fake", rank=0, world_size=self.world_size)
|
||||
self._pg = c10d._get_default_group()
|
||||
|
||||
def destroy_pg(self, device_id: Optional[int] = None) -> None:
|
||||
dist.destroy_process_group(self._pg)
|
||||
self._pg = None
|
||||
|
||||
def _spawn_processes(self) -> None:
|
||||
pass
|
||||
|
||||
def test_dtensor_constructor(self):
|
||||
pass
|
||||
|
||||
def test_meta_dtensor(self):
|
||||
pass
|
||||
|
||||
def test_modules_w_meta_dtensor(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_stride(self):
|
||||
pass
|
||||
|
||||
def test_from_local(self):
|
||||
pass
|
||||
|
||||
def test_from_local_uneven_sharding(self):
|
||||
pass
|
||||
|
||||
def test_from_local_uneven_sharding_raise_error(self):
|
||||
pass
|
||||
|
||||
def test_from_local_negative_dim(self):
|
||||
pass
|
||||
|
||||
def test_to_local(self):
|
||||
pass
|
||||
|
||||
def test_to_local_grad_hint(self):
|
||||
pass
|
||||
|
||||
def test_full_tensor_sync(self):
|
||||
pass
|
||||
|
||||
def test_full_tensor_grad_hint(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_new_empty_strided(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_async_output(self):
|
||||
pass
|
||||
|
||||
def test_from_local_then_to_local(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_spec_read_only_after_set(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_spec_hash(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_properties(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_save_load(self):
|
||||
pass
|
||||
|
||||
def test_dtensor_save_load_import(self):
|
||||
pass
|
||||
|
||||
def test_shard_tensor_2d(self):
|
||||
with self.get_local_tensor_mode():
|
||||
super().test_shard_tensor_2d()
|
||||
|
||||
def test_shard_tensor(self):
|
||||
with self.get_local_tensor_mode():
|
||||
super().test_shard_tensor()
|
||||
|
||||
|
||||
class DTensorMeshTest(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import copy
|
||||
import re
|
||||
import unittest
|
||||
import warnings
|
||||
@ -8,6 +9,7 @@ import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.testing._internal.common_methods_invocations as common_ops
|
||||
from torch.distributed._local_tensor import LocalTensorMode, reconcile_args
|
||||
from torch.distributed.tensor import (
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
@ -21,7 +23,7 @@ from torch.testing._internal.common_device_type import (
|
||||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db
|
||||
from torch.testing._internal.common_utils import run_tests, suppress_warnings
|
||||
from torch.testing._internal.common_utils import run_tests, suppress_warnings, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorConverter,
|
||||
DTensorOpTestBase,
|
||||
@ -49,7 +51,7 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
|
||||
return (op_name, variant_name, device_type, dtypes, False)
|
||||
|
||||
|
||||
def skipOps(test_case_name, base_test_name, to_skip):
|
||||
def skipOps(op_db, test_case_name, base_test_name, to_skip):
|
||||
all_opinfos = op_db
|
||||
for xfail in to_skip:
|
||||
op_name, variant_name, device_type, dtypes, expected_failure = xfail
|
||||
@ -88,6 +90,34 @@ def skipOps(test_case_name, base_test_name, to_skip):
|
||||
return wrapped
|
||||
|
||||
|
||||
def repurpose_ops(op_db, base_test_name, derived_test_name):
|
||||
"""
|
||||
Copies op info database and for the decorators that applied to base test class updates
|
||||
them to apply to derived test class. The class update is required because decorators are applied
|
||||
only if the class name matches (it doesn't consider base classes).
|
||||
|
||||
Specifically we use this function to create two test classes (one for multi-threaded and one for
|
||||
local tensor flavors) that share common test body but different rules for skip or fail.
|
||||
|
||||
Args:
|
||||
op_db: List of OpInfo objects to be repurposed.
|
||||
base_test_name: The original test class name to be replaced.
|
||||
derived_test_name: The new test class name to set in decorators.
|
||||
|
||||
Returns:
|
||||
list: A new list of OpInfo objects with updated target class names for the
|
||||
decorator.
|
||||
"""
|
||||
repurposed_ops = []
|
||||
for opinfo in op_db:
|
||||
opinfo_copy = copy.deepcopy(opinfo)
|
||||
for decorator in list(opinfo_copy.decorators):
|
||||
if hasattr(decorator, "cls_name") and decorator.cls_name == base_test_name:
|
||||
decorator.cls_name = derived_test_name
|
||||
repurposed_ops.append(opinfo_copy)
|
||||
return repurposed_ops
|
||||
|
||||
|
||||
# Re-generate this failed list, turn on dry_run of the below func
|
||||
# check_dtensor_func(self, test, op, dry_run=True), then run sth
|
||||
# like python test/distributed/tensor/test_dtensor_ops.py > failed.expect
|
||||
@ -162,7 +192,6 @@ dtensor_fails = {
|
||||
xfail("fmin"),
|
||||
xfail("frexp"),
|
||||
xfail("full"),
|
||||
xfail("full_like"),
|
||||
xfail("geometric"),
|
||||
xfail("geqrf"),
|
||||
xfail("grid_sampler_2d"),
|
||||
@ -226,7 +255,6 @@ dtensor_fails = {
|
||||
xfail("masked_select"),
|
||||
xfail("masked.argmax"),
|
||||
xfail("masked.argmin"),
|
||||
xfail("masked.cumprod"),
|
||||
xfail("masked.logsumexp"),
|
||||
xfail("masked.median"),
|
||||
xfail("matrix_exp"),
|
||||
@ -244,8 +272,6 @@ dtensor_fails = {
|
||||
xfail("native_batch_norm"),
|
||||
xfail("narrow_copy"),
|
||||
xfail("ne"),
|
||||
xfail("new_empty"),
|
||||
xfail("new_empty_strided"),
|
||||
xfail("transpose"),
|
||||
xfail("nn.functional.adaptive_avg_pool1d"),
|
||||
xfail("nn.functional.adaptive_avg_pool2d"),
|
||||
@ -272,8 +298,6 @@ dtensor_fails = {
|
||||
xfail("nn.functional.cosine_similarity"),
|
||||
xfail("nn.functional.ctc_loss"),
|
||||
xfail("nn.functional.dropout"),
|
||||
xfail("nn.functional.dropout2d"),
|
||||
xfail("nn.functional.dropout3d"),
|
||||
xfail("nn.functional.elu"),
|
||||
xfail("nn.functional.fractional_max_pool2d"),
|
||||
xfail("nn.functional.fractional_max_pool3d"),
|
||||
@ -307,7 +331,6 @@ dtensor_fails = {
|
||||
xfail("nn.functional.multi_margin_loss"),
|
||||
xfail("nn.functional.multilabel_margin_loss"),
|
||||
xfail("nn.functional.multilabel_soft_margin_loss"),
|
||||
xfail("nn.functional.multi_head_attention_forward"),
|
||||
xfail("nn.functional.pad", "reflect"),
|
||||
xfail("nn.functional.pad", "replicate"),
|
||||
xfail("nn.functional.pad", "replicate_negative"),
|
||||
@ -482,13 +505,21 @@ dtensor_fails = {
|
||||
skip("_segment_reduce", "offsets"),
|
||||
# TODO: fix the following ops
|
||||
skip("squeeze"),
|
||||
# These must be skipped as their contents are nondeterministic
|
||||
skip("empty"),
|
||||
skip("empty_strided"),
|
||||
skip("empty_like"),
|
||||
skip("empty_permuted"),
|
||||
skip("new_empty"),
|
||||
skip("new_empty_strided"),
|
||||
}
|
||||
|
||||
dtensor_multi_threaded_fails = {
|
||||
xfail("full_like"),
|
||||
xfail("nn.functional.dropout2d"),
|
||||
xfail("nn.functional.dropout3d"),
|
||||
xfail("masked.cumprod"),
|
||||
skip("nn.functional.multi_head_attention_forward"),
|
||||
}
|
||||
|
||||
# Add a list of ops that are currently failing BW pass
|
||||
skip_bw = [
|
||||
@ -507,7 +538,13 @@ OP_DB_WORLD_SIZE = 4
|
||||
DEVICE_TYPE = "cpu"
|
||||
|
||||
|
||||
class TestDTensorOps(DTensorOpTestBase):
|
||||
class TestDTensorOps(TestCase):
|
||||
__test__ = False
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls.__test__ = True
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return OP_DB_WORLD_SIZE
|
||||
@ -535,14 +572,6 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
|
||||
self.check_dtensor_func(test, op)
|
||||
|
||||
# only allow float dytpe for now, we can relax this constraint
|
||||
# when feel necessary later (i.e when adding quantization support).
|
||||
@suppress_warnings
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails)
|
||||
def test_dtensor_op_db(self, dtype, op):
|
||||
self.run_opinfo_test(dtype, op)
|
||||
|
||||
def assert_ref_dtensor_equal(self, dtensor_rs, rs):
|
||||
flat_dtensor_rs = pytree.tree_leaves(dtensor_rs)
|
||||
flat_rs = pytree.tree_leaves(rs)
|
||||
@ -567,6 +596,9 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
|
||||
self.assertEqualOnRank(dtensor_r, r)
|
||||
|
||||
def assertEqualOnRank(self, x, y, msg=None, *, rank=0) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def run_dtensor_crossref(self, func, args, kwargs):
|
||||
to_dtensor = DTensorConverter(self.mesh, args, kwargs)
|
||||
|
||||
@ -580,7 +612,8 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
return res
|
||||
|
||||
# TODO: also handle cases where func raise an exception
|
||||
rs = func(*args, **kwargs)
|
||||
op_args, op_kwargs = reconcile_args(args, kwargs)
|
||||
rs = func(*op_args, **op_kwargs)
|
||||
rs = concat_res_if_necessary(func, rs)
|
||||
|
||||
def to_replicate(e: object) -> object:
|
||||
@ -635,12 +668,12 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
self.assert_ref_dtensor_equal(dtensor_rs, rs)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"failed to convert args to DTensor; "
|
||||
f"Failed to convert args to DTensor; "
|
||||
f"originally (*{args}, **{kwargs})"
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"{str(e)}\n\nfailed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
|
||||
f"{str(e)}\n\nFailed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
|
||||
) from e
|
||||
return rs
|
||||
|
||||
@ -656,7 +689,7 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
else:
|
||||
print(f"xfail('{opinfo.name}'),")
|
||||
|
||||
def test_one_hot(self):
|
||||
def run_one_hot(self):
|
||||
ops = [op for op in op_db if op.name == "nn.functional.one_hot"]
|
||||
assert len(ops) == 1
|
||||
op = ops[0]
|
||||
@ -668,7 +701,7 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
sample_inputs_filter=lambda s: s.kwargs["num_classes"] != -1,
|
||||
)
|
||||
|
||||
def test_mean(self):
|
||||
def run_mean(self):
|
||||
self.mesh = init_device_mesh(DEVICE_TYPE, (self.world_size,))
|
||||
|
||||
shape = [2 * self.world_size + 1, 2 * self.world_size]
|
||||
@ -692,6 +725,7 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
full_tensor = mean.full_tensor()
|
||||
|
||||
self.assertEqual(full_tensor, tensor.mean(dim=reduce_dim))
|
||||
|
||||
if is_evenly_shardable:
|
||||
self.assertTrue("P->R" in debug_mode.debug_string())
|
||||
else:
|
||||
@ -720,9 +754,76 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
_ = torch.ops.aten.embedding.default(weight_dtensor, input_dtensor)
|
||||
|
||||
|
||||
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
|
||||
instantiate_device_type_tests(TestDTensorOps, globals(), only_for=(DEVICE_TYPE,))
|
||||
class TestMultiThreadedDTensorOps(DTensorOpTestBase, TestDTensorOps):
|
||||
_op_db = repurpose_ops(op_db, "TestDTensorOps", "TestMultiThreadedDTensorOps")
|
||||
|
||||
@suppress_warnings
|
||||
@ops(_op_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps(
|
||||
_op_db,
|
||||
"TestMultiThreadedDTensorOps",
|
||||
"test_dtensor_op_db",
|
||||
dtensor_fails | dtensor_multi_threaded_fails,
|
||||
)
|
||||
def test_dtensor_op_db(self, dtype, op):
|
||||
self.run_opinfo_test(dtype, op)
|
||||
|
||||
def test_mean(self):
|
||||
self.run_mean()
|
||||
|
||||
def test_one_hot(self):
|
||||
self.run_one_hot()
|
||||
|
||||
|
||||
class TestLocalDTensorOps(TestDTensorOps):
|
||||
_op_db = repurpose_ops(op_db, "TestDTensorOps", "TestLocalDTensorOps")
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.distributed.init_process_group("fake", rank=0, world_size=self.world_size)
|
||||
self.fake_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
dist.destroy_process_group()
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
@suppress_warnings
|
||||
@ops(_op_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps(
|
||||
_op_db,
|
||||
"TestLocalDTensorOps",
|
||||
"test_dtensor_op_db",
|
||||
dtensor_fails,
|
||||
)
|
||||
def test_dtensor_op_db(self, dtype, op):
|
||||
self.run_opinfo_test(dtype, op)
|
||||
|
||||
def test_mean(self):
|
||||
with LocalTensorMode(frozenset(range(0, self.world_size))):
|
||||
self.run_mean()
|
||||
|
||||
def test_one_hot(self):
|
||||
self.run_one_hot()
|
||||
|
||||
def run_opinfo_test(
|
||||
self, dtype, op, requires_grad=True, sample_inputs_filter=lambda s: True
|
||||
):
|
||||
with LocalTensorMode(frozenset(range(0, self.world_size))):
|
||||
super().run_opinfo_test(dtype, op, requires_grad, sample_inputs_filter)
|
||||
|
||||
def assertEqualOnRank(self, x, y, msg=None, *, rank=0):
|
||||
self.assertEqual(x, y, msg)
|
||||
|
||||
|
||||
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
|
||||
instantiate_device_type_tests(
|
||||
TestMultiThreadedDTensorOps, globals(), only_for=(DEVICE_TYPE,)
|
||||
)
|
||||
|
||||
instantiate_device_type_tests(TestLocalDTensorOps, globals(), only_for=(DEVICE_TYPE,))
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
415
test/distributed/test_local_tensor.py
Normal file
415
test/distributed/test_local_tensor.py
Normal file
@ -0,0 +1,415 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._local_tensor import (
|
||||
local_tensor_mode,
|
||||
LocalTensor,
|
||||
LocalTensorMode,
|
||||
)
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
init_device_mesh,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class LocalTensorTestBase(TestCase):
|
||||
def assertEqual(self, lhs, rhs, **kwargs):
|
||||
mode = local_tensor_mode()
|
||||
with nullcontext() if mode is None else mode.disable():
|
||||
if isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor):
|
||||
assert isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor)
|
||||
super().assertEqual(lhs._ranks, rhs._ranks)
|
||||
for r in lhs._ranks:
|
||||
super().assertEqual(
|
||||
lhs._local_tensors[r],
|
||||
rhs._local_tensors[r],
|
||||
lambda m: f"rank {r}: {m}",
|
||||
)
|
||||
elif isinstance(lhs, LocalTensor) or isinstance(rhs, LocalTensor):
|
||||
lhs, rhs = (lhs, rhs) if isinstance(lhs, LocalTensor) else (rhs, lhs)
|
||||
for r in lhs._ranks:
|
||||
super().assertEqual(
|
||||
lhs._local_tensors[r], rhs, lambda m: f"rank {r}: {m}"
|
||||
)
|
||||
else:
|
||||
return super().assertEqual(lhs, rhs, **kwargs)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
raise NotImplementedError("override world-size in your subclass")
|
||||
|
||||
def build_device_mesh(self) -> DeviceMesh:
|
||||
return init_device_mesh("cpu", (self.world_size,))
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch.distributed.init_process_group(
|
||||
# TODO: test other ranks too
|
||||
"fake",
|
||||
rank=0,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
dist.destroy_process_group()
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
|
||||
class TestLocalTensorWorld2(LocalTensorTestBase):
|
||||
world_size = 2
|
||||
|
||||
def test_local_tensor_dtype_consistency(self):
|
||||
"""Test that LocalTensor enforces dtype consistency."""
|
||||
device = torch.device("cpu")
|
||||
shape = (2, 3)
|
||||
|
||||
inconsistent_tensors = {
|
||||
0: torch.randn(shape, dtype=torch.float32, device=device),
|
||||
1: torch.randn(
|
||||
shape, dtype=torch.float64, device=device
|
||||
), # Different dtype
|
||||
}
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
LocalTensor(inconsistent_tensors)
|
||||
|
||||
def test_local_tensor_creation_fails_with_grad_tensors(self):
|
||||
"""Test that LocalTensor creation fails when local tensors have requires_grad=True."""
|
||||
device = torch.device("cpu")
|
||||
shape = (2, 3)
|
||||
dtype = torch.float32
|
||||
|
||||
# Create sample local tensors for different ranks
|
||||
local_tensors = {
|
||||
0: torch.randn(shape, dtype=dtype, device=device, requires_grad=True),
|
||||
1: torch.randn(shape, dtype=dtype, device=device, requires_grad=True),
|
||||
}
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
LocalTensor(local_tensors)
|
||||
|
||||
# TODO: test flatten/unflatten
|
||||
|
||||
def test_basic_arithmetic_operations(self):
|
||||
"""Test basic arithmetic operations on LocalTensors."""
|
||||
device = torch.device("cpu")
|
||||
shape = (2, 3)
|
||||
dtype = torch.float32
|
||||
|
||||
# Create identical local tensors for consistency tests
|
||||
base_tensor = torch.randn(shape, dtype=dtype, device=device)
|
||||
identical_local_tensors = {
|
||||
0: base_tensor.clone(),
|
||||
1: base_tensor.clone(),
|
||||
}
|
||||
|
||||
lt1 = LocalTensor(identical_local_tensors)
|
||||
lt2 = LocalTensor(identical_local_tensors)
|
||||
|
||||
# Test addition
|
||||
result_add = lt1 + lt2
|
||||
self.assertIsInstance(result_add, LocalTensor)
|
||||
self.assertEqual(len(result_add._local_tensors), 2)
|
||||
|
||||
# Verify the operation was applied to each local tensor
|
||||
for rank in identical_local_tensors.keys():
|
||||
expected = identical_local_tensors[rank] + identical_local_tensors[rank]
|
||||
self.assertEqual(result_add._local_tensors[rank], expected)
|
||||
|
||||
# Test multiplication
|
||||
result_mul = lt1 * 2.0
|
||||
self.assertIsInstance(result_mul, LocalTensor)
|
||||
for rank in identical_local_tensors.keys():
|
||||
expected = identical_local_tensors[rank] * 2.0
|
||||
self.assertEqual(result_mul._local_tensors[rank], expected)
|
||||
|
||||
# TODO: consider an op-info test; we don't actually need to cover all ops
|
||||
# but it will help make sure views and more exotic things are done
|
||||
# correctly (in standard subclass style)
|
||||
|
||||
def test_mixed_operations_with_regular_tensors(self):
|
||||
"""Test operations between LocalTensors and regular tensors."""
|
||||
device = torch.device("cpu")
|
||||
shape = (2, 3)
|
||||
dtype = torch.float32
|
||||
|
||||
# Create identical local tensors for consistency tests
|
||||
base_tensor = torch.randn(shape, dtype=dtype, device=device)
|
||||
identical_local_tensors = {
|
||||
0: base_tensor.clone(),
|
||||
1: base_tensor.clone(),
|
||||
}
|
||||
|
||||
lt = LocalTensor(identical_local_tensors)
|
||||
regular_tensor = torch.ones_like(identical_local_tensors[0])
|
||||
|
||||
# Test LocalTensor + regular tensor
|
||||
result = lt + regular_tensor
|
||||
self.assertIsInstance(result, LocalTensor)
|
||||
|
||||
for rank in identical_local_tensors.keys():
|
||||
expected = identical_local_tensors[rank] + regular_tensor
|
||||
self.assertEqual(result._local_tensors[rank], expected)
|
||||
|
||||
def test_local_tensor_mode(self):
|
||||
"""Test LocalTensorMode functionality."""
|
||||
device = torch.device("cpu")
|
||||
shape = (2, 3)
|
||||
dtype = torch.float32
|
||||
|
||||
# Create identical local tensors for consistency tests
|
||||
base_tensor = torch.randn(shape, dtype=dtype, device=device)
|
||||
identical_local_tensors = {
|
||||
0: base_tensor.clone(),
|
||||
1: base_tensor.clone(),
|
||||
}
|
||||
|
||||
lt = LocalTensor(identical_local_tensors)
|
||||
|
||||
with LocalTensorMode(lt._ranks):
|
||||
result = lt + 1.0
|
||||
self.assertIsInstance(result, LocalTensor)
|
||||
|
||||
regular = torch.ones(2, 2)
|
||||
regular_result = regular + 1.0
|
||||
self.assertIsInstance(regular, LocalTensor)
|
||||
self.assertIsInstance(regular_result, LocalTensor)
|
||||
|
||||
def test_empty_local_tensors(self):
|
||||
"""Test behavior with empty local tensors dict."""
|
||||
# TODO: raise a better error here
|
||||
with self.assertRaises(StopIteration): # next() on empty iterator
|
||||
LocalTensor({})
|
||||
|
||||
def test_collectives_within_local_tensor_mode(self):
|
||||
"""Test that collective operations work within LocalTensorMode context."""
|
||||
test_tensors = {
|
||||
0: torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
|
||||
1: torch.tensor([[5.0, 6.0], [7.0, 8.0]]),
|
||||
}
|
||||
lt = LocalTensor(test_tensors)
|
||||
fake_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
with LocalTensorMode(lt._ranks):
|
||||
# Test all_reduce within mode
|
||||
lt_sum = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
|
||||
dist.all_reduce(lt_sum, group=fake_pg)
|
||||
|
||||
expected_sum = torch.tensor([[6.0, 8.0], [10.0, 12.0]])
|
||||
for rank in test_tensors.keys():
|
||||
self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
|
||||
|
||||
# Test broadcast within mode
|
||||
lt_broadcast = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
|
||||
dist.broadcast(lt_broadcast, src=0, group=fake_pg)
|
||||
|
||||
for rank in test_tensors.keys():
|
||||
self.assertEqual(lt_broadcast._local_tensors[rank], test_tensors[0])
|
||||
|
||||
# Test that regular operations still work
|
||||
result = lt + 1.0
|
||||
self.assertIsInstance(result, LocalTensor)
|
||||
|
||||
def test_scalar_mul_reduction_bug(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
mesh = self.build_device_mesh()
|
||||
|
||||
tensor = torch.tensor([10, 10]).float()
|
||||
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
|
||||
y = dt.sum() * 1 # noqa: F841
|
||||
|
||||
tensor = torch.arange(10).reshape(10, 1).float().requires_grad_()
|
||||
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
|
||||
|
||||
print(dt.sum() * 1, dt.sum() * 2, dt.sum() * 3)
|
||||
|
||||
def test_uneven_sharding_mean_bug(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
mesh = self.build_device_mesh()
|
||||
tensor = torch.arange(12).reshape(-1, 4).float()
|
||||
|
||||
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
|
||||
|
||||
mean = dt.mean()
|
||||
self.assertEqual(mean.placements, [Replicate()])
|
||||
full = mean.full_tensor()
|
||||
self.assertEqual(tensor.mean(), full)
|
||||
|
||||
def test_uneven_sharding_prod(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
mesh = self.build_device_mesh()
|
||||
tensor = (torch.arange(12) + 1).reshape(-1, 4).float()
|
||||
|
||||
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
|
||||
|
||||
x = dt.prod()
|
||||
full = x.full_tensor()
|
||||
self.assertEqual(tensor.prod(), full)
|
||||
|
||||
def test_even_sharding_mean_is_partial(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
mesh = self.build_device_mesh()
|
||||
tensor = torch.arange(16).reshape(4, 4).float()
|
||||
|
||||
dt = distribute_tensor(tensor, device_mesh=mesh, placements=[Shard(0)])
|
||||
|
||||
mean = dt.mean()
|
||||
full = mean.full_tensor()
|
||||
self.assertEqual(tensor.mean(), full)
|
||||
self.assertEqual(mean.placements, [Partial("avg")])
|
||||
|
||||
|
||||
class TestLocalTensorWorld3(LocalTensorTestBase):
|
||||
world_size = 3
|
||||
|
||||
def test_collective_reduction_operations(self):
|
||||
"""Test different reduction operations for all_reduce."""
|
||||
# Create different tensors for each rank with simple values for testing
|
||||
test_tensors = {
|
||||
0: torch.tensor([[1.0, 4.0], [2.0, 5.0]]),
|
||||
1: torch.tensor([[2.0, 1.0], [3.0, 6.0]]),
|
||||
2: torch.tensor([[3.0, 2.0], [1.0, 4.0]]),
|
||||
}
|
||||
|
||||
fake_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
# Test SUM reduction
|
||||
lt_sum = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
|
||||
dist.all_reduce(lt_sum, op=dist.ReduceOp.SUM, group=fake_pg)
|
||||
expected_sum = torch.tensor([[6.0, 7.0], [6.0, 15.0]]) # Sum of all tensors
|
||||
for rank in test_tensors.keys():
|
||||
self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
|
||||
|
||||
# Test MAX reduction
|
||||
lt_max = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
|
||||
dist.all_reduce(lt_max, op=dist.ReduceOp.MAX, group=fake_pg)
|
||||
expected_max = torch.tensor([[3.0, 4.0], [3.0, 6.0]]) # Max across all tensors
|
||||
for rank in test_tensors.keys():
|
||||
self.assertEqual(lt_max._local_tensors[rank], expected_max)
|
||||
|
||||
# Test MIN reduction
|
||||
lt_min = LocalTensor({k: v.clone() for k, v in test_tensors.items()})
|
||||
dist.all_reduce(lt_min, op=dist.ReduceOp.MIN, group=fake_pg)
|
||||
expected_min = torch.tensor([[1.0, 1.0], [1.0, 4.0]]) # Min across all tensors
|
||||
for rank in test_tensors.keys():
|
||||
self.assertEqual(lt_min._local_tensors[rank], expected_min)
|
||||
|
||||
def test_all_reduce_collective(self):
|
||||
"""Test that all_reduce collective operation works correctly with LocalTensor."""
|
||||
# Create different tensors for each rank
|
||||
different_tensors = {
|
||||
0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]),
|
||||
2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]),
|
||||
}
|
||||
|
||||
fake_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
# Test all_reduce with SUM (default)
|
||||
lt_sum = LocalTensor({k: v.clone() for k, v in different_tensors.items()})
|
||||
lt_sum = lt_sum + 1
|
||||
dist.all_reduce(lt_sum, group=fake_pg)
|
||||
|
||||
# Verify all ranks have the sum of all tensors (after adding 1 to each)
|
||||
expected_sum = torch.tensor([[114.0, 225.0, 336.0], [447.0, 558.0, 669.0]])
|
||||
for rank in different_tensors.keys():
|
||||
self.assertEqual(lt_sum._local_tensors[rank], expected_sum)
|
||||
|
||||
def test_broadcast_collective(self):
|
||||
"""Test that broadcast collective operation works correctly with LocalTensor."""
|
||||
# Create different tensors for each rank
|
||||
different_tensors = {
|
||||
0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]),
|
||||
2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]),
|
||||
}
|
||||
|
||||
fake_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
# Test broadcast from rank 1
|
||||
lt_broadcast = LocalTensor({k: v.clone() for k, v in different_tensors.items()})
|
||||
dist.broadcast(lt_broadcast, src=1, group=fake_pg)
|
||||
|
||||
# Verify all ranks have rank 1's original tensor
|
||||
expected_broadcast = different_tensors[1]
|
||||
for rank in different_tensors.keys():
|
||||
self.assertEqual(lt_broadcast._local_tensors[rank], expected_broadcast)
|
||||
|
||||
def test_all_gather_collective(self):
|
||||
"""Test that all_gather collective operation works correctly with LocalTensor."""
|
||||
# Create different tensors for each rank
|
||||
different_tensors = {
|
||||
0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
|
||||
1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]),
|
||||
2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]),
|
||||
}
|
||||
|
||||
fake_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
|
||||
# Test all_gather
|
||||
lt_gather = LocalTensor(different_tensors)
|
||||
tensor_list = [torch.zeros_like(lt_gather) for _ in range(3)]
|
||||
|
||||
dist.all_gather(tensor_list, lt_gather, group=fake_pg)
|
||||
|
||||
# Verify each position in tensor_list contains the corresponding rank's tensor
|
||||
self.assertEqual(tensor_list[0], different_tensors[0])
|
||||
self.assertEqual(tensor_list[1], different_tensors[1])
|
||||
self.assertEqual(tensor_list[2], different_tensors[2])
|
||||
|
||||
|
||||
class TestLocalTensorWorld4(LocalTensorTestBase):
|
||||
world_size = 4
|
||||
|
||||
def test_dtensor_cat(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
t1 = torch.arange(16).view(4, 4).float()
|
||||
d1 = distribute_tensor(t1, device_mesh, [Replicate()])
|
||||
t2 = (torch.arange(16) + 16).view(4, 4).float()
|
||||
d2 = distribute_tensor(t2, device_mesh, [Shard(0)])
|
||||
|
||||
local_res = torch.cat([t1, t2], dim=-1)
|
||||
dist_res = torch.cat([d1, d2], dim=-1)
|
||||
full_tensor = dist_res.full_tensor()
|
||||
self.assertEqual(full_tensor, local_res)
|
||||
|
||||
|
||||
class TestLocalTensorWorld8(LocalTensorTestBase):
|
||||
world_size = 8
|
||||
|
||||
def test_dtensor_addmm(self):
|
||||
with LocalTensorMode(self.world_size):
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
shard_spec = [Shard(0)]
|
||||
replica_spec = [Replicate()]
|
||||
|
||||
tensor_to_shard = torch.randn(12, 8)
|
||||
mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
|
||||
tensor_to_replicate = torch.randn(8, 4)
|
||||
mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec)
|
||||
input_tensor = torch.randn(4)
|
||||
input = distribute_tensor(input_tensor, device_mesh, replica_spec)
|
||||
|
||||
dist_res = torch.addmm(input, mat1, mat2)
|
||||
local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate)
|
||||
full_tensor = dist_res.full_tensor()
|
||||
self.assertEqual(full_tensor, local_res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -7915,9 +7915,13 @@ torch.cuda.synchronize()
|
||||
|
||||
nt = torch.nested.nested_tensor(
|
||||
[
|
||||
torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype)
|
||||
(
|
||||
torch.randint(
|
||||
2, (n, *post_seq_len_shape), device=device, dtype=dtype
|
||||
)
|
||||
if dtype is torch.bool
|
||||
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
|
||||
)
|
||||
for n in range(2, 9)
|
||||
],
|
||||
layout=torch.jagged,
|
||||
@ -7966,9 +7970,13 @@ torch.cuda.synchronize()
|
||||
|
||||
nt = torch.nested.nested_tensor(
|
||||
[
|
||||
torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype)
|
||||
(
|
||||
torch.randint(
|
||||
2, (n, *post_seq_len_shape), device=device, dtype=dtype
|
||||
)
|
||||
if dtype is torch.bool
|
||||
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
|
||||
)
|
||||
for n in range(2, 9)
|
||||
],
|
||||
layout=torch.jagged,
|
||||
@ -8713,7 +8721,7 @@ COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
|
||||
# min() / max(): weird bug
|
||||
XFailRule(
|
||||
error_type=AttributeError,
|
||||
error_msg="'ConstantIntNode' object has no attribute 'add'",
|
||||
error_msg="'NestedIntNode' object has no attribute 'add'",
|
||||
op_match_fn=lambda device, op: (
|
||||
op.full_name in {"max.reduction_with_dim", "min.reduction_with_dim"}
|
||||
),
|
||||
@ -8730,7 +8738,7 @@ COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
|
||||
# copysign(): formula is broken for (T, NT) broadcasting
|
||||
XFailRule(
|
||||
error_type=AttributeError,
|
||||
error_msg="'ConstantIntNode' object has no attribute 'add'",
|
||||
error_msg="'NestedIntNode' object has no attribute 'add'",
|
||||
op_match_fn=lambda device, op: (op.full_name == "copysign"),
|
||||
sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name),
|
||||
name="broken_copysign_compile_backward",
|
||||
|
@ -15,7 +15,11 @@ TORCH_LIBRARY(c10d, m) {
|
||||
m.class_<Work>("Work")
|
||||
.def(torch::init<>())
|
||||
.def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
|
||||
m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
|
||||
m.class_<ReduceOp>("ReduceOp")
|
||||
.def(torch::init<>())
|
||||
.def("op", [](const c10::intrusive_ptr<ReduceOp>& self) -> int64_t {
|
||||
return self->op_;
|
||||
});
|
||||
m.def(
|
||||
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
m.def(
|
||||
|
747
torch/distributed/_local_tensor/__init__.py
Normal file
747
torch/distributed/_local_tensor/__init__.py
Normal file
@ -0,0 +1,747 @@
|
||||
from ast import Call
|
||||
|
||||
|
||||
"""
|
||||
A LocalTensor is a tensor subclass which simulates a tensor that is
|
||||
distributed across SPMD ranks. A LocalTensor might be size N, but in fact
|
||||
there are world_size shards/replicas of it stored internally. When you do a
|
||||
plain PyTorch operation on it, we apply the operation to each shard; when you
|
||||
do a collective, we do the mathematically equivalent operation on the local
|
||||
shards. A LocalTensor is associated with a list of ranks which specify
|
||||
which ranks it holds local tensors for.
|
||||
|
||||
NB, this is NOT a DataParallel like abstraction where you can run operations
|
||||
on multiple different GPUs. It is intended purely for *debugging* purposes,
|
||||
the overhead is almost certainly too high to keep eight GPUs (even the C++
|
||||
autograd needs multithreading to keep up!) (It might potentially be possible
|
||||
to trace through this with torch.compile and then compile it with CUDA graphs
|
||||
but this is currently a non-goal.)
|
||||
|
||||
We do not directly handling MPMD. However in practice even in SPMD you may
|
||||
encounter divergence in behavior per rank (for example, uneven sharding
|
||||
across ranks). To support scenarios like this, we provide a helper decorator
|
||||
that allows you to run a function with no side effects for each LocalTensor
|
||||
shard and combine results back into LocalTensor or LocalIntNode.
|
||||
|
||||
NB: This is a torch dispatch Tensor subclass, as we want to assume that autograd
|
||||
is SPMD, so we run it once, and dispatch the inner autograd calls to the individual
|
||||
local shards.
|
||||
|
||||
NOTE ABOUT MESH: This subclass requires collectives that are issued to it to
|
||||
respect a DeviceMesh like abstraction. The reason for this is that when
|
||||
DTensor issues us a collective for a particular rank, you will be asked to do
|
||||
this on a specific process group which involves some ranks. However, this
|
||||
will only be for the LOCAL PG that this particular rank is participating in;
|
||||
there will be a bunch of other PGs for other nodes that you don't get to see.
|
||||
We need to be able to reverse engineer all of the collectives that don't
|
||||
involve the current local rank here to actually issue them. This can be done
|
||||
two ways: (1) looking at the participating local ranks in the PG and computing
|
||||
the complement which specifies all the other collectives you have to run, or
|
||||
(2) retrieving the device mesh axis corresponding to the PG for this rank, and
|
||||
then running all the fibers for this.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import operator
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, Generator, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Size, SymBool, SymInt, Tensor
|
||||
from torch._C import DispatchKey, DispatchKeySet
|
||||
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
|
||||
from torch.distributed import DeviceMesh
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
from torch.fx.experimental._constant_symnode import ConstantIntNode
|
||||
from torch.nested._internal.nested_int import NestedIntNode
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._python_dispatch import return_and_correct_aliasing, TorchDispatchMode
|
||||
from torch.utils.checkpoint import get_device_states, set_device_states
|
||||
|
||||
|
||||
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
|
||||
|
||||
|
||||
from . import _c10d
|
||||
|
||||
|
||||
def _int_on_rank(i: "LocalIntNode | ConstantIntNode", r: int) -> int:
|
||||
if isinstance(i, LocalIntNode):
|
||||
return i._local_ints[r]
|
||||
elif isinstance(i, ConstantIntNode):
|
||||
return i.val
|
||||
else:
|
||||
raise AssertionError(type(i))
|
||||
|
||||
|
||||
def _check_for_subclass(flat_args: Sequence[object]) -> bool:
|
||||
return any(_check_for_subclass_arg(x) for x in flat_args)
|
||||
|
||||
|
||||
def _check_for_subclass_arg(x: object) -> bool:
|
||||
return (
|
||||
not isinstance(x, LocalTensor)
|
||||
and isinstance(x, Tensor)
|
||||
and type(x) not in (Tensor, torch.nn.Parameter, torch.nn.Buffer)
|
||||
)
|
||||
|
||||
|
||||
def _map_to_rank_local_val(val: Any, rank: int) -> Any:
|
||||
if isinstance(val, LocalTensor):
|
||||
return val._local_tensors[rank]
|
||||
if isinstance(val, SymInt) and isinstance(val.node, LocalIntNode):
|
||||
return val.node._local_ints[rank]
|
||||
return val
|
||||
|
||||
|
||||
def _for_each_rank_run_func(
|
||||
func: Callable[..., Any],
|
||||
ranks: frozenset[int],
|
||||
args: Sequence[Any],
|
||||
kwargs: dict[str, Any],
|
||||
*,
|
||||
alias: bool = True,
|
||||
) -> Any:
|
||||
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
cpu_state = torch.get_rng_state()
|
||||
devices, states = get_device_states((args, kwargs))
|
||||
|
||||
flat_rank_rets = {}
|
||||
|
||||
for r in sorted(ranks):
|
||||
torch.set_rng_state(cpu_state)
|
||||
set_device_states(devices, states)
|
||||
rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args]
|
||||
rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec)
|
||||
rank_ret = func(*rank_args, **rank_kwargs)
|
||||
flat_rank_rets[r] = rank_ret
|
||||
|
||||
rr_key = next(iter(flat_rank_rets.keys()))
|
||||
rr_val = flat_rank_rets[rr_key]
|
||||
|
||||
if isinstance(rr_val, Tensor):
|
||||
ret = LocalTensor({r: flat_rank_rets[r] for r in sorted(ranks)})
|
||||
elif isinstance(rr_val, (list, tuple)):
|
||||
ret_list = []
|
||||
for i in range(len(rr_val)):
|
||||
rets = {r: flat_rank_rets[r][i] for r in sorted(ranks)}
|
||||
v_it = iter(rets.values())
|
||||
v = next(v_it)
|
||||
if isinstance(v, Tensor):
|
||||
ret_list.append(LocalTensor(rets))
|
||||
elif isinstance(v, int) and not all(v == v2 for v2 in v_it):
|
||||
ret_list.append(torch.SymInt(LocalIntNode(rets)))
|
||||
else:
|
||||
assert all(v == v2 for v2 in v_it)
|
||||
ret_list.append(v)
|
||||
ret = type(rr_val)(ret_list)
|
||||
else:
|
||||
v_it = iter(flat_rank_rets.values())
|
||||
v = next(v_it)
|
||||
if all(v == v2 for v2 in v_it):
|
||||
return v
|
||||
if isinstance(v, int):
|
||||
return torch.SymInt(LocalIntNode(flat_rank_rets))
|
||||
raise AssertionError(f"Unexpected return type {type(v)}")
|
||||
|
||||
if alias:
|
||||
return return_and_correct_aliasing(func, args, kwargs, ret)
|
||||
else:
|
||||
return ret
|
||||
|
||||
|
||||
def _get_extra_dispatch_keys(t: torch.Tensor) -> DispatchKeySet:
|
||||
extra_dispatch_keys = torch._C.DispatchKeySet.from_raw_repr(0)
|
||||
if torch._C._dispatch_keys(t).has(torch._C.DispatchKey.Conjugate):
|
||||
extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Conjugate)
|
||||
if torch._C._dispatch_keys(t).has(torch._C.DispatchKey.Negative):
|
||||
extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Negative)
|
||||
return extra_dispatch_keys
|
||||
|
||||
|
||||
class LocalIntNode:
|
||||
"""
|
||||
Like a LocalTensor, but for an int. We can't use a 0D tensor to represent this
|
||||
because often only a SymInt is accepted where we wish to use this.
|
||||
"""
|
||||
|
||||
def __new__(cls, local_ints: dict[int, int]) -> "ConstantIntNode | LocalIntNode": # type: ignore[misc]
|
||||
if len(set(local_ints.values())) == 1:
|
||||
return ConstantIntNode(next(iter(local_ints.values())))
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, local_ints: dict[int, int]):
|
||||
self._local_ints = local_ints
|
||||
|
||||
def maybe_as_int(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
def is_int(self) -> bool:
|
||||
return True
|
||||
|
||||
def is_float(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_bool(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_nested_int(self) -> bool:
|
||||
return False
|
||||
|
||||
def clone(self) -> "LocalIntNode":
|
||||
return self
|
||||
|
||||
def _str(self) -> str:
|
||||
return f"LocalIntNode({self._local_ints})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._str()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self._str()
|
||||
|
||||
def _graph_repr(self) -> str:
|
||||
return self._str()
|
||||
|
||||
def is_symbolic(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_constant(self) -> bool:
|
||||
return False
|
||||
|
||||
def sym_max(
|
||||
self, other: "LocalIntNode | ConstantIntNode"
|
||||
) -> "LocalIntNode | ConstantIntNode":
|
||||
return LocalIntNode(
|
||||
{
|
||||
r: max(self._local_ints[r], _int_on_rank(other, r))
|
||||
for r in self._local_ints
|
||||
}
|
||||
)
|
||||
|
||||
def add(
|
||||
self, other: "LocalIntNode | ConstantIntNode"
|
||||
) -> "LocalIntNode | ConstantIntNode":
|
||||
return LocalIntNode(
|
||||
{r: self._local_ints[r] + _int_on_rank(other, r) for r in self._local_ints}
|
||||
)
|
||||
|
||||
def sub(
|
||||
self, other: "LocalIntNode | ConstantIntNode"
|
||||
) -> "LocalIntNode | ConstantIntNode":
|
||||
return LocalIntNode(
|
||||
{r: self._local_ints[r] - _int_on_rank(other, r) for r in self._local_ints}
|
||||
)
|
||||
|
||||
def mul(
|
||||
self, other: "LocalIntNode | ConstantIntNode"
|
||||
) -> "LocalIntNode | ConstantIntNode":
|
||||
return LocalIntNode(
|
||||
{r: self._local_ints[r] * _int_on_rank(other, r) for r in self._local_ints}
|
||||
)
|
||||
|
||||
def eq(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
|
||||
r = {self._local_ints[r] == _int_on_rank(other, r) for r in self._local_ints}
|
||||
return torch._C._get_constant_bool_symnode(len(r) == 1 and next(iter(r)))
|
||||
|
||||
def gt(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
|
||||
r = {self._local_ints[r] > _int_on_rank(other, r) for r in self._local_ints}
|
||||
assert len(r) == 1, (self, other)
|
||||
return torch._C._get_constant_bool_symnode(next(iter(r)))
|
||||
|
||||
def lt(self, other: "LocalIntNode | ConstantIntNode") -> bool | SymBool:
|
||||
r = {self._local_ints[r] < _int_on_rank(other, r) for r in self._local_ints}
|
||||
assert len(r) == 1, (self, other)
|
||||
return torch._C._get_constant_bool_symnode(next(iter(r)))
|
||||
|
||||
def wrap_int(self, num: int) -> "LocalIntNode | ConstantIntNode":
|
||||
return ConstantIntNode(num)
|
||||
|
||||
|
||||
class LocalTensor(torch.Tensor):
|
||||
"""
|
||||
LocalTensor is a Tensor subclass that simulates a tensor distributed across multiple SPMD
|
||||
(Single Program, Multiple Data) ranks. Each LocalTensor instance internally holds a mapping from
|
||||
global rank ids to their corresponding local Tensor shards.Operations performed on a LocalTensor
|
||||
are applied independently to each local shard, mimicking distributed computation. Collectives
|
||||
and other distributed operations are handled by mapping them to the local shards as appropriate.
|
||||
|
||||
Note:
|
||||
This class is primarily intended for debugging and simulating distributed tensor computations
|
||||
on a single process.
|
||||
|
||||
"""
|
||||
|
||||
# Map from global rank to the local tensor.
|
||||
_local_tensors: dict[int, torch.Tensor]
|
||||
# Precomputed for speed set of keys from the local tensor map.
|
||||
_ranks: frozenset[int]
|
||||
__slots__ = ["_local_tensors", "_ranks"]
|
||||
|
||||
@staticmethod
|
||||
@torch._disable_dynamo
|
||||
def __new__(
|
||||
cls,
|
||||
local_tensors: dict[int, torch.Tensor],
|
||||
) -> "LocalTensor":
|
||||
if any(t.requires_grad for t in local_tensors.values()):
|
||||
raise AssertionError(
|
||||
"Internal local_tensors require grad, but we will ignore those autograd graph. "
|
||||
"Make a custom autograd function and make sure you detach the inner tensors."
|
||||
)
|
||||
|
||||
it = iter(local_tensors.values())
|
||||
first_local_tensor = next(it)
|
||||
|
||||
first_shape = first_local_tensor.shape
|
||||
first_stride = first_local_tensor.stride()
|
||||
dtype = first_local_tensor.dtype
|
||||
device = first_local_tensor.device
|
||||
layout = first_local_tensor.layout
|
||||
|
||||
extra_dispatch_keys = _get_extra_dispatch_keys(first_local_tensor)
|
||||
|
||||
# Assert that all tensors have the same dtype, layout and dispatch keys. Due
|
||||
# to uneven sharding, it is possible that tensors will have different shapes.
|
||||
for local_tensor in it:
|
||||
assert dtype == local_tensor.dtype, (
|
||||
"Tensors representing LocalTensor shards must have the same dtype"
|
||||
)
|
||||
assert layout == local_tensor.layout, (
|
||||
"Tensors representing LocalTensor shards must have the same layout"
|
||||
)
|
||||
assert extra_dispatch_keys == _get_extra_dispatch_keys(local_tensor), (
|
||||
"Tensors representing LocalTensor shards must have the same set of extra dispatch keys"
|
||||
)
|
||||
|
||||
# Compute shape/stride. We allow for non-SPMD'ness here
|
||||
local_shapes: dict[int, dict[int, int]] = defaultdict(
|
||||
dict
|
||||
) # dim => rank => size
|
||||
local_strides: dict[int, dict[int, int]] = defaultdict(
|
||||
dict
|
||||
) # dim => rank => size
|
||||
for r, local_tensor in local_tensors.items():
|
||||
for d, size in enumerate(local_tensor.shape):
|
||||
local_shapes[d][r] = size
|
||||
local_strides[d][r] = local_tensor.stride(d)
|
||||
shape = [
|
||||
(
|
||||
first_shape[d]
|
||||
if len(set(local_shapes[d])) == 1
|
||||
else torch.SymInt(LocalIntNode(local_shapes[d]))
|
||||
)
|
||||
for d in range(len(first_shape))
|
||||
]
|
||||
strides = [
|
||||
(
|
||||
first_stride[d]
|
||||
if len(set(local_strides[d])) == 1
|
||||
else torch.SymInt(LocalIntNode(local_strides[d]))
|
||||
)
|
||||
for d in range(len(first_shape))
|
||||
]
|
||||
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
shape,
|
||||
strides=strides,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
layout=layout,
|
||||
requires_grad=False,
|
||||
_extra_dispatch_keys=extra_dispatch_keys,
|
||||
)
|
||||
|
||||
local_tensors = {
|
||||
r: v if not isinstance(v, AsyncCollectiveTensor) else v.wait()
|
||||
for r, v in local_tensors.items()
|
||||
}
|
||||
r._local_tensors = local_tensors
|
||||
r._ranks = frozenset(local_tensors.keys())
|
||||
return r
|
||||
|
||||
@torch._disable_dynamo
|
||||
@mark_subclass_constructor_exportable_experimental # type: ignore[misc]
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
parts = []
|
||||
for k, v in self._local_tensors.items():
|
||||
parts.append(f" {k}: {v}")
|
||||
tensors_str = ",\n".join(parts)
|
||||
return f"LocalTensor(\n{tensors_str}\n)"
|
||||
|
||||
def __tensor_flatten__(self) -> tuple[list[str], tuple[Any, ...]]:
|
||||
"""
|
||||
protocol to inform how to flatten a DTensor to local tensor
|
||||
for PT2 tracing
|
||||
"""
|
||||
return ["_local_tensors"], ()
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(
|
||||
inner_tensors: dict[str, Any],
|
||||
flatten_spec: tuple[Any, ...],
|
||||
outer_size: torch.Size,
|
||||
outer_stride: tuple[int, ...],
|
||||
) -> "LocalTensor":
|
||||
assert flatten_spec is not None, (
|
||||
"Expecting spec to be not None from `__tensor_flatten__` return value!"
|
||||
)
|
||||
local_tensors = inner_tensors["_local_tensors"]
|
||||
return LocalTensor(local_tensors)
|
||||
|
||||
@classmethod
|
||||
@torch._disable_dynamo
|
||||
def __torch_dispatch__( # type: ignore[override]
|
||||
cls,
|
||||
func: Any,
|
||||
types: tuple[Any, ...],
|
||||
args: tuple[Any, ...] = (),
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# This is horribly inefficient
|
||||
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
|
||||
local_tensor = None
|
||||
for arg in flat_args:
|
||||
if isinstance(arg, LocalTensor):
|
||||
local_tensor = arg
|
||||
break
|
||||
|
||||
assert local_tensor is not None, (
|
||||
"At least one of the arguments must be a LocalTensor"
|
||||
)
|
||||
|
||||
# Check for unrecognized tensor subclasses (but allow regular tensors and scalars)
|
||||
has_unrecognized_types = _check_for_subclass(flat_args)
|
||||
if has_unrecognized_types:
|
||||
unrecognized_types = [
|
||||
type(x) for x in flat_args if _check_for_subclass_arg(x)
|
||||
]
|
||||
not_implemented_log.debug(
|
||||
"LocalTensor unrecognized subclass(es): %s", unrecognized_types
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
with LocalTensorMode(local_tensor._ranks):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def tolist(self) -> list[Any]:
|
||||
"""
|
||||
Reconcile and convert result to list.
|
||||
"""
|
||||
|
||||
return self.reconcile().tolist()
|
||||
|
||||
def reconcile(self) -> torch.Tensor:
|
||||
"""
|
||||
Reconciles the LocalTensor into a single torch.Tensor by ensuring all local
|
||||
shards are identical and returning a detached clone of one of them.
|
||||
|
||||
Note:
|
||||
This method is useful for extracting a representative tensor from a LocalTensor
|
||||
when all shards are expected to be the same, such as after a collective operation
|
||||
that synchronizes all ranks.
|
||||
"""
|
||||
|
||||
# Force all local tensor shards across ranks to be the same
|
||||
it = iter(self._local_tensors.values())
|
||||
t1 = next(it)
|
||||
for t2 in it:
|
||||
assert torch.equal(t1, t2), (
|
||||
"LocalTensor shards must be the same to reconcile"
|
||||
)
|
||||
cl = t1.clone().detach()
|
||||
cl.requires_grad_(self.requires_grad)
|
||||
return cl
|
||||
|
||||
|
||||
_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = []
|
||||
|
||||
|
||||
class LocalTensorMode(TorchDispatchMode):
|
||||
"""
|
||||
A TorchDispatchMode that simulates SPMD (Single Program, Multiple Data) execution
|
||||
for LocalTensor objects across a set of ranks.
|
||||
|
||||
LocalTensorMode enables PyTorch operations to be transparently applied to each
|
||||
local shard of a LocalTensor, as if they were distributed across multiple ranks.
|
||||
When active, this mode intercepts tensor operations and dispatches them to each
|
||||
rank's local tensor, collecting and wrapping the results as LocalTensors. It also
|
||||
handles collective operations by mapping them to local implementations.
|
||||
|
||||
This mode is primarily intended for debugging and simulating distributed tensor
|
||||
computations on a single process, rather than for high-performance distributed
|
||||
training. It maintains a stack of active modes, patches DeviceMesh coordinate
|
||||
resolution, and provides utilities for temporarily disabling the mode or mapping
|
||||
functions over ranks.
|
||||
"""
|
||||
|
||||
# What ranks this local tensor mode is operating over
|
||||
def __init__(self, ranks: Union[int, frozenset[int]]):
|
||||
if isinstance(ranks, int):
|
||||
# assume is world size
|
||||
self.ranks = frozenset(range(ranks))
|
||||
else:
|
||||
assert isinstance(ranks, frozenset)
|
||||
self.ranks = ranks
|
||||
self._disable = False
|
||||
self._old_get_coordinate = None
|
||||
|
||||
def __enter__(self) -> "LocalTensorMode":
|
||||
self._disable = False
|
||||
self._patch_device_mesh()
|
||||
_LOCAL_TENSOR_MODE.append(self)
|
||||
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
self._disable = True
|
||||
self._unpatch_device_mesh()
|
||||
_LOCAL_TENSOR_MODE.pop()
|
||||
super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def __torch_dispatch__(
|
||||
self,
|
||||
func: Any,
|
||||
types: tuple[Any, ...],
|
||||
args: tuple[Any, ...] = (),
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
# Find all LocalTensor arguments to determine ranks
|
||||
local_tensors = [a for a in flat_args if isinstance(a, LocalTensor)]
|
||||
|
||||
# Check for unrecognized tensor subclasses (but allow regular tensors and scalars)
|
||||
has_unrecognized_types = _check_for_subclass(flat_args)
|
||||
if has_unrecognized_types:
|
||||
unrecognized_types = [
|
||||
type(x) for x in flat_args if _check_for_subclass_arg(x)
|
||||
]
|
||||
not_implemented_log.debug(
|
||||
"LocalTensorMode unrecognized subclass(es): %s", unrecognized_types
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
# Factory functions convert into LocalTensor, so we don't have to
|
||||
# transmute a Tensor into a LocalTensor if mutation happens...
|
||||
# But if you do an operation on a Tensor, do NOT wrap it into a
|
||||
# LocalTensor. This helps prevent accidents when you're doing Tensor
|
||||
# operations on the inner non-wrapped tensors.
|
||||
if not local_tensors:
|
||||
if self._disable or any(isinstance(a, Tensor) for a in flat_args):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# For LocalTensors, verify they have compatible ranks
|
||||
for a in flat_args:
|
||||
if isinstance(a, LocalTensor):
|
||||
assert a._ranks == self.ranks, (
|
||||
f"Input LocalTensor {a} and LocalTensorMode must be configured for the same ranks"
|
||||
)
|
||||
|
||||
if func.namespace == "c10d":
|
||||
if func is torch.ops.c10d.allreduce_.default:
|
||||
return _c10d._local_all_reduce_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.allreduce_coalesced_.default:
|
||||
return _c10d._local_allreduce_coalesced_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.reduce_scatter_tensor_coalesced_.default:
|
||||
return _c10d._local_reduce_scatter_tensor_coalesced_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.scatter_.default:
|
||||
return _c10d._local_scatter_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.broadcast_.default:
|
||||
return _c10d._local_broadcast_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.allgather_.default:
|
||||
return _c10d._local_all_gather_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.allgather_into_tensor_coalesced_.default:
|
||||
return _c10d._local_allgather_into_tensor_coalesced_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.gather_.default:
|
||||
return _c10d._local_gather_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.alltoall_.default:
|
||||
return _c10d._local_alltoall_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.alltoall_base_.default:
|
||||
return _c10d._local_alltoall_base_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.barrier.default:
|
||||
return _c10d._local_barrier(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.monitored_barrier_.default:
|
||||
return _c10d._local_monitored_barrier_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.send.default:
|
||||
return _c10d._local_send(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.recv_.default:
|
||||
return _c10d._local_recv_(*args, **kwargs)
|
||||
elif func is torch.ops.c10d.recv_any_source_.default:
|
||||
return _c10d._local_recv_any_source_(*args, **kwargs)
|
||||
raise NotImplementedError(f"{func} not implemented")
|
||||
|
||||
if func.namespace == "_c10d_functional" or func.namespace == "_dtensor":
|
||||
with LocalTensorMode(self.ranks):
|
||||
return func._op_dk(
|
||||
DispatchKey.CompositeExplicitAutograd, *args, **kwargs
|
||||
)
|
||||
|
||||
if func.namespace == "_c10d_functional_autograd":
|
||||
raise NotImplementedError(f"{func} not implemented")
|
||||
|
||||
if func.namespace == "symm_mem":
|
||||
raise NotImplementedError(f"{func} not implemented")
|
||||
|
||||
return _for_each_rank_run_func(func, self.ranks, args, kwargs, alias=True)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def disable(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
Disables LocalTensorMode temporarily. Primarily is intended to be used to perform
|
||||
rank specific computations and merge results back before enabling LocalTensorMode back.
|
||||
"""
|
||||
|
||||
old = self._disable
|
||||
self._disable = True
|
||||
self._unpatch_device_mesh()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._disable = old
|
||||
self._patch_device_mesh()
|
||||
|
||||
def rank_map(self, cb: Callable[[int], Tensor]) -> LocalTensor:
|
||||
"""
|
||||
Creates a LocalTensor instance by mapping rank id to ids local shard.
|
||||
"""
|
||||
|
||||
with self.disable():
|
||||
return LocalTensor({r: cb(r) for r in self.ranks})
|
||||
|
||||
def _patch_device_mesh(self) -> None:
|
||||
assert self._old_get_coordinate is None
|
||||
self._old_get_coordinate = DeviceMesh.get_coordinate # type: ignore[assignment]
|
||||
DeviceMesh.get_coordinate = _LocalDeviceMesh.get_coordinate # type: ignore[method-assign]
|
||||
|
||||
def _unpatch_device_mesh(self) -> None:
|
||||
assert self._old_get_coordinate is not None
|
||||
DeviceMesh.get_coordinate = self._old_get_coordinate
|
||||
self._old_get_coordinate = None
|
||||
|
||||
|
||||
class _LocalDeviceMesh:
|
||||
"""
|
||||
Holds implementations of DeviceMesh functionality that must be patched while running
|
||||
under LocalTensorMode.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]:
|
||||
lm = local_tensor_mode()
|
||||
assert lm is not None, "Unexpectedly not in LocalTensorMode"
|
||||
|
||||
rank_coords = (self.mesh == lm.rank_map(lambda r: torch.tensor(r))).nonzero()
|
||||
# NB: unlike the regular mechanism, we don't allow for MPMD
|
||||
assert rank_coords.size(0) == 1
|
||||
assert isinstance(rank_coords[0], LocalTensor)
|
||||
|
||||
coords: list[dict[int, int]] = [{} for _ in range(rank_coords.size(1))]
|
||||
for r, v in rank_coords[0]._local_tensors.items():
|
||||
for i, x in enumerate(v.tolist()):
|
||||
coords[i][r] = x
|
||||
out = [torch.SymInt(LocalIntNode(c)) for c in coords]
|
||||
|
||||
return out # type: ignore[return-value]
|
||||
|
||||
|
||||
def reconcile_args(args: Any, kwargs: dict[str, Any] | None = None) -> Any:
|
||||
"""
|
||||
Reconciles arguments by converting any LocalTensor instances in the input
|
||||
arguments to their underlying torch.Tensor representation.
|
||||
|
||||
This function is typically used to prepare arguments for functions that
|
||||
expect standard torch.Tensor objects, by flattening the input arguments,
|
||||
replacing LocalTensor instances with their reconciled (standard tensor)
|
||||
versions, and then reconstructing the original argument structure.
|
||||
|
||||
Args:
|
||||
args: Positional arguments, possibly containing LocalTensor instances.
|
||||
kwargs: Keyword arguments, possibly containing LocalTensor instances.
|
||||
|
||||
Returns:
|
||||
Any: The arguments with all LocalTensor instances replaced by their reconciled torch.Tensor equivalents,
|
||||
preserving the original structure.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
|
||||
reconciled_args = [
|
||||
a.reconcile() if isinstance(a, LocalTensor) else a for a in flat_args
|
||||
]
|
||||
return pytree.tree_unflatten(reconciled_args, args_spec)
|
||||
|
||||
|
||||
def local_tensor_mode() -> Optional[LocalTensorMode]:
|
||||
"""
|
||||
Returns the current active LocalTensorMode if one exists.
|
||||
|
||||
This function checks the global stack of LocalTensorMode instance. If there
|
||||
is at least one LocalTensorMode active, it returns the most recently entered
|
||||
(top of the stack) LocalTensorMode. If no LocalTensorMode is active, it returns None.
|
||||
|
||||
Returns:
|
||||
Optional[LocalTensorMode]: The current LocalTensorMode if active, else None.
|
||||
"""
|
||||
if len(_LOCAL_TENSOR_MODE) > 0:
|
||||
return _LOCAL_TENSOR_MODE[-1]
|
||||
return None
|
||||
|
||||
|
||||
def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
Decorator that ensures a function is executed for each local tensor shard
|
||||
when running under LocalTensorMode. If not in LocalTensorMode, the function
|
||||
is executed normally. When in LocalTensorMode, the function is run for each
|
||||
rank, and the results are collected appropriately.
|
||||
|
||||
This decorator is useful for functions that exhibit non-SPMD behavior, such
|
||||
as those requiring rank specific actions. For example, a function that computes
|
||||
offset into input tensor based on rank.
|
||||
|
||||
Note that the function being decorated must not have any side effects and
|
||||
contain operations for a single rank only. For example, wrapping a function
|
||||
that performs a collective operation will not work.
|
||||
|
||||
Args:
|
||||
func (Callable[..., Any]): The function to be decorated.
|
||||
|
||||
Returns:
|
||||
Callable[..., Any]: The wrapped function that handles LocalTensorMode logic.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
lm = local_tensor_mode()
|
||||
if lm is None:
|
||||
return func(*args, **kwargs)
|
||||
ret = None
|
||||
with lm.disable():
|
||||
ret = _for_each_rank_run_func(func, lm.ranks, args, kwargs, alias=False)
|
||||
|
||||
lm = local_tensor_mode()
|
||||
assert lm is not None
|
||||
return ret
|
||||
|
||||
return wrapper
|
669
torch/distributed/_local_tensor/_c10d.py
Normal file
669
torch/distributed/_local_tensor/_c10d.py
Normal file
@ -0,0 +1,669 @@
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from torch._C import ScriptObject
|
||||
from torch._C._distributed_c10d import FakeWork
|
||||
from torch.distributed._mesh_layout import _MeshLayout
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_get_default_group,
|
||||
ProcessGroup,
|
||||
ReduceOp,
|
||||
Work,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: Most of the c10d collectives often take a Tensor[] (or Tensor[][])
|
||||
# when you would expect Tensor (or Tensor[]). In fact, there will only ever
|
||||
# be one Tensor in this case; the old signature was to support dispatching a
|
||||
# collective on multiple devices (ala DataParallel) but we don't support that
|
||||
# API anymore. Note that we are not 100% consistent about this; some more
|
||||
# modern collectives like _allgather_base_ got rid of the unnecessary list.
|
||||
# When in doubt, consult the code that dispatches to the collective on the PG
|
||||
# in distributed_c10d.py e.g., work = group.allgather([tensor_list], [tensor],
|
||||
# opts) indicates its always a list.
|
||||
|
||||
|
||||
def _gcd_list(numbers: Sequence[int]) -> int:
|
||||
return 0 if not numbers else functools.reduce(math.gcd, numbers)
|
||||
|
||||
|
||||
def _indices_to_layout(indices: list[int]) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
# Base case: A single index represents a point, not a dimension.
|
||||
if len(indices) <= 1:
|
||||
return (), ()
|
||||
|
||||
# The smallest stride is likely the GCD of the differences between consecutive indices.
|
||||
# For a sorted, unique list, all differences will be positive.
|
||||
diffs = [indices[i] - indices[i - 1] for i in range(1, len(indices))]
|
||||
last_stride = _gcd_list(diffs)
|
||||
|
||||
assert last_stride != 0, (
|
||||
# This case should not be reached if indices are unique and sorted.
|
||||
"Cannot determine stride; indices may not be unique."
|
||||
)
|
||||
|
||||
# Identify the starting index of each "row" in the last dimension.
|
||||
# An index starts a new row if the preceding index (index - stride) is not present.
|
||||
indices_set = set(indices)
|
||||
higher_dim_indices = [indices[0]]
|
||||
for index in indices[1:]:
|
||||
if (index - last_stride) not in indices_set:
|
||||
higher_dim_indices.append(index)
|
||||
|
||||
# From the number of rows, we can deduce the shape of the last dimension.
|
||||
assert len(indices) % len(higher_dim_indices) == 0, (
|
||||
"Indices do not form a regular grid. "
|
||||
f"Found {len(higher_dim_indices)} subgroups for {len(indices)} total elements."
|
||||
)
|
||||
last_shape = len(indices) // len(higher_dim_indices)
|
||||
|
||||
# Recurse on the higher-dimensional indices (the start of each row).
|
||||
higher_shapes, higher_strides = _indices_to_layout(higher_dim_indices)
|
||||
|
||||
# Combine the results from the recursion with the current dimension's results.
|
||||
final_shapes = higher_shapes + (last_shape,)
|
||||
final_strides = higher_strides + (last_stride,)
|
||||
|
||||
return final_shapes, final_strides
|
||||
|
||||
|
||||
def _prepare_collective_groups(
|
||||
process_group_so: ScriptObject,
|
||||
) -> tuple[list[int], list[int], int]:
|
||||
process_group = ProcessGroup.unbox(process_group_so)
|
||||
|
||||
ranks = torch.distributed.get_process_group_ranks(process_group)
|
||||
assert ranks
|
||||
# TODO: We can handle permutations but the layout inference algorithm will
|
||||
# lose the permutation so we will have to reapply it
|
||||
assert ranks == sorted(ranks), ranks
|
||||
offset = ranks[0]
|
||||
ranks = [r - offset for r in ranks]
|
||||
|
||||
shape, strides = _indices_to_layout(ranks)
|
||||
layout = _MeshLayout(shape, strides)
|
||||
|
||||
global_pg = _get_default_group()
|
||||
group_offsets = layout.complement(global_pg.size()).all_ranks_from_zero()
|
||||
|
||||
return ranks, group_offsets, offset
|
||||
|
||||
|
||||
def _local_broadcast_(
|
||||
tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
root_rank: int,
|
||||
root_tensor: int,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> tuple[list[torch.Tensor], ScriptObject]:
|
||||
# "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"
|
||||
from . import LocalTensor
|
||||
|
||||
assert len(tensors) == 1
|
||||
assert root_tensor == 0
|
||||
tensor = tensors[0]
|
||||
|
||||
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
# We're going to assume SPMD where for every rank group the root_rank is
|
||||
# the same relative to others
|
||||
relative_root_rank = root_rank - offset
|
||||
|
||||
assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor"
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the broadcast on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
source_rank = group_offset + relative_root_rank
|
||||
source_tensor = tensor._local_tensors[source_rank]
|
||||
|
||||
# Broadcast the source tensor to all ranks in this group
|
||||
for rank in group_ranks:
|
||||
if source_rank != rank:
|
||||
tensor._local_tensors[rank].copy_(source_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return (tensors, work_so)
|
||||
|
||||
|
||||
def _local_reduce(
|
||||
reduce_op: ReduceOp,
|
||||
tensors: list[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if reduce_op == ReduceOp.SUM:
|
||||
op = operator.add
|
||||
elif reduce_op == ReduceOp.AVG:
|
||||
op = None
|
||||
elif reduce_op == ReduceOp.PRODUCT:
|
||||
op = operator.mul
|
||||
elif reduce_op == ReduceOp.MIN:
|
||||
op = torch.minimum
|
||||
elif reduce_op == ReduceOp.MAX:
|
||||
op = torch.maximum
|
||||
elif reduce_op == ReduceOp.BAND:
|
||||
op = torch.bitwise_and
|
||||
elif reduce_op == ReduceOp.BOR:
|
||||
op = torch.bitwise_or
|
||||
elif reduce_op == ReduceOp.BXOR:
|
||||
op = torch.bitwise_xor
|
||||
elif reduce_op == ReduceOp.PREMUL_SUM:
|
||||
raise NotImplementedError("PREMUL_SUM: need to add binding for scaling factor")
|
||||
else:
|
||||
raise NotImplementedError(f"ReduceOp {reduce_op} not implemented")
|
||||
|
||||
if reduce_op == ReduceOp.AVG:
|
||||
return functools.reduce(operator.add, tensors) / len(tensors)
|
||||
else:
|
||||
assert op is not None
|
||||
return functools.reduce(op, tensors)
|
||||
|
||||
|
||||
def _local_all_reduce_(
|
||||
tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
reduce_op_so: ScriptObject,
|
||||
sparse_indices: torch.Tensor | None = None,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> tuple[list[torch.Tensor], ScriptObject]:
|
||||
# "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "__torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, "
|
||||
# "int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
from . import LocalTensor
|
||||
|
||||
assert len(tensors) == 1
|
||||
tensor = tensors[0]
|
||||
reduce_op = reduce_op_so.op() # type: ignore[attr-defined]
|
||||
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor"
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the allreduce on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# Collect tensors from the specified ranks in this group
|
||||
group_tensors = []
|
||||
for rank in group_ranks:
|
||||
group_tensors.append(tensor._local_tensors[rank])
|
||||
|
||||
# Perform the reduction operation
|
||||
reduced_tensor = _local_reduce(reduce_op, group_tensors)
|
||||
|
||||
# Update all tensors in the group with the reduced result
|
||||
for rank in group_ranks:
|
||||
tensor._local_tensors[rank].copy_(reduced_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return (tensors, work_so)
|
||||
|
||||
|
||||
def _local_allreduce_coalesced_(
|
||||
tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
reduce_op_so: ScriptObject,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> ScriptObject:
|
||||
# "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "__torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"
|
||||
from . import LocalTensor
|
||||
|
||||
reduce_op = reduce_op_so.op() # type: ignore[attr-defined]
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the allreduce on all tensors together
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# For each tensor, perform the reduction operation
|
||||
for tensor in tensors:
|
||||
assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor"
|
||||
# Collect tensors from the specified ranks in this group
|
||||
group_tensors = []
|
||||
for rank in group_ranks:
|
||||
group_tensors.append(tensor._local_tensors[rank])
|
||||
|
||||
# Perform the reduction operation
|
||||
reduced_tensor = _local_reduce(reduce_op, group_tensors)
|
||||
|
||||
# Update all tensors in the group with the reduced result
|
||||
for rank in group_ranks:
|
||||
tensor._local_tensors[rank].copy_(reduced_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return work_so
|
||||
|
||||
|
||||
def _local_reduce_scatter_tensor_coalesced_(
|
||||
output_tensors: list[torch.Tensor],
|
||||
input_tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
reduce_op_so: ScriptObject,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> ScriptObject:
|
||||
# "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, "
|
||||
# "__torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "__torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, "
|
||||
# "int timeout=-1) -> __torch__.torch.classes.c10d.Work"
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
reduce_op = reduce_op_so.op() # type: ignore[attr-defined]
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the allreduce on all tensors together
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# For each tensor, perform the reduction operation
|
||||
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
|
||||
assert isinstance(input_tensor, LocalTensor), (
|
||||
"Input tensor must be a LocalTensor"
|
||||
)
|
||||
assert isinstance(output_tensor, LocalTensor), (
|
||||
"Output tensor must be a LocalTensor"
|
||||
)
|
||||
# Collect tensors from the specified ranks in this group
|
||||
group_inputs = []
|
||||
for rank in group_ranks:
|
||||
group_inputs.append(input_tensor._local_tensors[rank])
|
||||
|
||||
# Perform the reduction operation
|
||||
reduced_input = _local_reduce(reduce_op, group_inputs)
|
||||
|
||||
reduced_inpit_splits = torch.split(
|
||||
reduced_input, reduced_input.size(0) // len(group_ranks), dim=0
|
||||
)
|
||||
|
||||
# Update all tensors in the group with the reduced result
|
||||
for rank in group_ranks:
|
||||
output_tensor._local_tensors[rank].copy_(reduced_inpit_splits[rank])
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return work_so
|
||||
|
||||
|
||||
def _local_all_gather_(
|
||||
output_tensors: list[list[torch.Tensor]],
|
||||
input_tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> tuple[list[list[torch.Tensor]], ScriptObject]:
|
||||
# "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, "
|
||||
# "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, "
|
||||
# "int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
assert len(output_tensors) == 1
|
||||
assert len(input_tensors) == 1
|
||||
|
||||
input_tensor = input_tensors[0]
|
||||
output_tensors = output_tensors[0]
|
||||
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor"
|
||||
for i in range(len(output_tensors)):
|
||||
assert isinstance(output_tensors[i], LocalTensor), (
|
||||
"Output tensor must be a LocalTensor"
|
||||
)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the all_gather on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# For each rank in the group, gather from their input tensor
|
||||
for i, rank_i in enumerate(group_ranks):
|
||||
output_tensors[i].copy_(input_tensor._local_tensors[rank_i])
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return ([output_tensors], work_so)
|
||||
|
||||
|
||||
def _local_allgather_into_tensor_coalesced_(
|
||||
output_tensors: list[torch.Tensor],
|
||||
input_tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
async_op: bool = True,
|
||||
) -> ScriptObject:
|
||||
# "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, "
|
||||
# "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) "
|
||||
# "-> __torch__.torch.classes.c10d.Work"
|
||||
from . import LocalTensor
|
||||
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
# Each output tensor should be sized to hold all gathered inputs
|
||||
# outputs[i] will contain all inputs[i] from all ranks
|
||||
assert len(output_tensors) == len(input_tensors), (
|
||||
f"Number of outputs ({len(output_tensors)}) must match number of inputs ({len(input_tensors)})"
|
||||
)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the allgather_into_tensor on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# For each input/output pair
|
||||
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
|
||||
assert isinstance(input_tensor, LocalTensor), (
|
||||
"Input tensor must be a LocalTensor"
|
||||
)
|
||||
assert isinstance(output_tensor, LocalTensor), (
|
||||
"Output tensor must be a LocalTensor"
|
||||
)
|
||||
# Gather input_tensor from all ranks into output_tensor
|
||||
# The output should be a concatenation of all inputs along the first dimension
|
||||
gathered_tensors = []
|
||||
for rank in group_ranks:
|
||||
gathered_tensors.append(input_tensor._local_tensors[rank])
|
||||
|
||||
# Concatenate along first dimension and copy to output
|
||||
if gathered_tensors:
|
||||
concatenated = torch.cat(gathered_tensors, dim=0)
|
||||
for rank in group_ranks:
|
||||
output_tensor._local_tensors[rank].copy_(concatenated)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return work_so
|
||||
|
||||
|
||||
def _local_gather_(
|
||||
output_tensors: list[list[torch.Tensor]],
|
||||
input_tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
root_rank: int,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> ScriptObject:
|
||||
# "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, "
|
||||
# "__torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, "
|
||||
# "bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"
|
||||
raise NotImplementedError(
|
||||
"LocalTensor does not support MPMD operations like gather "
|
||||
"(only root rank receives data). Use SPMD collective operations like allgather instead."
|
||||
)
|
||||
|
||||
|
||||
def _local_scatter_(
|
||||
output_tensors: list[torch.Tensor],
|
||||
input_tensors: list[list[torch.Tensor]],
|
||||
process_group_so: ScriptObject,
|
||||
root_rank: int,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> tuple[list[torch.Tensor], ScriptObject]:
|
||||
# "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, "
|
||||
# "__torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, "
|
||||
# "bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
assert len(output_tensors) == 1
|
||||
assert len(input_tensors) == 1
|
||||
output_tensor = output_tensors[0]
|
||||
input_tensors = input_tensors[0]
|
||||
|
||||
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
# We're going to assume SPMD where for every rank group the root_rank is
|
||||
# the same relative to others
|
||||
relative_root_rank = root_rank - offset
|
||||
|
||||
assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor"
|
||||
assert len(ranks) == len(input_tensors), (ranks, input_tensors)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the scatter on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# Root rank scatters its input tensors to all ranks in this group
|
||||
for i, rank in enumerate(group_ranks):
|
||||
input_tensor = input_tensors[i]
|
||||
assert isinstance(input_tensor, LocalTensor)
|
||||
# Each rank i gets the i-th input tensor from the root
|
||||
source_tensor = input_tensor._local_tensors[
|
||||
group_offset + relative_root_rank
|
||||
]
|
||||
output_tensor._local_tensors[rank].copy_(source_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return (output_tensors, work_so)
|
||||
|
||||
|
||||
def _local_alltoall_(
|
||||
output_tensors: list[torch.Tensor],
|
||||
input_tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> tuple[list[torch.Tensor], ScriptObject]:
|
||||
# "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, "
|
||||
# "__torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, "
|
||||
# "int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)";
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
assert len(input_tensors) == len(output_tensors) == len(ranks), (
|
||||
f"Number of input tensors ({len(input_tensors)}), "
|
||||
f"output tensors ({len(output_tensors)}), and ranks ({len(ranks)}) must match"
|
||||
)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the alltoall on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
# In alltoall, rank i sends input_tensors[j] to rank j and receives into output_tensors[i] from rank j
|
||||
for i, rank_i in enumerate(group_ranks):
|
||||
output_tensor = output_tensors[i]
|
||||
assert isinstance(output_tensor, LocalTensor), (
|
||||
"Output tensor must be a LocalTensor"
|
||||
)
|
||||
for j, rank_j in enumerate(group_ranks):
|
||||
input_tensor = input_tensors[j]
|
||||
assert isinstance(input_tensor, LocalTensor), (
|
||||
"Input tensor must be a LocalTensor"
|
||||
)
|
||||
# Rank i's j-th input tensor goes to rank j's i-th output tensor
|
||||
source_tensor = input_tensor._local_tensors[rank_i]
|
||||
output_tensor._local_tensors[rank_j].copy_(source_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return (output_tensors, work_so)
|
||||
|
||||
|
||||
def _local_alltoall_base_(
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
process_group_so: ScriptObject,
|
||||
output_split_sizes: list[int],
|
||||
input_split_sizes: list[int],
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> ScriptObject:
|
||||
# "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work";
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||
|
||||
assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor"
|
||||
assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor"
|
||||
# Convert split sizes to lists if they aren't already
|
||||
if output_split_sizes is not None:
|
||||
output_split_sizes = list(output_split_sizes)
|
||||
if input_split_sizes is not None:
|
||||
input_split_sizes = list(input_split_sizes)
|
||||
|
||||
for group_offset in group_offsets:
|
||||
# For the tensors in this group [group_offset + r for r in ranks]
|
||||
# perform the alltoall_base on them
|
||||
group_ranks = [group_offset + r for r in ranks]
|
||||
|
||||
for i, rank_i in enumerate(group_ranks):
|
||||
# Split input tensor from rank_i according to input_split_sizes
|
||||
rank_tensor = input_tensor._local_tensors[rank_i]
|
||||
|
||||
if input_split_sizes is not None and len(input_split_sizes) > 0:
|
||||
# Split the input tensor
|
||||
input_splits = torch.split(rank_tensor, input_split_sizes, dim=0)
|
||||
else:
|
||||
# No split sizes specified, split evenly
|
||||
split_size = rank_tensor.size(0) // len(group_ranks)
|
||||
input_splits = torch.split(rank_tensor, split_size, dim=0)
|
||||
|
||||
# Send each split to the corresponding rank
|
||||
for j, rank_j in enumerate(group_ranks):
|
||||
if j < len(input_splits):
|
||||
split_tensor = input_splits[j]
|
||||
|
||||
# Determine where to place this split in the output tensor
|
||||
if output_split_sizes is not None and len(output_split_sizes) > 0:
|
||||
# Calculate offset based on output split sizes
|
||||
output_offset = sum(output_split_sizes[:i]) if i > 0 else 0
|
||||
end_offset = (
|
||||
output_offset + output_split_sizes[i]
|
||||
if i < len(output_split_sizes)
|
||||
else output_tensor._local_tensors[rank_j].size(0)
|
||||
)
|
||||
else:
|
||||
# No output split sizes, use even splits
|
||||
split_size = output_tensor._local_tensors[rank_j].size(
|
||||
0
|
||||
) // len(group_ranks)
|
||||
output_offset = i * split_size
|
||||
end_offset = min(
|
||||
(i + 1) * split_size,
|
||||
output_tensor._local_tensors[rank_j].size(0),
|
||||
)
|
||||
|
||||
# Copy the split to the appropriate section of the output tensor
|
||||
output_section = output_tensor._local_tensors[rank_j][
|
||||
output_offset:end_offset
|
||||
]
|
||||
if output_section.numel() > 0:
|
||||
# Reshape split_tensor to match output_section if necessary
|
||||
if split_tensor.size() != output_section.size():
|
||||
split_tensor = split_tensor.view(output_section.size())
|
||||
output_section.copy_(split_tensor)
|
||||
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return work_so
|
||||
|
||||
|
||||
def _local_barrier(
|
||||
tensor: torch.Tensor,
|
||||
process_group_so: ScriptObject,
|
||||
device_ids: list[int],
|
||||
async_op: bool = True,
|
||||
timeout: int = -1,
|
||||
) -> ScriptObject:
|
||||
# "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work";
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
# Barrier is a synchronization primitive - in local simulation,
|
||||
# we don't need to do any actual work since all "ranks" are in the same process
|
||||
# Just validate that the tensor is a LocalTensor
|
||||
assert isinstance(tensor, LocalTensor)
|
||||
|
||||
# In a real distributed setting, barrier would synchronize all processes
|
||||
# In local simulation, this is essentially a no-op since all ranks are local
|
||||
work = FakeWork()
|
||||
work_so = Work.boxed(work)
|
||||
return work_so
|
||||
|
||||
|
||||
def _local_monitored_barrier_(
|
||||
tensor: torch.Tensor,
|
||||
process_group_so: ScriptObject,
|
||||
device_ids: list[int],
|
||||
timeout: int,
|
||||
wait_all_ranks: bool,
|
||||
) -> None:
|
||||
# "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int[] device_ids, int timeout, bool wait_all_ranks) -> ()";
|
||||
|
||||
from . import LocalTensor
|
||||
|
||||
# Monitored barrier is a synchronization primitive with monitoring - in local simulation,
|
||||
# we don't need to do any actual work since all "ranks" are in the same process
|
||||
# Just validate that the tensor is a LocalTensor
|
||||
assert isinstance(tensor, LocalTensor)
|
||||
|
||||
# In a real distributed setting, monitored barrier would synchronize all processes
|
||||
# and provide monitoring capabilities. In local simulation, this is essentially a no-op
|
||||
# since all ranks are local and no actual synchronization is needed
|
||||
return
|
||||
|
||||
|
||||
def _local_send(
|
||||
tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
dst: int,
|
||||
tag: int,
|
||||
) -> ScriptObject:
|
||||
# "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int dst, int tag) -> __torch__.torch.classes.c10d.Work";
|
||||
|
||||
raise NotImplementedError(
|
||||
"LocalTensor does not support MPMD operations like send. "
|
||||
"Use SPMD collective operations instead."
|
||||
)
|
||||
|
||||
|
||||
def _local_recv_(
|
||||
tensors: list[torch.Tensor],
|
||||
process_group_so: ScriptObject,
|
||||
src: int,
|
||||
tag: int,
|
||||
) -> ScriptObject:
|
||||
# "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int src, int tag) -> __torch__.torch.classes.c10d.Work";
|
||||
|
||||
raise NotImplementedError(
|
||||
"LocalTensor does not support MPMD operations like recv. "
|
||||
"Use SPMD collective operations instead."
|
||||
)
|
||||
|
||||
|
||||
def _local_recv_any_source_(
|
||||
tensors: list[torch.Tensor], process_group_so: ScriptObject, tag: int
|
||||
) -> ScriptObject:
|
||||
# "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
# "int tag) -> __torch__.torch.classes.c10d.Work";
|
||||
|
||||
raise NotImplementedError(
|
||||
"LocalTensor does not support MPMD operations like recv_any_source. "
|
||||
"Use SPMD collective operations instead."
|
||||
)
|
@ -10,6 +10,7 @@ import torch.distributed._functional_collectives as funcol
|
||||
import torch.distributed.tensor._dtensor_spec as dtensor_spec
|
||||
from torch._C._distributed_c10d import _resolve_process_group
|
||||
from torch._logging import warning_once
|
||||
from torch.distributed._local_tensor import local_tensor_mode
|
||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_get_group_size_by_name,
|
||||
@ -40,7 +41,7 @@ def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
|
||||
|
||||
|
||||
def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
|
||||
if mesh.device_type == "cpu":
|
||||
if mesh.device_type == "cpu" and local_tensor_mode() is None:
|
||||
# Gloo does not support alltoall, so falling back to allgather + chunk
|
||||
warning_once(
|
||||
logger,
|
||||
|
@ -165,7 +165,7 @@ class OpDispatcher:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Sharding propagation failed for {op_info.schema}"
|
||||
f"{e}\n\nSharding propagation failed for {op_info.schema}"
|
||||
) from e
|
||||
|
||||
output_sharding = op_info.output_sharding
|
||||
|
@ -319,6 +319,10 @@ LINEAR_REDUCTION_OP_MAP = {
|
||||
aten.all.dim: "sum",
|
||||
aten.sum.default: "sum",
|
||||
aten.sum.dim_IntList: "sum",
|
||||
aten.any.default: "sum",
|
||||
aten.any.dim: "sum",
|
||||
aten.any.out: "sum",
|
||||
# These are only valid when there is no padding
|
||||
aten.prod.default: "product",
|
||||
aten.prod.dim_int: "product",
|
||||
aten.prod.int_out: "product",
|
||||
@ -332,9 +336,6 @@ LINEAR_REDUCTION_OP_MAP = {
|
||||
aten.min.default: "min",
|
||||
aten.min.dim: "min",
|
||||
aten.min.out: "min",
|
||||
aten.any.default: "sum",
|
||||
aten.any.dim: "sum",
|
||||
aten.any.out: "sum",
|
||||
aten.amax.default: "max",
|
||||
aten.amax.out: "max",
|
||||
aten.amin.default: "min",
|
||||
|
@ -383,6 +383,7 @@ def redistribute_local_tensor(
|
||||
raise RuntimeError(
|
||||
f"redistribute from {current} to {target} not supported yet"
|
||||
)
|
||||
|
||||
elif target.is_shard():
|
||||
# Case 2: target is Shard
|
||||
target_placement = cast(Shard, target)
|
||||
|
@ -149,7 +149,7 @@ def _compute_local_shape_and_global_offset(
|
||||
ordered_placements = _explicit_order_placements(mesh_shape, placements)
|
||||
|
||||
local_shape = list(global_shape)
|
||||
# We'll compute the data for where the shard beings on a per-dim basis.
|
||||
# We'll compute the data for where the shard begins on a per-dim basis.
|
||||
# However, a single dim can be sharded multiple times, so we will end up
|
||||
# doing a Sum(size*stride) like computation to determine the location of our
|
||||
# shard for each of the shardings on that dim.
|
||||
@ -170,6 +170,14 @@ def _compute_local_shape_and_global_offset(
|
||||
|
||||
local_shape[shard_dim] = shard_size
|
||||
|
||||
shard_global_offset = global_offset[shard_dim] + not_none(shard_offset)
|
||||
|
||||
zero_global_offset = global_shape[shard_dim]
|
||||
if isinstance(shard_global_offset, torch.SymInt) and not isinstance(
|
||||
zero_global_offset, torch.SymInt
|
||||
):
|
||||
zero_global_offset = torch.SymInt(zero_global_offset)
|
||||
|
||||
global_offset[shard_dim] = torch.sym_ite(
|
||||
shard_size == 0,
|
||||
# Special case to fill in a standardized non-garbage value for
|
||||
@ -179,11 +187,11 @@ def _compute_local_shape_and_global_offset(
|
||||
# Note that you can end up with zero-size shards that are
|
||||
# still otherwise in bounds for the tensor (TODO: give an
|
||||
# example).
|
||||
global_shape[shard_dim],
|
||||
zero_global_offset,
|
||||
# As we successively shard the same dimension, we keep
|
||||
# advancing our pointer beyond our original offset until we
|
||||
# get to the final chunk start.
|
||||
global_offset[shard_dim] + not_none(shard_offset),
|
||||
shard_global_offset,
|
||||
)
|
||||
|
||||
# NOTE: the offset compute relies on the local shard index and it has no
|
||||
|
@ -6,6 +6,7 @@ from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch.distributed._local_tensor import maybe_run_for_local_tensor
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._collective_utils import (
|
||||
fill_empty_tensor_to_shards,
|
||||
@ -128,6 +129,7 @@ class Shard(Placement):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def local_shard_size_and_offset(
|
||||
curr_local_size: int,
|
||||
num_chunks: int,
|
||||
@ -170,6 +172,20 @@ class Shard(Placement):
|
||||
) -> tuple[int, Optional[int]]:
|
||||
return Shard.local_shard_size_and_offset(curr_local_size, num_chunks, rank)
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _maybe_unpad_tensor_with_sizes(
|
||||
dim, local_tensor, pad_sizes, mesh_dim_local_rank, make_contiguous
|
||||
) -> torch.Tensor:
|
||||
# Only unpad if the local_tensor was padded on the dimension.
|
||||
if pad_sizes[mesh_dim_local_rank] > 0:
|
||||
local_tensor = unpad_tensor(
|
||||
local_tensor, dim, pad_sizes[mesh_dim_local_rank]
|
||||
)
|
||||
if make_contiguous:
|
||||
local_tensor = local_tensor.contiguous()
|
||||
return local_tensor
|
||||
|
||||
@staticmethod
|
||||
def _make_shard_tensor(
|
||||
dim: int,
|
||||
@ -198,24 +214,28 @@ class Shard(Placement):
|
||||
dim, tensor, num_chunks, with_padding=False, contiguous=True
|
||||
)
|
||||
|
||||
return scatter_list[mesh_dim_local_rank]
|
||||
return Shard._select_shard(scatter_list, mesh_dim_local_rank)
|
||||
|
||||
scatter_list, pad_sizes = Shard._make_split_tensor(
|
||||
dim, tensor, num_chunks, with_padding=True, contiguous=True
|
||||
)
|
||||
output = torch.empty_like(scatter_list[mesh_dim_local_rank])
|
||||
|
||||
it = iter(scatter_list)
|
||||
first = next(it)
|
||||
# Tensors in the scatter list are expected to have the same shape because
|
||||
# split is requested with padding.
|
||||
assert all(first.shape == v.shape for v in it)
|
||||
|
||||
output = torch.empty_like(first)
|
||||
|
||||
# perform scatter from the src_data_rank as data source when it is not None
|
||||
mesh_scatter(
|
||||
output, scatter_list, mesh, mesh_dim=mesh_dim, group_src=src_data_rank
|
||||
)
|
||||
|
||||
# Only unpad if the local_tensor was padded on the dimension.
|
||||
if pad_sizes[mesh_dim_local_rank] > 0:
|
||||
output = unpad_tensor(output, dim, pad_sizes[mesh_dim_local_rank])
|
||||
# Unpad might return a view, hence we need to remake it contiguous
|
||||
output = output.contiguous()
|
||||
return output
|
||||
return Shard._maybe_unpad_tensor_with_sizes(
|
||||
dim, output, pad_sizes, mesh_dim_local_rank, True
|
||||
)
|
||||
|
||||
def _shard_tensor(
|
||||
self,
|
||||
@ -245,6 +265,7 @@ class Shard(Placement):
|
||||
return tensor
|
||||
|
||||
is_padded = tensor.size(self.dim) % num_chunks != 0
|
||||
pad_sizes = None
|
||||
if is_padded:
|
||||
scattered_list, pad_sizes = Shard._make_split_tensor(
|
||||
self.dim, tensor, num_chunks, with_padding=True, contiguous=True
|
||||
@ -258,9 +279,47 @@ class Shard(Placement):
|
||||
)
|
||||
|
||||
if is_padded:
|
||||
output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined]
|
||||
assert pad_sizes is not None
|
||||
output = Shard._maybe_unpad_tensor_with_sizes(
|
||||
self.dim, output, pad_sizes, my_coordinate[mesh_dim], False
|
||||
)
|
||||
return output
|
||||
|
||||
@maybe_run_for_local_tensor
|
||||
def _maybe_pad_tensor(
|
||||
self,
|
||||
local_tensor: torch.Tensor,
|
||||
logical_dim_size: int,
|
||||
num_chunks: int,
|
||||
) -> torch.Tensor:
|
||||
is_padded = logical_dim_size % num_chunks != 0
|
||||
|
||||
if is_padded:
|
||||
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
|
||||
pad_size = full_chunk_size - local_tensor.size(self.dim)
|
||||
local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
|
||||
|
||||
if not local_tensor.is_contiguous():
|
||||
local_tensor = local_tensor.contiguous()
|
||||
|
||||
return local_tensor
|
||||
|
||||
@maybe_run_for_local_tensor
|
||||
def _maybe_unpad_tensor(
|
||||
self,
|
||||
local_tensor: torch.Tensor,
|
||||
logical_dim_size: int,
|
||||
num_chunks: int,
|
||||
) -> torch.Tensor:
|
||||
is_padded = logical_dim_size % num_chunks != 0
|
||||
|
||||
if is_padded:
|
||||
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
|
||||
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined]
|
||||
local_tensor = unpad_tensor(local_tensor, self.dim, unpad_size)
|
||||
|
||||
return local_tensor
|
||||
|
||||
def _to_replicate_tensor(
|
||||
self,
|
||||
local_tensor: torch.Tensor,
|
||||
@ -273,28 +332,27 @@ class Shard(Placement):
|
||||
is replicated on the previously sharded mesh dimension
|
||||
"""
|
||||
num_chunks = mesh.size(mesh_dim=mesh_dim)
|
||||
|
||||
logical_dim_size = current_logical_shape[self.dim]
|
||||
is_padded = logical_dim_size % num_chunks != 0
|
||||
|
||||
if is_padded:
|
||||
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
|
||||
pad_size = full_chunk_size - local_tensor.size(self.dim)
|
||||
local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
|
||||
|
||||
if not local_tensor.is_contiguous():
|
||||
local_tensor = local_tensor.contiguous()
|
||||
local_tensor = self._maybe_pad_tensor(
|
||||
local_tensor, logical_dim_size, num_chunks
|
||||
)
|
||||
|
||||
result = funcol.all_gather_tensor(
|
||||
local_tensor,
|
||||
gather_dim=self.dim,
|
||||
group=(mesh, mesh_dim),
|
||||
)
|
||||
if is_padded:
|
||||
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined]
|
||||
result = unpad_tensor(result, self.dim, unpad_size)
|
||||
|
||||
result = self._maybe_unpad_tensor(result, logical_dim_size, num_chunks)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
@maybe_run_for_local_tensor
|
||||
def _select_shard(shards: list[torch.Tensor], shard_index) -> torch.Tensor:
|
||||
return shards[shard_index].clone()
|
||||
|
||||
def _replicate_to_shard(
|
||||
self,
|
||||
local_tensor: torch.Tensor,
|
||||
@ -313,7 +371,8 @@ class Shard(Placement):
|
||||
with_padding=False,
|
||||
contiguous=False,
|
||||
)
|
||||
return shards[shard_index].clone()
|
||||
|
||||
return Shard._select_shard(shards, shard_index)
|
||||
|
||||
def _to_new_shard_dim(
|
||||
self,
|
||||
|
@ -41,6 +41,9 @@ class ConstantIntNode:
|
||||
def _graph_repr(self) -> str:
|
||||
return self._str()
|
||||
|
||||
def add(self, other: Any) -> Any:
|
||||
return other.add(self)
|
||||
|
||||
def mul(self, other: Any) -> Any:
|
||||
return other.mul(self)
|
||||
|
||||
|
@ -13,6 +13,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._local_tensor import LocalTensor
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
@ -660,7 +661,7 @@ class DTensorConverter:
|
||||
def to_dist_tensor(
|
||||
self, t: torch.Tensor, mesh: DeviceMesh, placements: list[Placement]
|
||||
) -> torch.Tensor:
|
||||
if type(t) is torch.Tensor or type(t) is nn.Parameter:
|
||||
if type(t) is torch.Tensor or type(t) is nn.Parameter or type(t) is LocalTensor:
|
||||
if self.is_supported_tensor(t):
|
||||
self.hit += 1
|
||||
if t.ndim == 0:
|
||||
@ -669,7 +670,7 @@ class DTensorConverter:
|
||||
else:
|
||||
# distribute non-scalar tensors
|
||||
r = distribute_tensor(t, mesh, placements)
|
||||
if type(t) is nn.Parameter:
|
||||
if isinstance(t, nn.Parameter):
|
||||
r = nn.Parameter( # type: ignore[assignment]
|
||||
r, requires_grad=r.requires_grad
|
||||
)
|
||||
|
@ -1292,7 +1292,7 @@ SAC_IGNORED_OPS = {
|
||||
# With subclasses involved, these metadata ops become dispatchable, this
|
||||
# can result in incorrectness if these ops are selected cached.
|
||||
torch.ops.prim.device.default,
|
||||
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
|
||||
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) # type: ignore[has-type]
|
||||
|
||||
|
||||
class _CachingTorchDispatchMode(TorchDispatchMode):
|
||||
|
Reference in New Issue
Block a user