Files
pytorch/c10/util/IntrusiveList.h
dolpm a766c1d117 [nativert] move intrusive list to c10/util (#152754)
Summary:
nativert RFC: https://github.com/zhxchen17/rfcs/blob/master/RFC-0043-torch-native-runtime.md

To land the runtime into PyTorch core, we will gradually land logical parts of the code into the Github issue and get each piece properly reviewed.

This diff moves intrusive list to c10/util

Test Plan: CI

Differential Revision: D74104595

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152754
Approved by: https://github.com/Skylion007, https://github.com/cyyever
2025-05-05 18:49:56 +00:00

207 lines
4.4 KiB
C++

#pragma once
#include <c10/util/Exception.h>
namespace c10 {
template <typename T>
class IntrusiveList;
class IntrusiveListHook {
template <typename P, typename T>
friend class ListIterator;
template <typename T>
friend class IntrusiveList;
IntrusiveListHook* next_{nullptr};
IntrusiveListHook* prev_{nullptr};
void link_before(IntrusiveListHook* next_node) {
next_ = next_node;
prev_ = next_node->prev_;
next_node->prev_ = this;
prev_->next_ = this;
}
public:
IntrusiveListHook() : next_(this), prev_(this) {}
IntrusiveListHook(const IntrusiveListHook&) = delete;
IntrusiveListHook& operator=(const IntrusiveListHook&) = delete;
IntrusiveListHook(IntrusiveListHook&&) = delete;
IntrusiveListHook& operator=(IntrusiveListHook&&) = delete;
void unlink() {
TORCH_CHECK(is_linked());
next_->prev_ = prev_;
prev_->next_ = next_;
next_ = this;
prev_ = this;
}
~IntrusiveListHook() {
if (is_linked()) {
unlink();
}
}
bool is_linked() const {
return next_ != this;
}
};
template <typename P, typename T>
class ListIterator {
static_assert(std::is_same_v<std::remove_const_t<P>, IntrusiveListHook>);
static_assert(std::is_base_of_v<IntrusiveListHook, T>);
P* ptr_;
friend class IntrusiveList<T>;
public:
using iterator_category = std::bidirectional_iterator_tag;
using value_type = std::conditional_t<std::is_const_v<P>, const T, T>;
using difference_type = std::ptrdiff_t;
using pointer = value_type*;
using reference = value_type&;
explicit ListIterator(P* ptr) : ptr_(ptr) {}
~ListIterator() = default;
ListIterator(const ListIterator&) = default;
ListIterator& operator=(const ListIterator&) = default;
ListIterator(ListIterator&&) = default;
ListIterator& operator=(ListIterator&&) = default;
template <
typename Q,
class = std::enable_if_t<std::is_const_v<P> && !std::is_const_v<Q>>>
ListIterator(const ListIterator<Q, T>& rhs) : ptr_(rhs.ptr_) {}
template <
typename Q,
class = std::enable_if_t<std::is_const_v<P> && !std::is_const_v<Q>>>
ListIterator& operator=(const ListIterator<Q, T>& rhs) {
ptr_ = rhs.ptr_;
return *this;
}
template <typename Q>
bool operator==(const ListIterator<Q, T>& other) const {
return ptr_ == other.ptr_;
}
template <typename Q>
bool operator!=(const ListIterator<Q, T>& other) const {
return !(*this == other);
}
auto& operator*() const {
return static_cast<reference>(*ptr_);
}
ListIterator& operator++() {
TORCH_CHECK(ptr_);
ptr_ = ptr_->next_;
return *this;
}
ListIterator& operator--() {
TORCH_CHECK(ptr_);
ptr_ = ptr_->prev_;
return *this;
}
auto* operator->() const {
return static_cast<pointer>(ptr_);
}
};
template <typename T>
class IntrusiveList {
static_assert(std::is_base_of_v<IntrusiveListHook, T>);
public:
IntrusiveList() = default;
IntrusiveList(const std::initializer_list<std::reference_wrapper<T>>& items) {
for (auto& item : items) {
insert(this->end(), item);
}
}
~IntrusiveList() {
while (head_.is_linked()) {
head_.next_->unlink();
}
}
IntrusiveList(const IntrusiveList&) = delete;
IntrusiveList& operator=(const IntrusiveList&) = delete;
IntrusiveList(IntrusiveList&&) = delete;
IntrusiveList& operator=(IntrusiveList&&) = delete;
using iterator = ListIterator<IntrusiveListHook, T>;
using const_iterator = ListIterator<const IntrusiveListHook, T>;
auto begin() const {
return ++const_iterator{&head_};
}
auto begin() {
return ++iterator{&head_};
}
auto end() const {
return const_iterator{&head_};
}
auto end() {
return iterator{&head_};
}
auto rbegin() const {
return std::reverse_iterator{end()};
}
auto rbegin() {
return std::reverse_iterator{end()};
}
auto rend() const {
return std::reverse_iterator{begin()};
}
auto rend() {
return std::reverse_iterator{begin()};
}
auto iterator_to(const T& n) const {
return const_iterator{&n};
}
auto iterator_to(T& n) {
return iterator{&n};
}
iterator insert(iterator pos, T& n) {
n.link_before(pos.ptr_);
return iterator{&n};
}
size_t size() const {
size_t ret = 0;
for ([[maybe_unused]] auto& _ : *this) {
ret++;
}
return ret;
}
bool empty() const {
return !head_.is_linked();
}
private:
IntrusiveListHook head_;
};
} // namespace c10