Files
pytorch/aten/src/ATen/PythonTorchFunctionTLS.h
Scott Wolchok cad2d473bf Force inlining into torch_function_mode_enabled (#164617)
This function is relatively hot; inlining here reduces time reported by `python -m timeit --setup 'import torch; t = torch.tensor([1])' 't._cdata'` from about 125 nsec/loop to about 110 nsec/loop. (To be fair, variance is high, but I did confirm with perf that time in this path seems to have roughly halved during torchtitan training.)

Note that locally I am getting bit by a GCC bug that I documented in a comment. Would be interested to hear if this does anything for clang.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164617
Approved by: https://github.com/ezyang
2025-10-13 19:25:51 +00:00

38 lines
1.2 KiB
C++

#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Macros.h>
namespace at::impl {
enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
struct TORCH_API PythonTorchFunctionTLS {
static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
static TorchFunctionDisabledState get_disabled_state();
static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
static const std::shared_ptr<SafePyObject> pop_stack();
static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
static int64_t stack_len();
static const PythonTorchFunctionTLS& get_state();
static void set_state(const PythonTorchFunctionTLS& state);
private:
// The mode TLS is split into
// - disabled_state, which says which part of torch function are disabled
// - stack_, which is a vector of modes representing the stack of user
// defined modes
TorchFunctionDisabledState disabled_state_ =
TorchFunctionDisabledState::ENABLED;
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
friend TORCH_API bool torch_function_mode_enabled();
};
TORCH_API bool torch_function_mode_enabled();
TORCH_API bool torch_function_all_disabled();
} // namespace at::impl