mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
This PR adds a C function to check if all torch function is disabled. Recall that there are three torch function enablement states: * All disabled * Torch Function Subclass disabled * All enabled The API before this change provides two functions: * `_is_torch_function_enabled` - returns True iff the current TF state is All enabled * `_is_torch_function_mode_enabled` - returns True iff the state is not All disabled and the torch function mode stack is non-empty. The crux of why a new API is needed is the following: If dynamo enters a frame with the torch function mode stack empty, `_is_torch_function_enabled` == False, it is impossible to determine if after a new mode is pushed whether we should enter the mode or not. This is because we don't know if the enablement state is All disabled or only subclass disabled. Adding this API to check if All disabled is True allows us to disambiguate this case. In the next PR, Dynamo InstructionTranslator will have clearer flags than the underlying C API: * A flag to indicate if subclasses are disabled (ie All disabled or Subclass Disabled is the current state) * A flag to indicate if modes are disabled (ie if All disabled is the current state) * A symbolic stack which can be checked if any modes are present Pull Request resolved: https://github.com/pytorch/pytorch/pull/133136 Approved by: https://github.com/bdhirsh ghstack dependencies: #133130, #133729, #133131, #133132, #133133, #133134
55 lines
1.9 KiB
C++
55 lines
1.9 KiB
C++
#include <ATen/PythonTorchFunctionTLS.h>
|
|
#include <c10/core/TensorImpl.h>
|
|
|
|
namespace at::impl {
|
|
|
|
static thread_local PythonTorchFunctionTLS pythonTorchFunctionState;
|
|
|
|
void PythonTorchFunctionTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
|
|
pythonTorchFunctionState.stack_.push_back(std::move(mode));
|
|
}
|
|
|
|
const std::shared_ptr<SafePyObject> PythonTorchFunctionTLS::pop_stack() {
|
|
TORCH_CHECK(!pythonTorchFunctionState.stack_.empty(), "trying to pop from empty mode stack");
|
|
auto out = pythonTorchFunctionState.stack_.back();
|
|
pythonTorchFunctionState.stack_.pop_back();
|
|
return out;
|
|
}
|
|
|
|
const std::shared_ptr<SafePyObject>& PythonTorchFunctionTLS::get_stack_at(int64_t idx) {
|
|
TORCH_CHECK(idx < static_cast<int64_t>(pythonTorchFunctionState.stack_.size()), "Tried to get stack at idx that's too big");
|
|
return pythonTorchFunctionState.stack_[idx];
|
|
}
|
|
|
|
int64_t PythonTorchFunctionTLS::stack_len() {
|
|
return static_cast<int64_t>(pythonTorchFunctionState.stack_.size());
|
|
}
|
|
|
|
void PythonTorchFunctionTLS::set_disabled_state(TorchFunctionDisabledState disabled_state) {
|
|
pythonTorchFunctionState.disabled_state_ = disabled_state;
|
|
}
|
|
|
|
TorchFunctionDisabledState PythonTorchFunctionTLS::get_disabled_state() {
|
|
return pythonTorchFunctionState.disabled_state_;
|
|
}
|
|
|
|
void PythonTorchFunctionTLS::set_state(const PythonTorchFunctionTLS& state) {
|
|
pythonTorchFunctionState = state;
|
|
}
|
|
|
|
const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() {
|
|
return pythonTorchFunctionState;
|
|
}
|
|
|
|
bool torch_function_mode_enabled() {
|
|
return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED &&
|
|
PythonTorchFunctionTLS::stack_len() > 0;
|
|
}
|
|
|
|
// This is needed to disambiguate the ternary torch function disabled states
|
|
bool torch_function_all_disabled() {
|
|
return PythonTorchFunctionTLS::get_disabled_state() == TorchFunctionDisabledState::ALL_DISABLED;
|
|
}
|
|
|
|
} // namespace at::impl
|