Files
pytorch/c10/core/impl/TorchDispatchModeTLS.cpp
Yuanyuan Chen 0f0b4bf029 [1/N] Remove unused header inclusion (#165763)
This PR removes unused header inclusion in C++ files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165763
Approved by: https://github.com/Skylion007
2025-10-18 05:23:11 +00:00

196 lines
6.6 KiB
C++

#include <c10/core/DispatchKey.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
#include <c10/util/irange.h>
#include <utility>
namespace c10::impl {
thread_local static TorchDispatchModeTLS torchDispatchModeState;
bool TorchDispatchModeTLS::any_modes_set(bool skip_infra_modes) {
if (!torchDispatchModeState.stack_.empty())
return true;
if (!skip_infra_modes) {
for (const auto i : c10::irange(
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
if (torchDispatchModeState.infra_modes_[i] != std::nullopt) {
return true;
}
}
}
return false;
}
void TorchDispatchModeTLS::push_non_infra_mode_onto_stack(
std::shared_ptr<PyObject_TorchDispatchMode> mode) {
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}
torchDispatchModeState.stack_.push_back(std::move(mode));
}
const std::shared_ptr<PyObject_TorchDispatchMode> TorchDispatchModeTLS::
pop_stack() {
std::shared_ptr<PyObject_TorchDispatchMode> out;
if (!torchDispatchModeState.stack_.empty()) {
out = torchDispatchModeState.stack_.back();
torchDispatchModeState.stack_.pop_back();
} else {
for (int64_t i =
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
i >= 0;
--i) {
if (torchDispatchModeState.infra_modes_[i].has_value()) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
out = std::move(torchDispatchModeState.infra_modes_[i].value());
torchDispatchModeState.infra_modes_[i] = std::nullopt;
break;
}
}
}
TORCH_CHECK(out, "trying to pop from empty mode stack");
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
return out;
}
const std::
tuple<std::shared_ptr<PyObject_TorchDispatchMode>, TorchDispatchModeKey>
TorchDispatchModeTLS::pop_highest_infra_mode() {
for (int64_t i = static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
i >= 0;
--i) {
if (torchDispatchModeState.infra_modes_[i].has_value()) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
auto out_mode = torchDispatchModeState.infra_modes_[i].value();
torchDispatchModeState.infra_modes_[i] = std::nullopt;
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
return std::make_tuple(
std::move(out_mode), static_cast<TorchDispatchModeKey>(i));
}
}
TORCH_CHECK(
false, "Called pop_highest_infra_mode, but no infra modes were active.")
}
const std::shared_ptr<PyObject_TorchDispatchMode>& TorchDispatchModeTLS::
get_stack_at(int64_t idx) {
TORCH_CHECK(idx < stack_len(), "Tried to get stack at idx that's too big");
// Our "logical" stack includes both:
// - any user modes (the entire torchDispatchModeState.stack_)
// - any infra modes (members of torchDispatchModeState.infra_modes_ that are
// not None)
// idx == 0 means the "bottom" of the stack, which starts with any infra
// modes (iterating from lowest-priority to highest-priority).
auto curr_idx = idx;
for (const auto i :
c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
if (torchDispatchModeState.infra_modes_[i].has_value()) {
if (curr_idx == 0) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
return torchDispatchModeState.infra_modes_[i].value();
}
curr_idx -= 1;
}
}
// At this point, we're guaranteed that curr_idx < stack_.size()
return torchDispatchModeState.stack_[curr_idx];
}
int64_t TorchDispatchModeTLS::stack_len() {
auto stack_len = static_cast<int64_t>(torchDispatchModeState.stack_.size());
int64_t infra_modes_len = 0;
for (const auto i :
c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
if (torchDispatchModeState.infra_modes_[i] != std::nullopt) {
infra_modes_len += 1;
}
}
return stack_len + infra_modes_len;
}
const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
TorchDispatchModeTLS::get_mode(TorchDispatchModeKey mode_key) {
return torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
}
void TorchDispatchModeTLS::set_mode(
const std::shared_ptr<PyObject_TorchDispatchMode>& mode,
TorchDispatchModeKey mode_key) {
TORCH_CHECK(
torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] ==
std::nullopt,
"trying to set the current ",
to_string(mode_key),
", but one already exists");
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}
torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] = mode;
}
const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
TorchDispatchModeTLS::unset_mode(TorchDispatchModeKey mode_key) {
auto out = torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] =
std::nullopt;
if (out.has_value() && !any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
return out;
}
const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
return torchDispatchModeState;
}
void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) {
torchDispatchModeState = std::move(state);
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
} else {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}
}
// UTIL
bool dispatch_mode_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python) &&
TorchDispatchModeTLS::any_modes_set();
}
std::string to_string(TorchDispatchModeKey mode_key) {
switch (mode_key) {
case TorchDispatchModeKey::PROXY:
return "ProxyTorchDispatchMode";
case TorchDispatchModeKey::FAKE:
return "FakeTensorMode";
default:
return "UNKNOWN_MODE";
}
}
} // namespace c10::impl