C++-accessible Placements via pybind11 (#163030)

This makes Placement data representation available in C++ via pybind11.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163030
Approved by: https://github.com/ezyang
This commit is contained in:
Scott Wolchok
2025-10-01 15:29:16 -07:00
committed by PyTorch MergeBot
parent 349e9e922d
commit 3e03deab6f
11 changed files with 385 additions and 88 deletions

View File

@ -913,6 +913,7 @@ libtorch_python_core_sources = [
"torch/csrc/autograd/python_torch_functions_manual.cpp", "torch/csrc/autograd/python_torch_functions_manual.cpp",
"torch/csrc/autograd/python_variable.cpp", "torch/csrc/autograd/python_variable.cpp",
"torch/csrc/autograd/python_variable_indexing.cpp", "torch/csrc/autograd/python_variable_indexing.cpp",
"torch/csrc/distributed/python_placement.cpp",
"torch/csrc/dynamo/python_compiled_autograd.cpp", "torch/csrc/dynamo/python_compiled_autograd.cpp",
"torch/csrc/dynamo/cache_entry.cpp", "torch/csrc/dynamo/cache_entry.cpp",
"torch/csrc/dynamo/cpp_shim.cpp", "torch/csrc/dynamo/cpp_shim.cpp",

View File

@ -0,0 +1,88 @@
# Owner(s): ["oncall: distributed"]
import copy
import itertools
import sys
import unittest
from torch._dynamo.variables.distributed import PlacementClassVariable
from torch.distributed.tensor.placement_types import (
_StridedShard,
Partial,
Replicate,
Shard,
)
from torch.testing._internal.common_utils import run_tests, TestCase
# Basic functionality test for Placement types.
class PlacementTypesTestCase(TestCase):
def test_type_identification(self):
shard = Shard(3)
strided_shard = _StridedShard(dim=3, split_factor=7)
partial_sum = Partial("sum")
partial_max = Partial("max")
replicate = Replicate()
ident_tests = (
(shard, True, False, False),
(strided_shard, True, False, False),
(partial_sum, False, True, False),
(partial_max, False, True, False),
(replicate, False, False, True),
)
for do_deepcopy in (False, True):
for placement, is_shard, is_partial, is_replicate in ident_tests:
if do_deepcopy:
placement = copy.deepcopy(placement)
self.assertEqual(placement.is_shard(), is_shard)
self.assertEqual(placement.is_partial(), is_partial)
self.assertEqual(placement.is_replicate(), is_replicate)
def test_equality(self):
equivalence_classes = (
(Shard(3), _StridedShard(dim=3, split_factor=7)),
(Shard(4), _StridedShard(dim=4, split_factor=9)),
(Replicate(),),
(Partial("sum"),),
(Partial("max"),),
)
for eq_class in equivalence_classes:
# Each item in the equivalence class should be equal to every other item in
# its class.
for lhs, rhs in itertools.product(eq_class, eq_class):
self.assertEqual(lhs, rhs)
# Each item in the equivalence class should not be equal to any item in any
# other class.
for other_class in equivalence_classes:
if other_class is eq_class:
continue
for lhs, rhs in itertools.product(eq_class, other_class):
self.assertNotEqual(lhs, rhs)
# Testing this case doesn't seem to fit neatly into the above equivalence class
# framework.
self.assertNotEqual(
_StridedShard(dim=3, split_factor=1), _StridedShard(dim=3, split_factor=2)
)
@unittest.skipIf(
sys.version_info < (3, 10), "kw_only is only available in python >= 3.10"
)
def test_strided_shard_kwonly_argument(self):
with self.assertRaises(TypeError):
_StridedShard(3, 4)
_StridedShard(3, split_factor=4)
def test_strided_shard_isinstance_shard(self):
assert isinstance(_StridedShard(dim=3, split_factor=7), Shard)
def test_dynamo_can_identify_placement_classes(self):
for cls in (Replicate, Shard, _StridedShard, Partial):
self.assertTrue(
PlacementClassVariable.is_placement_type(cls), msg=f"failed on {cls}"
)
if __name__ == "__main__":
run_tests()

20
torch/_C/_distributed.pyi Normal file
View File

@ -0,0 +1,20 @@
# This module is defined in torch/csrc/distributed/python_placement.cpp
class Placement:
def is_partial(self, reduce_op: str | None = None) -> bool: ...
def is_replicate(self) -> bool: ...
def is_shard(self, dim: int | None = None) -> bool: ...
class Shard(Placement):
dim: int
def __init__(self, dim: int): ...
class StridedShard(Shard):
split_factor: int
def __init__(self, dim: int, *, split_factor: int): ...
class Replicate(Placement): ...
class Partial(Placement):
reduce_op: str
def __init__(self, reduce_op: str | None = None): ...

View File

@ -142,7 +142,7 @@ class PlacementClassVariable(DistributedVariable):
from torch.distributed.tensor.placement_types import Placement from torch.distributed.tensor.placement_types import Placement
return type(value) is type and issubclass(value, Placement) return isinstance(value, type) and issubclass(value, Placement)
def as_python_constant(self): def as_python_constant(self):
return self.value return self.value
@ -153,13 +153,10 @@ class PlacementClassVariable(DistributedVariable):
args: "list[VariableTracker]", args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]", kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker": ) -> "VariableTracker":
if ( if self.source:
inspect.getattr_static(self.value, "__new__", None) == object.__new__
and self.source
):
# NOTE: we don't need to track mutations to the placement class as they # NOTE: we don't need to track mutations to the placement class as they
# suppose to be immutable. # are supposed to be immutable.
new_obj = object.__new__(self.value) new_obj = self.value.__new__(self.value)
var = PlacementVariable(new_obj) var = PlacementVariable(new_obj)
if inspect.getattr_static(self.value, "__init__", None): if inspect.getattr_static(self.value, "__init__", None):
var.call_method(tx, "__init__", args, kwargs) var.call_method(tx, "__init__", args, kwargs)

View File

@ -71,6 +71,7 @@
#include <torch/csrc/autograd/python_special_functions.h> #include <torch/csrc/autograd/python_special_functions.h>
#include <torch/csrc/autograd/python_variable.h> #include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/cpu/Module.h> #include <torch/csrc/cpu/Module.h>
#include <torch/csrc/distributed/python_placement.h>
#include <torch/csrc/dynamo/init.h> #include <torch/csrc/dynamo/init.h>
#include <torch/csrc/export/pybind.h> #include <torch/csrc/export/pybind.h>
#include <torch/csrc/functionalization/Module.h> #include <torch/csrc/functionalization/Module.h>
@ -2119,6 +2120,8 @@ PyObject* initModule() {
THXPEvent_init(module); THXPEvent_init(module);
#endif #endif
torch::distributed::initPlacementBindings(module);
auto set_module_attr = auto set_module_attr =
[&](const char* name, PyObject* v, bool incref = true) { [&](const char* name, PyObject* v, bool incref = true) {
// PyModule_AddObject steals reference // PyModule_AddObject steals reference

View File

@ -0,0 +1,121 @@
#pragma once
/**
* The implementations in this file are coupled with
* torch/distributed/tensor/placement_types.py.
*/
#include <cstdint>
#include <optional>
#include <string>
#include <string_view>
namespace torch::distributed {
class Placement {
public:
Placement() = default;
virtual ~Placement() = default;
Placement(const Placement&) = default;
Placement& operator=(const Placement&) = default;
Placement(Placement&&) noexcept = default;
Placement& operator=(Placement&&) noexcept = default;
virtual bool is_shard(std::optional<std::int64_t> dim) const {
return false;
}
virtual bool is_replicate() const {
return false;
}
virtual bool is_partial(
std::optional<std::string_view> reduce_op = std::nullopt) const {
return false;
}
};
class Shard : public Placement {
public:
std::int64_t dim;
explicit Shard(std::int64_t dim_) : dim(dim_) {}
bool is_shard(std::optional<std::int64_t> dim_) const override {
return !dim_.has_value() || *dim_ == dim;
}
bool operator==(const Shard& rhs) const {
return dim == rhs.dim;
}
bool operator!=(const Shard& rhs) const {
return !operator==(rhs);
}
};
class StridedShard : public Shard {
public:
std::int64_t split_factor;
explicit StridedShard(std::int64_t dim, std::int64_t split_factor_)
: Shard(dim), split_factor(split_factor_) {}
bool operator==(const StridedShard& rhs) const {
return dim == rhs.dim && split_factor == rhs.split_factor;
}
bool operator==(const Shard& rhs) const {
if (auto* rhs_strided = dynamic_cast<const StridedShard*>(&rhs)) {
return operator==(*rhs_strided);
}
// TODO: this is to avoid extra all-gather in dtensor op dispatch
// note that sharding prop would not produce _StridedShard and a
// placement inequality would introduce an all-gather for resharding
return dim == rhs.dim;
}
bool operator!=(const Shard& rhs) const {
return !operator==(rhs);
}
};
class Replicate : public Placement {
public:
bool is_replicate() const override {
return true;
}
bool operator==(const Replicate& rhs) const {
return true;
}
bool operator!=(const Replicate& rhs) const {
return false;
}
};
class Partial : public Placement {
public:
std::string reduce_op;
Partial() : Partial("sum") {}
explicit Partial(std::optional<std::string> reduce_op_)
: reduce_op(
reduce_op_.has_value() ? std::move(*reduce_op_)
: std::string("sum")) {}
bool is_partial(
std::optional<std::string_view> op = std::nullopt) const override {
return !op.has_value() || *op == reduce_op;
}
bool operator==(const Partial& rhs) const {
return reduce_op == rhs.reduce_op;
}
bool operator!=(const Partial& rhs) const {
return !operator==(rhs);
}
};
} // namespace torch::distributed

View File

@ -0,0 +1,104 @@
#include <torch/csrc/distributed/python_placement.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/distributed/Placement.h>
#include <torch/csrc/utils/pybind.h>
using namespace pybind11::literals;
namespace torch::distributed {
namespace {
const auto placement_class_docstring =
R"(The base class for the Placement type, where it describes how a DTensor is placed onto the
``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout.
It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``,
and ``Partial``.
This class is not meant to be used directly, mainly served as a typing stub.
)";
} // namespace
void initPlacementBindings(PyObject* module) {
auto py_module = py::reinterpret_borrow<py::module>(module);
auto distributed_module = py_module.def_submodule("_distributed");
py::class_<Placement>(
distributed_module, "Placement", placement_class_docstring)
.def(py::init<>()) // Allow construction of Python subclasses.
.def(
"is_partial",
&Placement::is_partial,
py::arg("reduce_op") = py::none())
.def("is_replicate", &Placement::is_replicate)
.def("is_shard", &Placement::is_shard, py::arg("dim") = py::none());
py::class_<Shard, Placement>(distributed_module, "Shard")
.def(py::init<int64_t>(), py::arg("dim"))
.def_readonly("dim", &Shard::dim)
.def("is_shard", &Shard::is_shard, py::arg("dim") = py::none())
.def(
"__eq__",
[](const Shard& lhs, const Shard& rhs) { return lhs == rhs; },
py::is_operator())
// Note: we need to use dicts for pickling to match the old
// dataclasses.
.def(py::pickle(
[](const Shard& shard) { return py::dict("dim"_a = shard.dim); },
[](const py::dict& d) {
return Shard(py::cast<int64_t>(d["dim"]));
}));
py::class_<StridedShard, Shard>(distributed_module, "StridedShard")
.def(
py::init<int64_t, int64_t>(),
py::arg("dim"),
py::kw_only(),
py::arg("split_factor"))
.def_readonly("split_factor", &StridedShard::split_factor)
.def("is_shard", &StridedShard::is_shard, py::arg("dim") = py::none())
.def(
"__eq__",
[](const StridedShard& lhs, const Shard& rhs) { return lhs == rhs; },
py::is_operator())
.def(py::pickle(
[](const StridedShard& shard) {
return py::dict(
"dim"_a = shard.dim, "split_factor"_a = shard.split_factor);
},
[](const py::dict& d) {
return StridedShard(
py::cast<int64_t>(d["dim"]),
py::cast<int64_t>(d["split_factor"]));
}));
py::class_<Replicate, Placement>(distributed_module, "Replicate")
.def(py::init())
.def("is_replicate", &Replicate::is_replicate)
.def(
"__eq__",
[](const Replicate& lhs, const Replicate& rhs) { return lhs == rhs; },
py::is_operator())
.def(py::pickle(
// I observed SIGSEGV when trying to use None as the
// pickled state, though AFAICT that matches the
// behavior of
// object().__reduce__().
// test_placement_types.test_type_identification will repro if an
// enterprising reader wants to get this fixed.
[](const Replicate& repl) { return py::dict(); },
[](const py::dict&) { return Replicate(); }));
py::class_<Partial, Placement>(distributed_module, "Partial")
.def(py::init<>())
.def(py::init<std::optional<std::string>>(), py::arg("reduce_op"))
.def_readonly("reduce_op", &Partial::reduce_op)
.def(
"is_partial", &Partial::is_partial, py::arg("reduce_op") = py::none())
.def(
"__eq__",
[](const Partial& lhs, const Partial& rhs) { return lhs == rhs; },
py::is_operator())
.def(py::pickle(
[](const Partial& part) {
return py::dict("reduce_op"_a = part.reduce_op);
},
[](const py::dict& d) {
return Partial(py::cast<std::string>(d["reduce_op"]));
}));
}
} // namespace torch::distributed

View File

@ -0,0 +1,7 @@
#pragma once
#include <torch/csrc/utils/python_stub.h>
namespace torch::distributed {
void initPlacementBindings(PyObject* module);
} // namespace torch::distributed

View File

@ -83,6 +83,22 @@ class _MaskPartial(Partial):
offset_shape: Optional[torch.Size] = None offset_shape: Optional[torch.Size] = None
offset_dim: int = 0 offset_dim: int = 0
def __init__(
self,
reduce_op=None,
mask_buffer=None,
offset_shape=None,
offset_dim=0,
*args,
**kwargs,
):
super().__init__(reduce_op)
if mask_buffer is None:
mask_buffer = MaskBuffer()
object.__setattr__(self, "mask_buffer", mask_buffer)
object.__setattr__(self, "offset_shape", offset_shape)
object.__setattr__(self, "offset_dim", offset_dim)
def _partition_value( def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor: ) -> torch.Tensor:

View File

@ -73,17 +73,19 @@ class _NormPartial(Partial):
norm_type: Union[int, float, str] = 2 norm_type: Union[int, float, str] = 2
def __post_init__(self): def __init__(self, norm_type: Union[int, float, str] = 2):
"""Set the appropriate reduce op based on the norm type.""" reduce_op = None
# Use `object.__setattr__` to bypass frozen checks if norm_type in (float("inf"), "inf"):
if self.norm_type in (float("inf"), "inf"): reduce_op = "max"
object.__setattr__(self, "reduce_op", "max") elif norm_type in (float("-inf"), "-inf"):
elif self.norm_type in (float("-inf"), "-inf"): reduce_op = "min"
object.__setattr__(self, "reduce_op", "min") elif isinstance(norm_type, (int, float)):
elif isinstance(self.norm_type, (int, float)): reduce_op = "sum"
object.__setattr__(self, "reduce_op", "sum")
else: else:
raise NotImplementedError(f"Unsupported norm type: {self.norm_type}") raise NotImplementedError(f"Unsupported norm type: {norm_type}")
super().__init__(reduce_op)
object.__setattr__(self, "norm_type", norm_type)
def _partition_value( def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int

View File

@ -1,11 +1,12 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
from dataclasses import dataclass
from typing import cast, Optional from typing import cast, Optional
import torch import torch
import torch._C
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
from torch._C._distributed import Placement
from torch.distributed.device_mesh import DeviceMesh from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._collective_utils import ( from torch.distributed.tensor._collective_utils import (
fill_empty_tensor_to_shards, fill_empty_tensor_to_shards,
@ -20,29 +21,11 @@ from torch.distributed.tensor._collective_utils import (
__all__ = ["Placement", "Shard", "Replicate", "Partial"] __all__ = ["Placement", "Shard", "Replicate", "Partial"]
class Placement: # Appease TestPublicBindings.test_correct_module_names
""" Placement.__module__ = "torch.distributed.tensor.placement_types"
The base class for the Placement type, where it describes how a DTensor is placed onto the
``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout.
It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``,
and ``Partial``.
This class is not meant to be used directly, mainly served as a typing stub.
"""
# convenient utils to check for placement types
def is_shard(self, dim: Optional[int] = None) -> bool:
return False
def is_replicate(self) -> bool:
return False
def is_partial(self, reduce_op: Optional[str] = None) -> bool:
return False
@dataclass(frozen=True) class Shard(torch._C._distributed.Shard):
class Shard(Placement):
""" """
The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension
``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the ``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the
@ -60,14 +43,6 @@ class Shard(Placement):
evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. evenly divisible on a DeviceMesh dimension is currently experimental and subject to change.
""" """
dim: int
def is_shard(self, dim: Optional[int] = None) -> bool:
if dim is not None:
return self.dim == dim
else:
return True
def _split_tensor( def _split_tensor(
self, self,
tensor: torch.Tensor, tensor: torch.Tensor,
@ -349,11 +324,6 @@ class Shard(Placement):
return new_tensor return new_tensor
def __eq__(self, other: object) -> bool:
if not isinstance(other, Shard):
return False
return self.dim == other.dim
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(self.dim) return hash(self.dim)
@ -368,8 +338,8 @@ class Shard(Placement):
return f"S({self.dim})" return f"S({self.dim})"
@dataclass(frozen=True, kw_only=True) # Need to inherit from Shard here so that isinstance(some_strided_shard, Shard) will work.
class _StridedShard(Shard): class _StridedShard(torch._C._distributed.StridedShard, Shard):
""" """
_StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor
is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension. is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension.
@ -427,18 +397,6 @@ class _StridedShard(Shard):
TODO: we should remove _StridedShard placement once we can unify it with Shard TODO: we should remove _StridedShard placement once we can unify it with Shard
""" """
split_factor: int
def __eq__(self, other: object) -> bool:
if isinstance(other, _StridedShard):
return self.dim == other.dim and self.split_factor == other.split_factor
elif isinstance(other, Shard):
# TODO: this is to avoid extra all-gather in dtensor op dispatch
# note that sharding prop would not produce _StridedShard and an
# placement inequality would introduce an all-gather for resharding
return self.dim == other.dim
return False
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((self.dim, self.split_factor)) return hash((self.dim, self.split_factor))
@ -585,8 +543,7 @@ class _StridedShard(Shard):
return local_shard_size, None return local_shard_size, None
@dataclass(frozen=True) class Replicate(torch._C._distributed.Replicate):
class Replicate(Placement):
""" """
The ``Replicate()`` placement describes the DTensor replicating on a corresponding The ``Replicate()`` placement describes the DTensor replicating on a corresponding
``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a ``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a
@ -594,9 +551,6 @@ class Replicate(Placement):
DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.) DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.)
""" """
def __eq__(self, other: object) -> bool:
return isinstance(other, Replicate)
def __hash__(self) -> int: def __hash__(self) -> int:
# every replicate placement is the same # every replicate placement is the same
return -1 return -1
@ -636,12 +590,8 @@ class Replicate(Placement):
mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim, group_src=src_data_rank) mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim, group_src=src_data_rank)
return tensor return tensor
def is_replicate(self) -> bool:
return True
class Partial(torch._C._distributed.Partial):
@dataclass(frozen=True)
class Partial(Placement):
""" """
The ``Partial(reduce_op)`` placement describes the DTensor that is pending The ``Partial(reduce_op)`` placement describes the DTensor that is pending
reduction on a specified ``DeviceMesh`` dimension, where each rank on the reduction on a specified ``DeviceMesh`` dimension, where each rank on the
@ -660,8 +610,6 @@ class Partial(Placement):
and can only be used by the ``DTensor.from_local`` API. and can only be used by the ``DTensor.from_local`` API.
""" """
reduce_op: str = "sum"
def _reduce_value( def _reduce_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor: ) -> torch.Tensor:
@ -698,11 +646,6 @@ class Partial(Placement):
num_chunks = mesh.size(mesh_dim=mesh_dim) num_chunks = mesh.size(mesh_dim=mesh_dim)
return tensor / num_chunks return tensor / num_chunks
def __eq__(self, other: object) -> bool:
if not isinstance(other, Partial):
return False
return self.reduce_op == other.reduce_op
def __hash__(self) -> int: def __hash__(self) -> int:
return 1 + hash(self.reduce_op) return 1 + hash(self.reduce_op)
@ -718,11 +661,6 @@ class Partial(Placement):
""" """
return "P" return "P"
def is_partial(self, reduce_op: Optional[str] = None) -> bool:
if reduce_op is None:
return True
return self.reduce_op == reduce_op
# We keep the old _Partial name for a while for BC reason # We keep the old _Partial name for a while for BC reason
_Partial = Partial _Partial = Partial