Revert "C++-accessible Placements via pybind11 (#163030)"

This reverts commit 3e03deab6f3c268c85c8efd9546e28cdda0fa4cc.

Reverted https://github.com/pytorch/pytorch/pull/163030 on behalf of https://github.com/swolchok due to doesn't pass pyre ([comment](https://github.com/pytorch/pytorch/pull/163030#issuecomment-3362450379))
This commit is contained in:
PyTorch MergeBot
2025-10-02 18:25:24 +00:00
parent e6d4b26776
commit f6f7676756
11 changed files with 88 additions and 385 deletions

View File

@ -913,7 +913,6 @@ libtorch_python_core_sources = [
"torch/csrc/autograd/python_torch_functions_manual.cpp",
"torch/csrc/autograd/python_variable.cpp",
"torch/csrc/autograd/python_variable_indexing.cpp",
"torch/csrc/distributed/python_placement.cpp",
"torch/csrc/dynamo/python_compiled_autograd.cpp",
"torch/csrc/dynamo/cache_entry.cpp",
"torch/csrc/dynamo/cpp_shim.cpp",

View File

@ -1,88 +0,0 @@
# Owner(s): ["oncall: distributed"]
import copy
import itertools
import sys
import unittest
from torch._dynamo.variables.distributed import PlacementClassVariable
from torch.distributed.tensor.placement_types import (
_StridedShard,
Partial,
Replicate,
Shard,
)
from torch.testing._internal.common_utils import run_tests, TestCase
# Basic functionality test for Placement types.
class PlacementTypesTestCase(TestCase):
def test_type_identification(self):
shard = Shard(3)
strided_shard = _StridedShard(dim=3, split_factor=7)
partial_sum = Partial("sum")
partial_max = Partial("max")
replicate = Replicate()
ident_tests = (
(shard, True, False, False),
(strided_shard, True, False, False),
(partial_sum, False, True, False),
(partial_max, False, True, False),
(replicate, False, False, True),
)
for do_deepcopy in (False, True):
for placement, is_shard, is_partial, is_replicate in ident_tests:
if do_deepcopy:
placement = copy.deepcopy(placement)
self.assertEqual(placement.is_shard(), is_shard)
self.assertEqual(placement.is_partial(), is_partial)
self.assertEqual(placement.is_replicate(), is_replicate)
def test_equality(self):
equivalence_classes = (
(Shard(3), _StridedShard(dim=3, split_factor=7)),
(Shard(4), _StridedShard(dim=4, split_factor=9)),
(Replicate(),),
(Partial("sum"),),
(Partial("max"),),
)
for eq_class in equivalence_classes:
# Each item in the equivalence class should be equal to every other item in
# its class.
for lhs, rhs in itertools.product(eq_class, eq_class):
self.assertEqual(lhs, rhs)
# Each item in the equivalence class should not be equal to any item in any
# other class.
for other_class in equivalence_classes:
if other_class is eq_class:
continue
for lhs, rhs in itertools.product(eq_class, other_class):
self.assertNotEqual(lhs, rhs)
# Testing this case doesn't seem to fit neatly into the above equivalence class
# framework.
self.assertNotEqual(
_StridedShard(dim=3, split_factor=1), _StridedShard(dim=3, split_factor=2)
)
@unittest.skipIf(
sys.version_info < (3, 10), "kw_only is only available in python >= 3.10"
)
def test_strided_shard_kwonly_argument(self):
with self.assertRaises(TypeError):
_StridedShard(3, 4)
_StridedShard(3, split_factor=4)
def test_strided_shard_isinstance_shard(self):
assert isinstance(_StridedShard(dim=3, split_factor=7), Shard)
def test_dynamo_can_identify_placement_classes(self):
for cls in (Replicate, Shard, _StridedShard, Partial):
self.assertTrue(
PlacementClassVariable.is_placement_type(cls), msg=f"failed on {cls}"
)
if __name__ == "__main__":
run_tests()

View File

@ -1,20 +0,0 @@
# This module is defined in torch/csrc/distributed/python_placement.cpp
class Placement:
def is_partial(self, reduce_op: str | None = None) -> bool: ...
def is_replicate(self) -> bool: ...
def is_shard(self, dim: int | None = None) -> bool: ...
class Shard(Placement):
dim: int
def __init__(self, dim: int): ...
class StridedShard(Shard):
split_factor: int
def __init__(self, dim: int, *, split_factor: int): ...
class Replicate(Placement): ...
class Partial(Placement):
reduce_op: str
def __init__(self, reduce_op: str | None = None): ...

View File

@ -142,7 +142,7 @@ class PlacementClassVariable(DistributedVariable):
from torch.distributed.tensor.placement_types import Placement
return isinstance(value, type) and issubclass(value, Placement)
return type(value) is type and issubclass(value, Placement)
def as_python_constant(self):
return self.value
@ -153,10 +153,13 @@ class PlacementClassVariable(DistributedVariable):
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if self.source:
if (
inspect.getattr_static(self.value, "__new__", None) == object.__new__
and self.source
):
# NOTE: we don't need to track mutations to the placement class as they
# are supposed to be immutable.
new_obj = self.value.__new__(self.value)
# suppose to be immutable.
new_obj = object.__new__(self.value)
var = PlacementVariable(new_obj)
if inspect.getattr_static(self.value, "__init__", None):
var.call_method(tx, "__init__", args, kwargs)

View File

@ -71,7 +71,6 @@
#include <torch/csrc/autograd/python_special_functions.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/cpu/Module.h>
#include <torch/csrc/distributed/python_placement.h>
#include <torch/csrc/dynamo/init.h>
#include <torch/csrc/export/pybind.h>
#include <torch/csrc/functionalization/Module.h>
@ -2112,8 +2111,6 @@ PyObject* initModule() {
THXPEvent_init(module);
#endif
torch::distributed::initPlacementBindings(module);
auto set_module_attr =
[&](const char* name, PyObject* v, bool incref = true) {
// PyModule_AddObject steals reference

View File

@ -1,121 +0,0 @@
#pragma once
/**
* The implementations in this file are coupled with
* torch/distributed/tensor/placement_types.py.
*/
#include <cstdint>
#include <optional>
#include <string>
#include <string_view>
namespace torch::distributed {
class Placement {
public:
Placement() = default;
virtual ~Placement() = default;
Placement(const Placement&) = default;
Placement& operator=(const Placement&) = default;
Placement(Placement&&) noexcept = default;
Placement& operator=(Placement&&) noexcept = default;
virtual bool is_shard(std::optional<std::int64_t> dim) const {
return false;
}
virtual bool is_replicate() const {
return false;
}
virtual bool is_partial(
std::optional<std::string_view> reduce_op = std::nullopt) const {
return false;
}
};
class Shard : public Placement {
public:
std::int64_t dim;
explicit Shard(std::int64_t dim_) : dim(dim_) {}
bool is_shard(std::optional<std::int64_t> dim_) const override {
return !dim_.has_value() || *dim_ == dim;
}
bool operator==(const Shard& rhs) const {
return dim == rhs.dim;
}
bool operator!=(const Shard& rhs) const {
return !operator==(rhs);
}
};
class StridedShard : public Shard {
public:
std::int64_t split_factor;
explicit StridedShard(std::int64_t dim, std::int64_t split_factor_)
: Shard(dim), split_factor(split_factor_) {}
bool operator==(const StridedShard& rhs) const {
return dim == rhs.dim && split_factor == rhs.split_factor;
}
bool operator==(const Shard& rhs) const {
if (auto* rhs_strided = dynamic_cast<const StridedShard*>(&rhs)) {
return operator==(*rhs_strided);
}
// TODO: this is to avoid extra all-gather in dtensor op dispatch
// note that sharding prop would not produce _StridedShard and a
// placement inequality would introduce an all-gather for resharding
return dim == rhs.dim;
}
bool operator!=(const Shard& rhs) const {
return !operator==(rhs);
}
};
class Replicate : public Placement {
public:
bool is_replicate() const override {
return true;
}
bool operator==(const Replicate& rhs) const {
return true;
}
bool operator!=(const Replicate& rhs) const {
return false;
}
};
class Partial : public Placement {
public:
std::string reduce_op;
Partial() : Partial("sum") {}
explicit Partial(std::optional<std::string> reduce_op_)
: reduce_op(
reduce_op_.has_value() ? std::move(*reduce_op_)
: std::string("sum")) {}
bool is_partial(
std::optional<std::string_view> op = std::nullopt) const override {
return !op.has_value() || *op == reduce_op;
}
bool operator==(const Partial& rhs) const {
return reduce_op == rhs.reduce_op;
}
bool operator!=(const Partial& rhs) const {
return !operator==(rhs);
}
};
} // namespace torch::distributed

View File

@ -1,104 +0,0 @@
#include <torch/csrc/distributed/python_placement.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/distributed/Placement.h>
#include <torch/csrc/utils/pybind.h>
using namespace pybind11::literals;
namespace torch::distributed {
namespace {
const auto placement_class_docstring =
R"(The base class for the Placement type, where it describes how a DTensor is placed onto the
``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout.
It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``,
and ``Partial``.
This class is not meant to be used directly, mainly served as a typing stub.
)";
} // namespace
void initPlacementBindings(PyObject* module) {
auto py_module = py::reinterpret_borrow<py::module>(module);
auto distributed_module = py_module.def_submodule("_distributed");
py::class_<Placement>(
distributed_module, "Placement", placement_class_docstring)
.def(py::init<>()) // Allow construction of Python subclasses.
.def(
"is_partial",
&Placement::is_partial,
py::arg("reduce_op") = py::none())
.def("is_replicate", &Placement::is_replicate)
.def("is_shard", &Placement::is_shard, py::arg("dim") = py::none());
py::class_<Shard, Placement>(distributed_module, "Shard")
.def(py::init<int64_t>(), py::arg("dim"))
.def_readonly("dim", &Shard::dim)
.def("is_shard", &Shard::is_shard, py::arg("dim") = py::none())
.def(
"__eq__",
[](const Shard& lhs, const Shard& rhs) { return lhs == rhs; },
py::is_operator())
// Note: we need to use dicts for pickling to match the old
// dataclasses.
.def(py::pickle(
[](const Shard& shard) { return py::dict("dim"_a = shard.dim); },
[](const py::dict& d) {
return Shard(py::cast<int64_t>(d["dim"]));
}));
py::class_<StridedShard, Shard>(distributed_module, "StridedShard")
.def(
py::init<int64_t, int64_t>(),
py::arg("dim"),
py::kw_only(),
py::arg("split_factor"))
.def_readonly("split_factor", &StridedShard::split_factor)
.def("is_shard", &StridedShard::is_shard, py::arg("dim") = py::none())
.def(
"__eq__",
[](const StridedShard& lhs, const Shard& rhs) { return lhs == rhs; },
py::is_operator())
.def(py::pickle(
[](const StridedShard& shard) {
return py::dict(
"dim"_a = shard.dim, "split_factor"_a = shard.split_factor);
},
[](const py::dict& d) {
return StridedShard(
py::cast<int64_t>(d["dim"]),
py::cast<int64_t>(d["split_factor"]));
}));
py::class_<Replicate, Placement>(distributed_module, "Replicate")
.def(py::init())
.def("is_replicate", &Replicate::is_replicate)
.def(
"__eq__",
[](const Replicate& lhs, const Replicate& rhs) { return lhs == rhs; },
py::is_operator())
.def(py::pickle(
// I observed SIGSEGV when trying to use None as the
// pickled state, though AFAICT that matches the
// behavior of
// object().__reduce__().
// test_placement_types.test_type_identification will repro if an
// enterprising reader wants to get this fixed.
[](const Replicate& repl) { return py::dict(); },
[](const py::dict&) { return Replicate(); }));
py::class_<Partial, Placement>(distributed_module, "Partial")
.def(py::init<>())
.def(py::init<std::optional<std::string>>(), py::arg("reduce_op"))
.def_readonly("reduce_op", &Partial::reduce_op)
.def(
"is_partial", &Partial::is_partial, py::arg("reduce_op") = py::none())
.def(
"__eq__",
[](const Partial& lhs, const Partial& rhs) { return lhs == rhs; },
py::is_operator())
.def(py::pickle(
[](const Partial& part) {
return py::dict("reduce_op"_a = part.reduce_op);
},
[](const py::dict& d) {
return Partial(py::cast<std::string>(d["reduce_op"]));
}));
}
} // namespace torch::distributed

View File

@ -1,7 +0,0 @@
#pragma once
#include <torch/csrc/utils/python_stub.h>
namespace torch::distributed {
void initPlacementBindings(PyObject* module);
} // namespace torch::distributed

View File

@ -83,22 +83,6 @@ class _MaskPartial(Partial):
offset_shape: Optional[torch.Size] = None
offset_dim: int = 0
def __init__(
self,
reduce_op=None,
mask_buffer=None,
offset_shape=None,
offset_dim=0,
*args,
**kwargs,
):
super().__init__(reduce_op)
if mask_buffer is None:
mask_buffer = MaskBuffer()
object.__setattr__(self, "mask_buffer", mask_buffer)
object.__setattr__(self, "offset_shape", offset_shape)
object.__setattr__(self, "offset_dim", offset_dim)
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:

View File

@ -73,19 +73,17 @@ class _NormPartial(Partial):
norm_type: Union[int, float, str] = 2
def __init__(self, norm_type: Union[int, float, str] = 2):
reduce_op = None
if norm_type in (float("inf"), "inf"):
reduce_op = "max"
elif norm_type in (float("-inf"), "-inf"):
reduce_op = "min"
elif isinstance(norm_type, (int, float)):
reduce_op = "sum"
def __post_init__(self):
"""Set the appropriate reduce op based on the norm type."""
# Use `object.__setattr__` to bypass frozen checks
if self.norm_type in (float("inf"), "inf"):
object.__setattr__(self, "reduce_op", "max")
elif self.norm_type in (float("-inf"), "-inf"):
object.__setattr__(self, "reduce_op", "min")
elif isinstance(self.norm_type, (int, float)):
object.__setattr__(self, "reduce_op", "sum")
else:
raise NotImplementedError(f"Unsupported norm type: {norm_type}")
super().__init__(reduce_op)
object.__setattr__(self, "norm_type", norm_type)
raise NotImplementedError(f"Unsupported norm type: {self.norm_type}")
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int

View File

@ -1,12 +1,11 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from dataclasses import dataclass
from typing import cast, Optional
import torch
import torch._C
import torch.distributed._functional_collectives as funcol
from torch._C._distributed import Placement
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._collective_utils import (
fill_empty_tensor_to_shards,
@ -21,11 +20,29 @@ from torch.distributed.tensor._collective_utils import (
__all__ = ["Placement", "Shard", "Replicate", "Partial"]
# Appease TestPublicBindings.test_correct_module_names
Placement.__module__ = "torch.distributed.tensor.placement_types"
class Placement:
"""
The base class for the Placement type, where it describes how a DTensor is placed onto the
``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout.
It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``,
and ``Partial``.
This class is not meant to be used directly, mainly served as a typing stub.
"""
# convenient utils to check for placement types
def is_shard(self, dim: Optional[int] = None) -> bool:
return False
def is_replicate(self) -> bool:
return False
def is_partial(self, reduce_op: Optional[str] = None) -> bool:
return False
class Shard(torch._C._distributed.Shard):
@dataclass(frozen=True)
class Shard(Placement):
"""
The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension
``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the
@ -43,6 +60,14 @@ class Shard(torch._C._distributed.Shard):
evenly divisible on a DeviceMesh dimension is currently experimental and subject to change.
"""
dim: int
def is_shard(self, dim: Optional[int] = None) -> bool:
if dim is not None:
return self.dim == dim
else:
return True
def _split_tensor(
self,
tensor: torch.Tensor,
@ -324,6 +349,11 @@ class Shard(torch._C._distributed.Shard):
return new_tensor
def __eq__(self, other: object) -> bool:
if not isinstance(other, Shard):
return False
return self.dim == other.dim
def __hash__(self) -> int:
return hash(self.dim)
@ -338,8 +368,8 @@ class Shard(torch._C._distributed.Shard):
return f"S({self.dim})"
# Need to inherit from Shard here so that isinstance(some_strided_shard, Shard) will work.
class _StridedShard(torch._C._distributed.StridedShard, Shard):
@dataclass(frozen=True, kw_only=True)
class _StridedShard(Shard):
"""
_StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor
is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension.
@ -397,6 +427,18 @@ class _StridedShard(torch._C._distributed.StridedShard, Shard):
TODO: we should remove _StridedShard placement once we can unify it with Shard
"""
split_factor: int
def __eq__(self, other: object) -> bool:
if isinstance(other, _StridedShard):
return self.dim == other.dim and self.split_factor == other.split_factor
elif isinstance(other, Shard):
# TODO: this is to avoid extra all-gather in dtensor op dispatch
# note that sharding prop would not produce _StridedShard and an
# placement inequality would introduce an all-gather for resharding
return self.dim == other.dim
return False
def __hash__(self) -> int:
return hash((self.dim, self.split_factor))
@ -543,7 +585,8 @@ class _StridedShard(torch._C._distributed.StridedShard, Shard):
return local_shard_size, None
class Replicate(torch._C._distributed.Replicate):
@dataclass(frozen=True)
class Replicate(Placement):
"""
The ``Replicate()`` placement describes the DTensor replicating on a corresponding
``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a
@ -551,6 +594,9 @@ class Replicate(torch._C._distributed.Replicate):
DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.)
"""
def __eq__(self, other: object) -> bool:
return isinstance(other, Replicate)
def __hash__(self) -> int:
# every replicate placement is the same
return -1
@ -590,8 +636,12 @@ class Replicate(torch._C._distributed.Replicate):
mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim, group_src=src_data_rank)
return tensor
def is_replicate(self) -> bool:
return True
class Partial(torch._C._distributed.Partial):
@dataclass(frozen=True)
class Partial(Placement):
"""
The ``Partial(reduce_op)`` placement describes the DTensor that is pending
reduction on a specified ``DeviceMesh`` dimension, where each rank on the
@ -610,6 +660,8 @@ class Partial(torch._C._distributed.Partial):
and can only be used by the ``DTensor.from_local`` API.
"""
reduce_op: str = "sum"
def _reduce_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
@ -646,6 +698,11 @@ class Partial(torch._C._distributed.Partial):
num_chunks = mesh.size(mesh_dim=mesh_dim)
return tensor / num_chunks
def __eq__(self, other: object) -> bool:
if not isinstance(other, Partial):
return False
return self.reduce_op == other.reduce_op
def __hash__(self) -> int:
return 1 + hash(self.reduce_op)
@ -661,6 +718,11 @@ class Partial(torch._C._distributed.Partial):
"""
return "P"
def is_partial(self, reduce_op: Optional[str] = None) -> bool:
if reduce_op is None:
return True
return self.reduce_op == reduce_op
# We keep the old _Partial name for a while for BC reason
_Partial = Partial