mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
C++-accessible Placements via pybind11 (#163030)
This makes Placement data representation available in C++ via pybind11. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163030 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
349e9e922d
commit
3e03deab6f
@ -913,6 +913,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/autograd/python_torch_functions_manual.cpp",
|
||||
"torch/csrc/autograd/python_variable.cpp",
|
||||
"torch/csrc/autograd/python_variable_indexing.cpp",
|
||||
"torch/csrc/distributed/python_placement.cpp",
|
||||
"torch/csrc/dynamo/python_compiled_autograd.cpp",
|
||||
"torch/csrc/dynamo/cache_entry.cpp",
|
||||
"torch/csrc/dynamo/cpp_shim.cpp",
|
||||
|
88
test/distributed/tensor/test_placement_types.py
Normal file
88
test/distributed/tensor/test_placement_types.py
Normal file
@ -0,0 +1,88 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import copy
|
||||
import itertools
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from torch._dynamo.variables.distributed import PlacementClassVariable
|
||||
from torch.distributed.tensor.placement_types import (
|
||||
_StridedShard,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
# Basic functionality test for Placement types.
|
||||
class PlacementTypesTestCase(TestCase):
|
||||
def test_type_identification(self):
|
||||
shard = Shard(3)
|
||||
strided_shard = _StridedShard(dim=3, split_factor=7)
|
||||
partial_sum = Partial("sum")
|
||||
partial_max = Partial("max")
|
||||
replicate = Replicate()
|
||||
|
||||
ident_tests = (
|
||||
(shard, True, False, False),
|
||||
(strided_shard, True, False, False),
|
||||
(partial_sum, False, True, False),
|
||||
(partial_max, False, True, False),
|
||||
(replicate, False, False, True),
|
||||
)
|
||||
for do_deepcopy in (False, True):
|
||||
for placement, is_shard, is_partial, is_replicate in ident_tests:
|
||||
if do_deepcopy:
|
||||
placement = copy.deepcopy(placement)
|
||||
self.assertEqual(placement.is_shard(), is_shard)
|
||||
self.assertEqual(placement.is_partial(), is_partial)
|
||||
self.assertEqual(placement.is_replicate(), is_replicate)
|
||||
|
||||
def test_equality(self):
|
||||
equivalence_classes = (
|
||||
(Shard(3), _StridedShard(dim=3, split_factor=7)),
|
||||
(Shard(4), _StridedShard(dim=4, split_factor=9)),
|
||||
(Replicate(),),
|
||||
(Partial("sum"),),
|
||||
(Partial("max"),),
|
||||
)
|
||||
for eq_class in equivalence_classes:
|
||||
# Each item in the equivalence class should be equal to every other item in
|
||||
# its class.
|
||||
for lhs, rhs in itertools.product(eq_class, eq_class):
|
||||
self.assertEqual(lhs, rhs)
|
||||
|
||||
# Each item in the equivalence class should not be equal to any item in any
|
||||
# other class.
|
||||
for other_class in equivalence_classes:
|
||||
if other_class is eq_class:
|
||||
continue
|
||||
for lhs, rhs in itertools.product(eq_class, other_class):
|
||||
self.assertNotEqual(lhs, rhs)
|
||||
|
||||
# Testing this case doesn't seem to fit neatly into the above equivalence class
|
||||
# framework.
|
||||
self.assertNotEqual(
|
||||
_StridedShard(dim=3, split_factor=1), _StridedShard(dim=3, split_factor=2)
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 10), "kw_only is only available in python >= 3.10"
|
||||
)
|
||||
def test_strided_shard_kwonly_argument(self):
|
||||
with self.assertRaises(TypeError):
|
||||
_StridedShard(3, 4)
|
||||
_StridedShard(3, split_factor=4)
|
||||
|
||||
def test_strided_shard_isinstance_shard(self):
|
||||
assert isinstance(_StridedShard(dim=3, split_factor=7), Shard)
|
||||
|
||||
def test_dynamo_can_identify_placement_classes(self):
|
||||
for cls in (Replicate, Shard, _StridedShard, Partial):
|
||||
self.assertTrue(
|
||||
PlacementClassVariable.is_placement_type(cls), msg=f"failed on {cls}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
20
torch/_C/_distributed.pyi
Normal file
20
torch/_C/_distributed.pyi
Normal file
@ -0,0 +1,20 @@
|
||||
# This module is defined in torch/csrc/distributed/python_placement.cpp
|
||||
|
||||
class Placement:
|
||||
def is_partial(self, reduce_op: str | None = None) -> bool: ...
|
||||
def is_replicate(self) -> bool: ...
|
||||
def is_shard(self, dim: int | None = None) -> bool: ...
|
||||
|
||||
class Shard(Placement):
|
||||
dim: int
|
||||
def __init__(self, dim: int): ...
|
||||
|
||||
class StridedShard(Shard):
|
||||
split_factor: int
|
||||
def __init__(self, dim: int, *, split_factor: int): ...
|
||||
|
||||
class Replicate(Placement): ...
|
||||
|
||||
class Partial(Placement):
|
||||
reduce_op: str
|
||||
def __init__(self, reduce_op: str | None = None): ...
|
@ -142,7 +142,7 @@ class PlacementClassVariable(DistributedVariable):
|
||||
|
||||
from torch.distributed.tensor.placement_types import Placement
|
||||
|
||||
return type(value) is type and issubclass(value, Placement)
|
||||
return isinstance(value, type) and issubclass(value, Placement)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
@ -153,13 +153,10 @@ class PlacementClassVariable(DistributedVariable):
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if (
|
||||
inspect.getattr_static(self.value, "__new__", None) == object.__new__
|
||||
and self.source
|
||||
):
|
||||
if self.source:
|
||||
# NOTE: we don't need to track mutations to the placement class as they
|
||||
# suppose to be immutable.
|
||||
new_obj = object.__new__(self.value)
|
||||
# are supposed to be immutable.
|
||||
new_obj = self.value.__new__(self.value)
|
||||
var = PlacementVariable(new_obj)
|
||||
if inspect.getattr_static(self.value, "__init__", None):
|
||||
var.call_method(tx, "__init__", args, kwargs)
|
||||
|
@ -71,6 +71,7 @@
|
||||
#include <torch/csrc/autograd/python_special_functions.h>
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <torch/csrc/cpu/Module.h>
|
||||
#include <torch/csrc/distributed/python_placement.h>
|
||||
#include <torch/csrc/dynamo/init.h>
|
||||
#include <torch/csrc/export/pybind.h>
|
||||
#include <torch/csrc/functionalization/Module.h>
|
||||
@ -2119,6 +2120,8 @@ PyObject* initModule() {
|
||||
THXPEvent_init(module);
|
||||
#endif
|
||||
|
||||
torch::distributed::initPlacementBindings(module);
|
||||
|
||||
auto set_module_attr =
|
||||
[&](const char* name, PyObject* v, bool incref = true) {
|
||||
// PyModule_AddObject steals reference
|
||||
|
121
torch/csrc/distributed/Placement.h
Normal file
121
torch/csrc/distributed/Placement.h
Normal file
@ -0,0 +1,121 @@
|
||||
#pragma once
|
||||
/**
|
||||
* The implementations in this file are coupled with
|
||||
* torch/distributed/tensor/placement_types.py.
|
||||
*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
||||
namespace torch::distributed {
|
||||
|
||||
class Placement {
|
||||
public:
|
||||
Placement() = default;
|
||||
virtual ~Placement() = default;
|
||||
|
||||
Placement(const Placement&) = default;
|
||||
Placement& operator=(const Placement&) = default;
|
||||
Placement(Placement&&) noexcept = default;
|
||||
Placement& operator=(Placement&&) noexcept = default;
|
||||
|
||||
virtual bool is_shard(std::optional<std::int64_t> dim) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool is_replicate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool is_partial(
|
||||
std::optional<std::string_view> reduce_op = std::nullopt) const {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
class Shard : public Placement {
|
||||
public:
|
||||
std::int64_t dim;
|
||||
explicit Shard(std::int64_t dim_) : dim(dim_) {}
|
||||
|
||||
bool is_shard(std::optional<std::int64_t> dim_) const override {
|
||||
return !dim_.has_value() || *dim_ == dim;
|
||||
}
|
||||
|
||||
bool operator==(const Shard& rhs) const {
|
||||
return dim == rhs.dim;
|
||||
}
|
||||
|
||||
bool operator!=(const Shard& rhs) const {
|
||||
return !operator==(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
class StridedShard : public Shard {
|
||||
public:
|
||||
std::int64_t split_factor;
|
||||
explicit StridedShard(std::int64_t dim, std::int64_t split_factor_)
|
||||
: Shard(dim), split_factor(split_factor_) {}
|
||||
|
||||
bool operator==(const StridedShard& rhs) const {
|
||||
return dim == rhs.dim && split_factor == rhs.split_factor;
|
||||
}
|
||||
|
||||
bool operator==(const Shard& rhs) const {
|
||||
if (auto* rhs_strided = dynamic_cast<const StridedShard*>(&rhs)) {
|
||||
return operator==(*rhs_strided);
|
||||
}
|
||||
// TODO: this is to avoid extra all-gather in dtensor op dispatch
|
||||
// note that sharding prop would not produce _StridedShard and a
|
||||
// placement inequality would introduce an all-gather for resharding
|
||||
return dim == rhs.dim;
|
||||
}
|
||||
|
||||
bool operator!=(const Shard& rhs) const {
|
||||
return !operator==(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
class Replicate : public Placement {
|
||||
public:
|
||||
bool is_replicate() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool operator==(const Replicate& rhs) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool operator!=(const Replicate& rhs) const {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
class Partial : public Placement {
|
||||
public:
|
||||
std::string reduce_op;
|
||||
|
||||
Partial() : Partial("sum") {}
|
||||
|
||||
explicit Partial(std::optional<std::string> reduce_op_)
|
||||
: reduce_op(
|
||||
reduce_op_.has_value() ? std::move(*reduce_op_)
|
||||
: std::string("sum")) {}
|
||||
|
||||
bool is_partial(
|
||||
std::optional<std::string_view> op = std::nullopt) const override {
|
||||
return !op.has_value() || *op == reduce_op;
|
||||
}
|
||||
|
||||
bool operator==(const Partial& rhs) const {
|
||||
return reduce_op == rhs.reduce_op;
|
||||
}
|
||||
|
||||
bool operator!=(const Partial& rhs) const {
|
||||
return !operator==(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace torch::distributed
|
104
torch/csrc/distributed/python_placement.cpp
Normal file
104
torch/csrc/distributed/python_placement.cpp
Normal file
@ -0,0 +1,104 @@
|
||||
#include <torch/csrc/distributed/python_placement.h>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/csrc/distributed/Placement.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
using namespace pybind11::literals;
|
||||
|
||||
namespace torch::distributed {
|
||||
namespace {
|
||||
const auto placement_class_docstring =
|
||||
R"(The base class for the Placement type, where it describes how a DTensor is placed onto the
|
||||
``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout.
|
||||
It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``,
|
||||
and ``Partial``.
|
||||
|
||||
This class is not meant to be used directly, mainly served as a typing stub.
|
||||
)";
|
||||
} // namespace
|
||||
|
||||
void initPlacementBindings(PyObject* module) {
|
||||
auto py_module = py::reinterpret_borrow<py::module>(module);
|
||||
auto distributed_module = py_module.def_submodule("_distributed");
|
||||
py::class_<Placement>(
|
||||
distributed_module, "Placement", placement_class_docstring)
|
||||
.def(py::init<>()) // Allow construction of Python subclasses.
|
||||
.def(
|
||||
"is_partial",
|
||||
&Placement::is_partial,
|
||||
py::arg("reduce_op") = py::none())
|
||||
.def("is_replicate", &Placement::is_replicate)
|
||||
.def("is_shard", &Placement::is_shard, py::arg("dim") = py::none());
|
||||
py::class_<Shard, Placement>(distributed_module, "Shard")
|
||||
.def(py::init<int64_t>(), py::arg("dim"))
|
||||
.def_readonly("dim", &Shard::dim)
|
||||
.def("is_shard", &Shard::is_shard, py::arg("dim") = py::none())
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const Shard& lhs, const Shard& rhs) { return lhs == rhs; },
|
||||
py::is_operator())
|
||||
// Note: we need to use dicts for pickling to match the old
|
||||
// dataclasses.
|
||||
.def(py::pickle(
|
||||
[](const Shard& shard) { return py::dict("dim"_a = shard.dim); },
|
||||
[](const py::dict& d) {
|
||||
return Shard(py::cast<int64_t>(d["dim"]));
|
||||
}));
|
||||
py::class_<StridedShard, Shard>(distributed_module, "StridedShard")
|
||||
.def(
|
||||
py::init<int64_t, int64_t>(),
|
||||
py::arg("dim"),
|
||||
py::kw_only(),
|
||||
py::arg("split_factor"))
|
||||
.def_readonly("split_factor", &StridedShard::split_factor)
|
||||
.def("is_shard", &StridedShard::is_shard, py::arg("dim") = py::none())
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const StridedShard& lhs, const Shard& rhs) { return lhs == rhs; },
|
||||
py::is_operator())
|
||||
.def(py::pickle(
|
||||
[](const StridedShard& shard) {
|
||||
return py::dict(
|
||||
"dim"_a = shard.dim, "split_factor"_a = shard.split_factor);
|
||||
},
|
||||
[](const py::dict& d) {
|
||||
return StridedShard(
|
||||
py::cast<int64_t>(d["dim"]),
|
||||
py::cast<int64_t>(d["split_factor"]));
|
||||
}));
|
||||
py::class_<Replicate, Placement>(distributed_module, "Replicate")
|
||||
.def(py::init())
|
||||
.def("is_replicate", &Replicate::is_replicate)
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const Replicate& lhs, const Replicate& rhs) { return lhs == rhs; },
|
||||
py::is_operator())
|
||||
.def(py::pickle(
|
||||
// I observed SIGSEGV when trying to use None as the
|
||||
// pickled state, though AFAICT that matches the
|
||||
// behavior of
|
||||
// object().__reduce__().
|
||||
// test_placement_types.test_type_identification will repro if an
|
||||
// enterprising reader wants to get this fixed.
|
||||
[](const Replicate& repl) { return py::dict(); },
|
||||
[](const py::dict&) { return Replicate(); }));
|
||||
py::class_<Partial, Placement>(distributed_module, "Partial")
|
||||
.def(py::init<>())
|
||||
.def(py::init<std::optional<std::string>>(), py::arg("reduce_op"))
|
||||
.def_readonly("reduce_op", &Partial::reduce_op)
|
||||
.def(
|
||||
"is_partial", &Partial::is_partial, py::arg("reduce_op") = py::none())
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const Partial& lhs, const Partial& rhs) { return lhs == rhs; },
|
||||
py::is_operator())
|
||||
.def(py::pickle(
|
||||
[](const Partial& part) {
|
||||
return py::dict("reduce_op"_a = part.reduce_op);
|
||||
},
|
||||
[](const py::dict& d) {
|
||||
return Partial(py::cast<std::string>(d["reduce_op"]));
|
||||
}));
|
||||
}
|
||||
} // namespace torch::distributed
|
7
torch/csrc/distributed/python_placement.h
Normal file
7
torch/csrc/distributed/python_placement.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/utils/python_stub.h>
|
||||
|
||||
namespace torch::distributed {
|
||||
void initPlacementBindings(PyObject* module);
|
||||
} // namespace torch::distributed
|
@ -83,6 +83,22 @@ class _MaskPartial(Partial):
|
||||
offset_shape: Optional[torch.Size] = None
|
||||
offset_dim: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_op=None,
|
||||
mask_buffer=None,
|
||||
offset_shape=None,
|
||||
offset_dim=0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(reduce_op)
|
||||
if mask_buffer is None:
|
||||
mask_buffer = MaskBuffer()
|
||||
object.__setattr__(self, "mask_buffer", mask_buffer)
|
||||
object.__setattr__(self, "offset_shape", offset_shape)
|
||||
object.__setattr__(self, "offset_dim", offset_dim)
|
||||
|
||||
def _partition_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
|
@ -73,17 +73,19 @@ class _NormPartial(Partial):
|
||||
|
||||
norm_type: Union[int, float, str] = 2
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set the appropriate reduce op based on the norm type."""
|
||||
# Use `object.__setattr__` to bypass frozen checks
|
||||
if self.norm_type in (float("inf"), "inf"):
|
||||
object.__setattr__(self, "reduce_op", "max")
|
||||
elif self.norm_type in (float("-inf"), "-inf"):
|
||||
object.__setattr__(self, "reduce_op", "min")
|
||||
elif isinstance(self.norm_type, (int, float)):
|
||||
object.__setattr__(self, "reduce_op", "sum")
|
||||
def __init__(self, norm_type: Union[int, float, str] = 2):
|
||||
reduce_op = None
|
||||
if norm_type in (float("inf"), "inf"):
|
||||
reduce_op = "max"
|
||||
elif norm_type in (float("-inf"), "-inf"):
|
||||
reduce_op = "min"
|
||||
elif isinstance(norm_type, (int, float)):
|
||||
reduce_op = "sum"
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported norm type: {self.norm_type}")
|
||||
raise NotImplementedError(f"Unsupported norm type: {norm_type}")
|
||||
|
||||
super().__init__(reduce_op)
|
||||
object.__setattr__(self, "norm_type", norm_type)
|
||||
|
||||
def _partition_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
|
@ -1,11 +1,12 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
from torch._C._distributed import Placement
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._collective_utils import (
|
||||
fill_empty_tensor_to_shards,
|
||||
@ -20,29 +21,11 @@ from torch.distributed.tensor._collective_utils import (
|
||||
__all__ = ["Placement", "Shard", "Replicate", "Partial"]
|
||||
|
||||
|
||||
class Placement:
|
||||
"""
|
||||
The base class for the Placement type, where it describes how a DTensor is placed onto the
|
||||
``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout.
|
||||
It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``,
|
||||
and ``Partial``.
|
||||
|
||||
This class is not meant to be used directly, mainly served as a typing stub.
|
||||
"""
|
||||
|
||||
# convenient utils to check for placement types
|
||||
def is_shard(self, dim: Optional[int] = None) -> bool:
|
||||
return False
|
||||
|
||||
def is_replicate(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_partial(self, reduce_op: Optional[str] = None) -> bool:
|
||||
return False
|
||||
# Appease TestPublicBindings.test_correct_module_names
|
||||
Placement.__module__ = "torch.distributed.tensor.placement_types"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Shard(Placement):
|
||||
class Shard(torch._C._distributed.Shard):
|
||||
"""
|
||||
The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension
|
||||
``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the
|
||||
@ -60,14 +43,6 @@ class Shard(Placement):
|
||||
evenly divisible on a DeviceMesh dimension is currently experimental and subject to change.
|
||||
"""
|
||||
|
||||
dim: int
|
||||
|
||||
def is_shard(self, dim: Optional[int] = None) -> bool:
|
||||
if dim is not None:
|
||||
return self.dim == dim
|
||||
else:
|
||||
return True
|
||||
|
||||
def _split_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
@ -349,11 +324,6 @@ class Shard(Placement):
|
||||
|
||||
return new_tensor
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Shard):
|
||||
return False
|
||||
return self.dim == other.dim
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.dim)
|
||||
|
||||
@ -368,8 +338,8 @@ class Shard(Placement):
|
||||
return f"S({self.dim})"
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class _StridedShard(Shard):
|
||||
# Need to inherit from Shard here so that isinstance(some_strided_shard, Shard) will work.
|
||||
class _StridedShard(torch._C._distributed.StridedShard, Shard):
|
||||
"""
|
||||
_StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor
|
||||
is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension.
|
||||
@ -427,18 +397,6 @@ class _StridedShard(Shard):
|
||||
TODO: we should remove _StridedShard placement once we can unify it with Shard
|
||||
"""
|
||||
|
||||
split_factor: int
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, _StridedShard):
|
||||
return self.dim == other.dim and self.split_factor == other.split_factor
|
||||
elif isinstance(other, Shard):
|
||||
# TODO: this is to avoid extra all-gather in dtensor op dispatch
|
||||
# note that sharding prop would not produce _StridedShard and an
|
||||
# placement inequality would introduce an all-gather for resharding
|
||||
return self.dim == other.dim
|
||||
return False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.dim, self.split_factor))
|
||||
|
||||
@ -585,8 +543,7 @@ class _StridedShard(Shard):
|
||||
return local_shard_size, None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Replicate(Placement):
|
||||
class Replicate(torch._C._distributed.Replicate):
|
||||
"""
|
||||
The ``Replicate()`` placement describes the DTensor replicating on a corresponding
|
||||
``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a
|
||||
@ -594,9 +551,6 @@ class Replicate(Placement):
|
||||
DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.)
|
||||
"""
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, Replicate)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# every replicate placement is the same
|
||||
return -1
|
||||
@ -636,12 +590,8 @@ class Replicate(Placement):
|
||||
mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim, group_src=src_data_rank)
|
||||
return tensor
|
||||
|
||||
def is_replicate(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Partial(Placement):
|
||||
class Partial(torch._C._distributed.Partial):
|
||||
"""
|
||||
The ``Partial(reduce_op)`` placement describes the DTensor that is pending
|
||||
reduction on a specified ``DeviceMesh`` dimension, where each rank on the
|
||||
@ -660,8 +610,6 @@ class Partial(Placement):
|
||||
and can only be used by the ``DTensor.from_local`` API.
|
||||
"""
|
||||
|
||||
reduce_op: str = "sum"
|
||||
|
||||
def _reduce_value(
|
||||
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
||||
) -> torch.Tensor:
|
||||
@ -698,11 +646,6 @@ class Partial(Placement):
|
||||
num_chunks = mesh.size(mesh_dim=mesh_dim)
|
||||
return tensor / num_chunks
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Partial):
|
||||
return False
|
||||
return self.reduce_op == other.reduce_op
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return 1 + hash(self.reduce_op)
|
||||
|
||||
@ -718,11 +661,6 @@ class Partial(Placement):
|
||||
"""
|
||||
return "P"
|
||||
|
||||
def is_partial(self, reduce_op: Optional[str] = None) -> bool:
|
||||
if reduce_op is None:
|
||||
return True
|
||||
return self.reduce_op == reduce_op
|
||||
|
||||
|
||||
# We keep the old _Partial name for a while for BC reason
|
||||
_Partial = Partial
|
||||
|
Reference in New Issue
Block a user