mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/140679 Approved by: https://github.com/ezyang
69 lines
2.2 KiB
C++
69 lines
2.2 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
|
|
|
namespace torch::torch_dispatch_mode {
|
|
|
|
struct StashTorchDispatchModeGuard {
|
|
public:
|
|
StashTorchDispatchModeGuard() {
|
|
if (c10::impl::TorchDispatchModeTLS::any_modes_set(
|
|
/*skip_infra_modes=*/true)) {
|
|
saved_mode_ = c10::impl::TorchDispatchModeTLS::pop_stack();
|
|
} else {
|
|
auto mode_and_key =
|
|
c10::impl::TorchDispatchModeTLS::pop_highest_infra_mode();
|
|
saved_mode_ = std::move(std::get<0>(mode_and_key));
|
|
saved_mode_key_ = std::get<1>(mode_and_key);
|
|
}
|
|
}
|
|
|
|
~StashTorchDispatchModeGuard() {
|
|
if (saved_mode_key_.has_value()) {
|
|
c10::impl::TorchDispatchModeTLS::set_mode(
|
|
saved_mode_, saved_mode_key_.value());
|
|
} else {
|
|
c10::impl::TorchDispatchModeTLS::push_non_infra_mode_onto_stack(
|
|
std::move(saved_mode_));
|
|
}
|
|
}
|
|
StashTorchDispatchModeGuard(const StashTorchDispatchModeGuard&) = delete;
|
|
StashTorchDispatchModeGuard(StashTorchDispatchModeGuard&&) = delete;
|
|
StashTorchDispatchModeGuard& operator=(const StashTorchDispatchModeGuard&) =
|
|
delete;
|
|
StashTorchDispatchModeGuard& operator=(StashTorchDispatchModeGuard&&) =
|
|
delete;
|
|
|
|
const std::shared_ptr<c10::impl::PyObject_TorchDispatchMode>& get_cur_mode() {
|
|
return saved_mode_;
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<c10::impl::PyObject_TorchDispatchMode> saved_mode_;
|
|
std::optional<c10::impl::TorchDispatchModeKey> saved_mode_key_;
|
|
};
|
|
|
|
struct StashTorchDispatchStackGuard {
|
|
public:
|
|
StashTorchDispatchStackGuard() {
|
|
auto old = c10::impl::TorchDispatchModeTLS::get_state();
|
|
c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_));
|
|
saved_state_ = std::move(old);
|
|
}
|
|
StashTorchDispatchStackGuard(const StashTorchDispatchStackGuard&) = delete;
|
|
StashTorchDispatchStackGuard(StashTorchDispatchStackGuard&&) = delete;
|
|
StashTorchDispatchStackGuard& operator=(const StashTorchDispatchStackGuard&) =
|
|
delete;
|
|
StashTorchDispatchStackGuard& operator=(StashTorchDispatchStackGuard&&) =
|
|
delete;
|
|
|
|
~StashTorchDispatchStackGuard() {
|
|
c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_));
|
|
}
|
|
|
|
private:
|
|
c10::impl::TorchDispatchModeTLS saved_state_;
|
|
};
|
|
|
|
} // namespace torch::torch_dispatch_mode
|