From 3e03deab6f3c268c85c8efd9546e28cdda0fa4cc Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 1 Oct 2025 15:29:16 -0700 Subject: [PATCH] C++-accessible Placements via pybind11 (#163030) This makes Placement data representation available in C++ via pybind11. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163030 Approved by: https://github.com/ezyang --- build_variables.bzl | 1 + .../tensor/test_placement_types.py | 88 +++++++++++++ torch/_C/_distributed.pyi | 20 +++ torch/_dynamo/variables/distributed.py | 11 +- torch/csrc/Module.cpp | 3 + torch/csrc/distributed/Placement.h | 121 ++++++++++++++++++ torch/csrc/distributed/python_placement.cpp | 104 +++++++++++++++ torch/csrc/distributed/python_placement.h | 7 + .../distributed/tensor/_ops/_embedding_ops.py | 16 +++ torch/distributed/tensor/_ops/_math_ops.py | 22 ++-- torch/distributed/tensor/placement_types.py | 80 ++---------- 11 files changed, 385 insertions(+), 88 deletions(-) create mode 100644 test/distributed/tensor/test_placement_types.py create mode 100644 torch/_C/_distributed.pyi create mode 100644 torch/csrc/distributed/Placement.h create mode 100644 torch/csrc/distributed/python_placement.cpp create mode 100644 torch/csrc/distributed/python_placement.h diff --git a/build_variables.bzl b/build_variables.bzl index e4dd849be4fe..570ba46c4eb3 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -913,6 +913,7 @@ libtorch_python_core_sources = [ "torch/csrc/autograd/python_torch_functions_manual.cpp", "torch/csrc/autograd/python_variable.cpp", "torch/csrc/autograd/python_variable_indexing.cpp", + "torch/csrc/distributed/python_placement.cpp", "torch/csrc/dynamo/python_compiled_autograd.cpp", "torch/csrc/dynamo/cache_entry.cpp", "torch/csrc/dynamo/cpp_shim.cpp", diff --git a/test/distributed/tensor/test_placement_types.py b/test/distributed/tensor/test_placement_types.py new file mode 100644 index 000000000000..4ed043757124 --- /dev/null +++ b/test/distributed/tensor/test_placement_types.py @@ -0,0 +1,88 @@ +# Owner(s): ["oncall: distributed"] +import copy +import itertools +import sys +import unittest + +from torch._dynamo.variables.distributed import PlacementClassVariable +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Partial, + Replicate, + Shard, +) +from torch.testing._internal.common_utils import run_tests, TestCase + + +# Basic functionality test for Placement types. +class PlacementTypesTestCase(TestCase): + def test_type_identification(self): + shard = Shard(3) + strided_shard = _StridedShard(dim=3, split_factor=7) + partial_sum = Partial("sum") + partial_max = Partial("max") + replicate = Replicate() + + ident_tests = ( + (shard, True, False, False), + (strided_shard, True, False, False), + (partial_sum, False, True, False), + (partial_max, False, True, False), + (replicate, False, False, True), + ) + for do_deepcopy in (False, True): + for placement, is_shard, is_partial, is_replicate in ident_tests: + if do_deepcopy: + placement = copy.deepcopy(placement) + self.assertEqual(placement.is_shard(), is_shard) + self.assertEqual(placement.is_partial(), is_partial) + self.assertEqual(placement.is_replicate(), is_replicate) + + def test_equality(self): + equivalence_classes = ( + (Shard(3), _StridedShard(dim=3, split_factor=7)), + (Shard(4), _StridedShard(dim=4, split_factor=9)), + (Replicate(),), + (Partial("sum"),), + (Partial("max"),), + ) + for eq_class in equivalence_classes: + # Each item in the equivalence class should be equal to every other item in + # its class. + for lhs, rhs in itertools.product(eq_class, eq_class): + self.assertEqual(lhs, rhs) + + # Each item in the equivalence class should not be equal to any item in any + # other class. + for other_class in equivalence_classes: + if other_class is eq_class: + continue + for lhs, rhs in itertools.product(eq_class, other_class): + self.assertNotEqual(lhs, rhs) + + # Testing this case doesn't seem to fit neatly into the above equivalence class + # framework. + self.assertNotEqual( + _StridedShard(dim=3, split_factor=1), _StridedShard(dim=3, split_factor=2) + ) + + @unittest.skipIf( + sys.version_info < (3, 10), "kw_only is only available in python >= 3.10" + ) + def test_strided_shard_kwonly_argument(self): + with self.assertRaises(TypeError): + _StridedShard(3, 4) + _StridedShard(3, split_factor=4) + + def test_strided_shard_isinstance_shard(self): + assert isinstance(_StridedShard(dim=3, split_factor=7), Shard) + + def test_dynamo_can_identify_placement_classes(self): + for cls in (Replicate, Shard, _StridedShard, Partial): + self.assertTrue( + PlacementClassVariable.is_placement_type(cls), msg=f"failed on {cls}" + ) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_C/_distributed.pyi b/torch/_C/_distributed.pyi new file mode 100644 index 000000000000..ab797c5d15d5 --- /dev/null +++ b/torch/_C/_distributed.pyi @@ -0,0 +1,20 @@ +# This module is defined in torch/csrc/distributed/python_placement.cpp + +class Placement: + def is_partial(self, reduce_op: str | None = None) -> bool: ... + def is_replicate(self) -> bool: ... + def is_shard(self, dim: int | None = None) -> bool: ... + +class Shard(Placement): + dim: int + def __init__(self, dim: int): ... + +class StridedShard(Shard): + split_factor: int + def __init__(self, dim: int, *, split_factor: int): ... + +class Replicate(Placement): ... + +class Partial(Placement): + reduce_op: str + def __init__(self, reduce_op: str | None = None): ... diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index 31aed55da1f2..7cf630798313 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -142,7 +142,7 @@ class PlacementClassVariable(DistributedVariable): from torch.distributed.tensor.placement_types import Placement - return type(value) is type and issubclass(value, Placement) + return isinstance(value, type) and issubclass(value, Placement) def as_python_constant(self): return self.value @@ -153,13 +153,10 @@ class PlacementClassVariable(DistributedVariable): args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": - if ( - inspect.getattr_static(self.value, "__new__", None) == object.__new__ - and self.source - ): + if self.source: # NOTE: we don't need to track mutations to the placement class as they - # suppose to be immutable. - new_obj = object.__new__(self.value) + # are supposed to be immutable. + new_obj = self.value.__new__(self.value) var = PlacementVariable(new_obj) if inspect.getattr_static(self.value, "__init__", None): var.call_method(tx, "__init__", args, kwargs) diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 126333fdcc4d..4ea9981577a3 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -71,6 +71,7 @@ #include #include #include +#include #include #include #include @@ -2119,6 +2120,8 @@ PyObject* initModule() { THXPEvent_init(module); #endif + torch::distributed::initPlacementBindings(module); + auto set_module_attr = [&](const char* name, PyObject* v, bool incref = true) { // PyModule_AddObject steals reference diff --git a/torch/csrc/distributed/Placement.h b/torch/csrc/distributed/Placement.h new file mode 100644 index 000000000000..3b36f04ee646 --- /dev/null +++ b/torch/csrc/distributed/Placement.h @@ -0,0 +1,121 @@ +#pragma once +/** + * The implementations in this file are coupled with + * torch/distributed/tensor/placement_types.py. + */ + +#include +#include +#include +#include + +namespace torch::distributed { + +class Placement { + public: + Placement() = default; + virtual ~Placement() = default; + + Placement(const Placement&) = default; + Placement& operator=(const Placement&) = default; + Placement(Placement&&) noexcept = default; + Placement& operator=(Placement&&) noexcept = default; + + virtual bool is_shard(std::optional dim) const { + return false; + } + + virtual bool is_replicate() const { + return false; + } + + virtual bool is_partial( + std::optional reduce_op = std::nullopt) const { + return false; + } +}; + +class Shard : public Placement { + public: + std::int64_t dim; + explicit Shard(std::int64_t dim_) : dim(dim_) {} + + bool is_shard(std::optional dim_) const override { + return !dim_.has_value() || *dim_ == dim; + } + + bool operator==(const Shard& rhs) const { + return dim == rhs.dim; + } + + bool operator!=(const Shard& rhs) const { + return !operator==(rhs); + } +}; + +class StridedShard : public Shard { + public: + std::int64_t split_factor; + explicit StridedShard(std::int64_t dim, std::int64_t split_factor_) + : Shard(dim), split_factor(split_factor_) {} + + bool operator==(const StridedShard& rhs) const { + return dim == rhs.dim && split_factor == rhs.split_factor; + } + + bool operator==(const Shard& rhs) const { + if (auto* rhs_strided = dynamic_cast(&rhs)) { + return operator==(*rhs_strided); + } + // TODO: this is to avoid extra all-gather in dtensor op dispatch + // note that sharding prop would not produce _StridedShard and a + // placement inequality would introduce an all-gather for resharding + return dim == rhs.dim; + } + + bool operator!=(const Shard& rhs) const { + return !operator==(rhs); + } +}; + +class Replicate : public Placement { + public: + bool is_replicate() const override { + return true; + } + + bool operator==(const Replicate& rhs) const { + return true; + } + + bool operator!=(const Replicate& rhs) const { + return false; + } +}; + +class Partial : public Placement { + public: + std::string reduce_op; + + Partial() : Partial("sum") {} + + explicit Partial(std::optional reduce_op_) + : reduce_op( + reduce_op_.has_value() ? std::move(*reduce_op_) + : std::string("sum")) {} + + bool is_partial( + std::optional op = std::nullopt) const override { + return !op.has_value() || *op == reduce_op; + } + + bool operator==(const Partial& rhs) const { + return reduce_op == rhs.reduce_op; + } + + bool operator!=(const Partial& rhs) const { + return !operator==(rhs); + } +}; + +} // namespace torch::distributed diff --git a/torch/csrc/distributed/python_placement.cpp b/torch/csrc/distributed/python_placement.cpp new file mode 100644 index 000000000000..df5825d8d8e4 --- /dev/null +++ b/torch/csrc/distributed/python_placement.cpp @@ -0,0 +1,104 @@ +#include + +#include +#include +#include + +using namespace pybind11::literals; + +namespace torch::distributed { +namespace { +const auto placement_class_docstring = + R"(The base class for the Placement type, where it describes how a DTensor is placed onto the +``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout. +It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``, +and ``Partial``. + +This class is not meant to be used directly, mainly served as a typing stub. +)"; +} // namespace + +void initPlacementBindings(PyObject* module) { + auto py_module = py::reinterpret_borrow(module); + auto distributed_module = py_module.def_submodule("_distributed"); + py::class_( + distributed_module, "Placement", placement_class_docstring) + .def(py::init<>()) // Allow construction of Python subclasses. + .def( + "is_partial", + &Placement::is_partial, + py::arg("reduce_op") = py::none()) + .def("is_replicate", &Placement::is_replicate) + .def("is_shard", &Placement::is_shard, py::arg("dim") = py::none()); + py::class_(distributed_module, "Shard") + .def(py::init(), py::arg("dim")) + .def_readonly("dim", &Shard::dim) + .def("is_shard", &Shard::is_shard, py::arg("dim") = py::none()) + .def( + "__eq__", + [](const Shard& lhs, const Shard& rhs) { return lhs == rhs; }, + py::is_operator()) + // Note: we need to use dicts for pickling to match the old + // dataclasses. + .def(py::pickle( + [](const Shard& shard) { return py::dict("dim"_a = shard.dim); }, + [](const py::dict& d) { + return Shard(py::cast(d["dim"])); + })); + py::class_(distributed_module, "StridedShard") + .def( + py::init(), + py::arg("dim"), + py::kw_only(), + py::arg("split_factor")) + .def_readonly("split_factor", &StridedShard::split_factor) + .def("is_shard", &StridedShard::is_shard, py::arg("dim") = py::none()) + .def( + "__eq__", + [](const StridedShard& lhs, const Shard& rhs) { return lhs == rhs; }, + py::is_operator()) + .def(py::pickle( + [](const StridedShard& shard) { + return py::dict( + "dim"_a = shard.dim, "split_factor"_a = shard.split_factor); + }, + [](const py::dict& d) { + return StridedShard( + py::cast(d["dim"]), + py::cast(d["split_factor"])); + })); + py::class_(distributed_module, "Replicate") + .def(py::init()) + .def("is_replicate", &Replicate::is_replicate) + .def( + "__eq__", + [](const Replicate& lhs, const Replicate& rhs) { return lhs == rhs; }, + py::is_operator()) + .def(py::pickle( + // I observed SIGSEGV when trying to use None as the + // pickled state, though AFAICT that matches the + // behavior of + // object().__reduce__(). + // test_placement_types.test_type_identification will repro if an + // enterprising reader wants to get this fixed. + [](const Replicate& repl) { return py::dict(); }, + [](const py::dict&) { return Replicate(); })); + py::class_(distributed_module, "Partial") + .def(py::init<>()) + .def(py::init>(), py::arg("reduce_op")) + .def_readonly("reduce_op", &Partial::reduce_op) + .def( + "is_partial", &Partial::is_partial, py::arg("reduce_op") = py::none()) + .def( + "__eq__", + [](const Partial& lhs, const Partial& rhs) { return lhs == rhs; }, + py::is_operator()) + .def(py::pickle( + [](const Partial& part) { + return py::dict("reduce_op"_a = part.reduce_op); + }, + [](const py::dict& d) { + return Partial(py::cast(d["reduce_op"])); + })); +} +} // namespace torch::distributed diff --git a/torch/csrc/distributed/python_placement.h b/torch/csrc/distributed/python_placement.h new file mode 100644 index 000000000000..dc88341db4e4 --- /dev/null +++ b/torch/csrc/distributed/python_placement.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace torch::distributed { +void initPlacementBindings(PyObject* module); +} // namespace torch::distributed diff --git a/torch/distributed/tensor/_ops/_embedding_ops.py b/torch/distributed/tensor/_ops/_embedding_ops.py index 445b1830defe..6730a740ba99 100644 --- a/torch/distributed/tensor/_ops/_embedding_ops.py +++ b/torch/distributed/tensor/_ops/_embedding_ops.py @@ -83,6 +83,22 @@ class _MaskPartial(Partial): offset_shape: Optional[torch.Size] = None offset_dim: int = 0 + def __init__( + self, + reduce_op=None, + mask_buffer=None, + offset_shape=None, + offset_dim=0, + *args, + **kwargs, + ): + super().__init__(reduce_op) + if mask_buffer is None: + mask_buffer = MaskBuffer() + object.__setattr__(self, "mask_buffer", mask_buffer) + object.__setattr__(self, "offset_shape", offset_shape) + object.__setattr__(self, "offset_dim", offset_dim) + def _partition_value( self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int ) -> torch.Tensor: diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 2daf9298637f..bb19a26d9378 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -73,17 +73,19 @@ class _NormPartial(Partial): norm_type: Union[int, float, str] = 2 - def __post_init__(self): - """Set the appropriate reduce op based on the norm type.""" - # Use `object.__setattr__` to bypass frozen checks - if self.norm_type in (float("inf"), "inf"): - object.__setattr__(self, "reduce_op", "max") - elif self.norm_type in (float("-inf"), "-inf"): - object.__setattr__(self, "reduce_op", "min") - elif isinstance(self.norm_type, (int, float)): - object.__setattr__(self, "reduce_op", "sum") + def __init__(self, norm_type: Union[int, float, str] = 2): + reduce_op = None + if norm_type in (float("inf"), "inf"): + reduce_op = "max" + elif norm_type in (float("-inf"), "-inf"): + reduce_op = "min" + elif isinstance(norm_type, (int, float)): + reduce_op = "sum" else: - raise NotImplementedError(f"Unsupported norm type: {self.norm_type}") + raise NotImplementedError(f"Unsupported norm type: {norm_type}") + + super().__init__(reduce_op) + object.__setattr__(self, "norm_type", norm_type) def _partition_value( self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index ad304229a278..1d6827a95724 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -1,11 +1,12 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates -from dataclasses import dataclass from typing import cast, Optional import torch +import torch._C import torch.distributed._functional_collectives as funcol +from torch._C._distributed import Placement from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._collective_utils import ( fill_empty_tensor_to_shards, @@ -20,29 +21,11 @@ from torch.distributed.tensor._collective_utils import ( __all__ = ["Placement", "Shard", "Replicate", "Partial"] -class Placement: - """ - The base class for the Placement type, where it describes how a DTensor is placed onto the - ``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout. - It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``, - and ``Partial``. - - This class is not meant to be used directly, mainly served as a typing stub. - """ - - # convenient utils to check for placement types - def is_shard(self, dim: Optional[int] = None) -> bool: - return False - - def is_replicate(self) -> bool: - return False - - def is_partial(self, reduce_op: Optional[str] = None) -> bool: - return False +# Appease TestPublicBindings.test_correct_module_names +Placement.__module__ = "torch.distributed.tensor.placement_types" -@dataclass(frozen=True) -class Shard(Placement): +class Shard(torch._C._distributed.Shard): """ The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension ``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the @@ -60,14 +43,6 @@ class Shard(Placement): evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. """ - dim: int - - def is_shard(self, dim: Optional[int] = None) -> bool: - if dim is not None: - return self.dim == dim - else: - return True - def _split_tensor( self, tensor: torch.Tensor, @@ -349,11 +324,6 @@ class Shard(Placement): return new_tensor - def __eq__(self, other: object) -> bool: - if not isinstance(other, Shard): - return False - return self.dim == other.dim - def __hash__(self) -> int: return hash(self.dim) @@ -368,8 +338,8 @@ class Shard(Placement): return f"S({self.dim})" -@dataclass(frozen=True, kw_only=True) -class _StridedShard(Shard): +# Need to inherit from Shard here so that isinstance(some_strided_shard, Shard) will work. +class _StridedShard(torch._C._distributed.StridedShard, Shard): """ _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension. @@ -427,18 +397,6 @@ class _StridedShard(Shard): TODO: we should remove _StridedShard placement once we can unify it with Shard """ - split_factor: int - - def __eq__(self, other: object) -> bool: - if isinstance(other, _StridedShard): - return self.dim == other.dim and self.split_factor == other.split_factor - elif isinstance(other, Shard): - # TODO: this is to avoid extra all-gather in dtensor op dispatch - # note that sharding prop would not produce _StridedShard and an - # placement inequality would introduce an all-gather for resharding - return self.dim == other.dim - return False - def __hash__(self) -> int: return hash((self.dim, self.split_factor)) @@ -585,8 +543,7 @@ class _StridedShard(Shard): return local_shard_size, None -@dataclass(frozen=True) -class Replicate(Placement): +class Replicate(torch._C._distributed.Replicate): """ The ``Replicate()`` placement describes the DTensor replicating on a corresponding ``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a @@ -594,9 +551,6 @@ class Replicate(Placement): DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.) """ - def __eq__(self, other: object) -> bool: - return isinstance(other, Replicate) - def __hash__(self) -> int: # every replicate placement is the same return -1 @@ -636,12 +590,8 @@ class Replicate(Placement): mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim, group_src=src_data_rank) return tensor - def is_replicate(self) -> bool: - return True - -@dataclass(frozen=True) -class Partial(Placement): +class Partial(torch._C._distributed.Partial): """ The ``Partial(reduce_op)`` placement describes the DTensor that is pending reduction on a specified ``DeviceMesh`` dimension, where each rank on the @@ -660,8 +610,6 @@ class Partial(Placement): and can only be used by the ``DTensor.from_local`` API. """ - reduce_op: str = "sum" - def _reduce_value( self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int ) -> torch.Tensor: @@ -698,11 +646,6 @@ class Partial(Placement): num_chunks = mesh.size(mesh_dim=mesh_dim) return tensor / num_chunks - def __eq__(self, other: object) -> bool: - if not isinstance(other, Partial): - return False - return self.reduce_op == other.reduce_op - def __hash__(self) -> int: return 1 + hash(self.reduce_op) @@ -718,11 +661,6 @@ class Partial(Placement): """ return "P" - def is_partial(self, reduce_op: Optional[str] = None) -> bool: - if reduce_op is None: - return True - return self.reduce_op == reduce_op - # We keep the old _Partial name for a while for BC reason _Partial = Partial