mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This makes Placement data representation available in C++ via pybind11. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163030 Approved by: https://github.com/ezyang
122 lines
2.8 KiB
C++
122 lines
2.8 KiB
C++
#pragma once
|
|
/**
|
|
* The implementations in this file are coupled with
|
|
* torch/distributed/tensor/placement_types.py.
|
|
*/
|
|
|
|
#include <cstdint>
|
|
#include <optional>
|
|
#include <string>
|
|
#include <string_view>
|
|
|
|
namespace torch::distributed {
|
|
|
|
class Placement {
|
|
public:
|
|
Placement() = default;
|
|
virtual ~Placement() = default;
|
|
|
|
Placement(const Placement&) = default;
|
|
Placement& operator=(const Placement&) = default;
|
|
Placement(Placement&&) noexcept = default;
|
|
Placement& operator=(Placement&&) noexcept = default;
|
|
|
|
virtual bool is_shard(std::optional<std::int64_t> dim) const {
|
|
return false;
|
|
}
|
|
|
|
virtual bool is_replicate() const {
|
|
return false;
|
|
}
|
|
|
|
virtual bool is_partial(
|
|
std::optional<std::string_view> reduce_op = std::nullopt) const {
|
|
return false;
|
|
}
|
|
};
|
|
|
|
class Shard : public Placement {
|
|
public:
|
|
std::int64_t dim;
|
|
explicit Shard(std::int64_t dim_) : dim(dim_) {}
|
|
|
|
bool is_shard(std::optional<std::int64_t> dim_) const override {
|
|
return !dim_.has_value() || *dim_ == dim;
|
|
}
|
|
|
|
bool operator==(const Shard& rhs) const {
|
|
return dim == rhs.dim;
|
|
}
|
|
|
|
bool operator!=(const Shard& rhs) const {
|
|
return !operator==(rhs);
|
|
}
|
|
};
|
|
|
|
class StridedShard : public Shard {
|
|
public:
|
|
std::int64_t split_factor;
|
|
explicit StridedShard(std::int64_t dim, std::int64_t split_factor_)
|
|
: Shard(dim), split_factor(split_factor_) {}
|
|
|
|
bool operator==(const StridedShard& rhs) const {
|
|
return dim == rhs.dim && split_factor == rhs.split_factor;
|
|
}
|
|
|
|
bool operator==(const Shard& rhs) const {
|
|
if (auto* rhs_strided = dynamic_cast<const StridedShard*>(&rhs)) {
|
|
return operator==(*rhs_strided);
|
|
}
|
|
// TODO: this is to avoid extra all-gather in dtensor op dispatch
|
|
// note that sharding prop would not produce _StridedShard and a
|
|
// placement inequality would introduce an all-gather for resharding
|
|
return dim == rhs.dim;
|
|
}
|
|
|
|
bool operator!=(const Shard& rhs) const {
|
|
return !operator==(rhs);
|
|
}
|
|
};
|
|
|
|
class Replicate : public Placement {
|
|
public:
|
|
bool is_replicate() const override {
|
|
return true;
|
|
}
|
|
|
|
bool operator==(const Replicate& rhs) const {
|
|
return true;
|
|
}
|
|
|
|
bool operator!=(const Replicate& rhs) const {
|
|
return false;
|
|
}
|
|
};
|
|
|
|
class Partial : public Placement {
|
|
public:
|
|
std::string reduce_op;
|
|
|
|
Partial() : Partial("sum") {}
|
|
|
|
explicit Partial(std::optional<std::string> reduce_op_)
|
|
: reduce_op(
|
|
reduce_op_.has_value() ? std::move(*reduce_op_)
|
|
: std::string("sum")) {}
|
|
|
|
bool is_partial(
|
|
std::optional<std::string_view> op = std::nullopt) const override {
|
|
return !op.has_value() || *op == reduce_op;
|
|
}
|
|
|
|
bool operator==(const Partial& rhs) const {
|
|
return reduce_op == rhs.reduce_op;
|
|
}
|
|
|
|
bool operator!=(const Partial& rhs) const {
|
|
return !operator==(rhs);
|
|
}
|
|
};
|
|
|
|
} // namespace torch::distributed
|