Files
pytorch/torch/csrc/distributed/Placement.h
Scott Wolchok 3e03deab6f C++-accessible Placements via pybind11 (#163030)
This makes Placement data representation available in C++ via pybind11.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163030
Approved by: https://github.com/ezyang
2025-10-02 02:38:23 +00:00

122 lines
2.8 KiB
C++

#pragma once
/**
* The implementations in this file are coupled with
* torch/distributed/tensor/placement_types.py.
*/
#include <cstdint>
#include <optional>
#include <string>
#include <string_view>
namespace torch::distributed {
class Placement {
public:
Placement() = default;
virtual ~Placement() = default;
Placement(const Placement&) = default;
Placement& operator=(const Placement&) = default;
Placement(Placement&&) noexcept = default;
Placement& operator=(Placement&&) noexcept = default;
virtual bool is_shard(std::optional<std::int64_t> dim) const {
return false;
}
virtual bool is_replicate() const {
return false;
}
virtual bool is_partial(
std::optional<std::string_view> reduce_op = std::nullopt) const {
return false;
}
};
class Shard : public Placement {
public:
std::int64_t dim;
explicit Shard(std::int64_t dim_) : dim(dim_) {}
bool is_shard(std::optional<std::int64_t> dim_) const override {
return !dim_.has_value() || *dim_ == dim;
}
bool operator==(const Shard& rhs) const {
return dim == rhs.dim;
}
bool operator!=(const Shard& rhs) const {
return !operator==(rhs);
}
};
class StridedShard : public Shard {
public:
std::int64_t split_factor;
explicit StridedShard(std::int64_t dim, std::int64_t split_factor_)
: Shard(dim), split_factor(split_factor_) {}
bool operator==(const StridedShard& rhs) const {
return dim == rhs.dim && split_factor == rhs.split_factor;
}
bool operator==(const Shard& rhs) const {
if (auto* rhs_strided = dynamic_cast<const StridedShard*>(&rhs)) {
return operator==(*rhs_strided);
}
// TODO: this is to avoid extra all-gather in dtensor op dispatch
// note that sharding prop would not produce _StridedShard and a
// placement inequality would introduce an all-gather for resharding
return dim == rhs.dim;
}
bool operator!=(const Shard& rhs) const {
return !operator==(rhs);
}
};
class Replicate : public Placement {
public:
bool is_replicate() const override {
return true;
}
bool operator==(const Replicate& rhs) const {
return true;
}
bool operator!=(const Replicate& rhs) const {
return false;
}
};
class Partial : public Placement {
public:
std::string reduce_op;
Partial() : Partial("sum") {}
explicit Partial(std::optional<std::string> reduce_op_)
: reduce_op(
reduce_op_.has_value() ? std::move(*reduce_op_)
: std::string("sum")) {}
bool is_partial(
std::optional<std::string_view> op = std::nullopt) const override {
return !op.has_value() || *op == reduce_op;
}
bool operator==(const Partial& rhs) const {
return reduce_op == rhs.reduce_op;
}
bool operator!=(const Partial& rhs) const {
return !operator==(rhs);
}
};
} // namespace torch::distributed