Files
pytorch/c10/core/impl/TorchDispatchModeTLS.cpp
cyy 8fa81a6066 Enable misc-use-internal-linkage check and apply fixes (#148948)
Enables clang-tidy rule [`misc-use-internal-linkage`](https://clang.llvm.org/extra/clang-tidy/checks/misc/use-internal-linkage.html). This new check was introduced in Clang-Tidy 18 and is available due to recent update of Clang-Tidy 19.

The check marks functions and variables used only in the translation unit as static. Therefore undesired symbols are not leaked into other units, more link time optimisations are possible and the resulting binaries may be smaller.

The detected violations were mostly fixed by using static. In other cases, the symbols were indeed consumed by others files, then their declaring headers were included. Still some declarations were wrong and have been fixed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148948
Approved by: https://github.com/Skylion007
2025-03-12 14:22:56 +00:00

197 lines
6.7 KiB
C++

#include <c10/core/DispatchKey.h>
#include <c10/core/SafePyObject.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
#include <c10/util/irange.h>
#include <utility>
namespace c10::impl {
thread_local static TorchDispatchModeTLS torchDispatchModeState;
bool TorchDispatchModeTLS::any_modes_set(bool skip_infra_modes) {
if (!torchDispatchModeState.stack_.empty())
return true;
if (!skip_infra_modes) {
for (const auto i : c10::irange(
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
if (torchDispatchModeState.infra_modes_[i] != std::nullopt) {
return true;
}
}
}
return false;
}
void TorchDispatchModeTLS::push_non_infra_mode_onto_stack(
std::shared_ptr<PyObject_TorchDispatchMode> mode) {
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}
torchDispatchModeState.stack_.push_back(std::move(mode));
}
const std::shared_ptr<PyObject_TorchDispatchMode> TorchDispatchModeTLS::
pop_stack() {
std::shared_ptr<PyObject_TorchDispatchMode> out;
if (!torchDispatchModeState.stack_.empty()) {
out = torchDispatchModeState.stack_.back();
torchDispatchModeState.stack_.pop_back();
} else {
for (int64_t i =
static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
i >= 0;
--i) {
if (torchDispatchModeState.infra_modes_[i].has_value()) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
out = std::move(torchDispatchModeState.infra_modes_[i].value());
torchDispatchModeState.infra_modes_[i] = std::nullopt;
break;
}
}
}
TORCH_CHECK(out, "trying to pop from empty mode stack");
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
return out;
}
const std::
tuple<std::shared_ptr<PyObject_TorchDispatchMode>, TorchDispatchModeKey>
TorchDispatchModeTLS::pop_highest_infra_mode() {
for (int64_t i = static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS) - 1;
i >= 0;
--i) {
if (torchDispatchModeState.infra_modes_[i].has_value()) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
auto out_mode = torchDispatchModeState.infra_modes_[i].value();
torchDispatchModeState.infra_modes_[i] = std::nullopt;
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
return std::make_tuple(
std::move(out_mode), static_cast<TorchDispatchModeKey>(i));
}
}
TORCH_CHECK(
false, "Called pop_highest_infra_mode, but no infra modes were active.")
}
const std::shared_ptr<PyObject_TorchDispatchMode>& TorchDispatchModeTLS::
get_stack_at(int64_t idx) {
TORCH_CHECK(idx < stack_len(), "Tried to get stack at idx that's too big");
// Our "logical" stack includes both:
// - any user modes (the entire torchDispatchModeState.stack_)
// - any infra modes (members of torchDispatchModeState.infra_modes_ that are
// not None)
// idx == 0 means the "bottom" of the stack, which starts with any infra
// modes (iterating from lowest-priority to highest-priority).
auto curr_idx = idx;
for (const auto i :
c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
if (torchDispatchModeState.infra_modes_[i].has_value()) {
if (curr_idx == 0) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
return torchDispatchModeState.infra_modes_[i].value();
}
curr_idx -= 1;
}
}
// At this point, we're guaranteed that curr_idx < stack_.size()
return torchDispatchModeState.stack_[curr_idx];
}
int64_t TorchDispatchModeTLS::stack_len() {
auto stack_len = static_cast<int64_t>(torchDispatchModeState.stack_.size());
int64_t infra_modes_len = 0;
for (const auto i :
c10::irange(static_cast<size_t>(TorchDispatchModeKey::NUM_MODE_KEYS))) {
if (torchDispatchModeState.infra_modes_[i] != std::nullopt) {
infra_modes_len += 1;
}
}
return stack_len + infra_modes_len;
}
const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
TorchDispatchModeTLS::get_mode(TorchDispatchModeKey mode_key) {
return torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
}
void TorchDispatchModeTLS::set_mode(
const std::shared_ptr<PyObject_TorchDispatchMode>& mode,
TorchDispatchModeKey mode_key) {
TORCH_CHECK(
torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] ==
std::nullopt,
"trying to set the current ",
to_string(mode_key),
", but one already exists");
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}
torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] = mode;
}
const std::optional<std::shared_ptr<PyObject_TorchDispatchMode>>
TorchDispatchModeTLS::unset_mode(TorchDispatchModeKey mode_key) {
auto out = torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)];
torchDispatchModeState.infra_modes_[static_cast<size_t>(mode_key)] =
std::nullopt;
if (out.has_value() && !any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
return out;
}
const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
return torchDispatchModeState;
}
void TorchDispatchModeTLS::set_state(TorchDispatchModeTLS state) {
torchDispatchModeState = std::move(state);
if (!any_modes_set()) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
} else {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
}
}
// UTIL
bool dispatch_mode_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python) &&
TorchDispatchModeTLS::any_modes_set();
}
std::string to_string(TorchDispatchModeKey mode_key) {
switch (mode_key) {
case TorchDispatchModeKey::PROXY:
return "ProxyTorchDispatchMode";
case TorchDispatchModeKey::FAKE:
return "FakeTensorMode";
default:
return "UNKNOWN_MODE";
}
}
} // namespace c10::impl