mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This function is relatively hot; inlining here reduces time reported by `python -m timeit --setup 'import torch; t = torch.tensor([1])' 't._cdata'` from about 125 nsec/loop to about 110 nsec/loop. (To be fair, variance is high, but I did confirm with perf that time in this path seems to have roughly halved during torchtitan training.) Note that locally I am getting bit by a GCC bug that I documented in a comment. Would be interested to hear if this does anything for clang. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164617 Approved by: https://github.com/ezyang
38 lines
1.2 KiB
C++
38 lines
1.2 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/SafePyObject.h>
|
|
#include <c10/macros/Macros.h>
|
|
|
|
namespace at::impl {
|
|
|
|
enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
|
|
|
|
struct TORCH_API PythonTorchFunctionTLS {
|
|
static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
|
|
static TorchFunctionDisabledState get_disabled_state();
|
|
|
|
static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
|
|
static const std::shared_ptr<SafePyObject> pop_stack();
|
|
static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
|
|
static int64_t stack_len();
|
|
|
|
static const PythonTorchFunctionTLS& get_state();
|
|
static void set_state(const PythonTorchFunctionTLS& state);
|
|
|
|
private:
|
|
// The mode TLS is split into
|
|
// - disabled_state, which says which part of torch function are disabled
|
|
// - stack_, which is a vector of modes representing the stack of user
|
|
// defined modes
|
|
TorchFunctionDisabledState disabled_state_ =
|
|
TorchFunctionDisabledState::ENABLED;
|
|
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
|
|
friend TORCH_API bool torch_function_mode_enabled();
|
|
};
|
|
|
|
TORCH_API bool torch_function_mode_enabled();
|
|
|
|
TORCH_API bool torch_function_all_disabled();
|
|
|
|
} // namespace at::impl
|