mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[kineto] Optimize getStepCallbacks for common case of no active callbacks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77804 IIUC, the result of this function will be empty and unused if there are no sampled callbacks, which is the common case. We can accelerate this case by wrapping the result in an optional to save initializing an empty SmallVector. Differential Revision: [D36497279](https://our.internmc.facebook.com/intern/diff/D36497279/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36497279/)! Approved by: https://github.com/robieta
This commit is contained in:
committed by
PyTorch MergeBot
parent
02c4d877b4
commit
c083489f46
@ -545,9 +545,9 @@ C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandl
|
||||
.template getDispatchKeySetUnboxed<Args...>(args...);
|
||||
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
|
||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION);
|
||||
if (C10_UNLIKELY(!step_callbacks.empty() && op.operatorDef_->op.isObserved())) {
|
||||
return callWithDispatchKeySlowPath<Return, Args...>(op, step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...);
|
||||
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
|
||||
if (C10_UNLIKELY(step_callbacks.has_value() && op.operatorDef_->op.isObserved())) {
|
||||
return callWithDispatchKeySlowPath<Return, Args...>(op, *step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...);
|
||||
}
|
||||
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
|
||||
@ -568,9 +568,9 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
|
||||
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
|
||||
const auto& kernel = entry.lookup(dispatchKeySet);
|
||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION);
|
||||
if (C10_UNLIKELY(!step_callbacks.empty() && entry.isObserved())) {
|
||||
at::RecordFunction guard(std::move(step_callbacks));
|
||||
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
|
||||
if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) {
|
||||
at::RecordFunction guard(std::move(*step_callbacks));
|
||||
auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
|
||||
auto& schema = op.schema();
|
||||
auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);
|
||||
|
@ -130,6 +130,7 @@ class CacheEntry {
|
||||
// The caller is expected to check `GlobalCallbackManager::get().version()'
|
||||
// and call CacheEntry::update() if necessary.
|
||||
StepCallbacks getActiveCallbacks();
|
||||
c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty();
|
||||
|
||||
// Full rebuild. (E.g. during registration)
|
||||
void update(const std::vector<RecordFunctionCallback>& callbacks);
|
||||
@ -142,6 +143,8 @@ class CacheEntry {
|
||||
int tries_left_{-1};
|
||||
};
|
||||
|
||||
C10_ALWAYS_INLINE void getActiveCallbacksImpl();
|
||||
|
||||
void rebuildActiveCallbacks();
|
||||
int sampleTries(double p) const;
|
||||
|
||||
@ -169,6 +172,7 @@ class LocalCallbackManager {
|
||||
public:
|
||||
const RecordFunctionTLS& getTLS() const;
|
||||
StepCallbacks getActiveCallbacks(const RecordScope scope);
|
||||
c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty(const RecordScope scope);
|
||||
|
||||
void setTLS(const RecordFunctionTLS& tls);
|
||||
void seed(uint32_t seed);
|
||||
@ -178,6 +182,8 @@ class LocalCallbackManager {
|
||||
void clearCallbacks();
|
||||
|
||||
private:
|
||||
void rebuildActiveCallbacksIfNeeded();
|
||||
|
||||
void rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot);
|
||||
|
||||
void rebuild_callback_scopes(
|
||||
@ -271,7 +277,7 @@ void CacheEntry::update(const std::vector<RecordFunctionCallback>& callbacks) {
|
||||
rebuildActiveCallbacks();
|
||||
}
|
||||
|
||||
StepCallbacks CacheEntry::getActiveCallbacks() {
|
||||
void CacheEntry::getActiveCallbacksImpl() {
|
||||
// We rebuild the active set when `sampling_countdown_` reaches zero, so if it
|
||||
// reaches zero at the start of this function something has gone wrong.
|
||||
TORCH_INTERNAL_ASSERT(sampling_countdown_ > 0, sampling_countdown_);
|
||||
@ -295,7 +301,18 @@ StepCallbacks CacheEntry::getActiveCallbacks() {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StepCallbacks CacheEntry::getActiveCallbacks() {
|
||||
getActiveCallbacksImpl();
|
||||
return active_callbacks_;
|
||||
}
|
||||
|
||||
c10::optional<StepCallbacks> CacheEntry::getActiveCallbacksUnlessEmpty() {
|
||||
getActiveCallbacksImpl();
|
||||
if (C10_LIKELY(active_callbacks_.empty())) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
return active_callbacks_;
|
||||
}
|
||||
|
||||
@ -365,15 +382,25 @@ const RecordFunctionTLS& LocalCallbackManager::getTLS() const {
|
||||
return registered_callbacks_;
|
||||
}
|
||||
|
||||
StepCallbacks LocalCallbackManager::getActiveCallbacks(
|
||||
const RecordScope scope) {
|
||||
void LocalCallbackManager::rebuildActiveCallbacksIfNeeded() {
|
||||
const auto global_version = GlobalCallbackManager::get().version();
|
||||
if (C10_UNLIKELY(global_version != global_version_)) {
|
||||
rebuild_all(GlobalCallbackManager::get().getSnapshot());
|
||||
}
|
||||
}
|
||||
|
||||
StepCallbacks LocalCallbackManager::getActiveCallbacks(
|
||||
const RecordScope scope) {
|
||||
rebuildActiveCallbacksIfNeeded();
|
||||
return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacks();
|
||||
}
|
||||
|
||||
c10::optional<StepCallbacks> LocalCallbackManager::getActiveCallbacksUnlessEmpty(
|
||||
const RecordScope scope) {
|
||||
rebuildActiveCallbacksIfNeeded();
|
||||
return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacksUnlessEmpty();
|
||||
}
|
||||
|
||||
void LocalCallbackManager::setTLS(const RecordFunctionTLS& tls) {
|
||||
registered_callbacks_ = tls;
|
||||
rebuild_all(GlobalCallbackManager::get().getSnapshot());
|
||||
@ -572,6 +599,10 @@ StepCallbacks getStepCallbacks(RecordScope scope) {
|
||||
return LocalCallbackManager::get().getActiveCallbacks(scope);
|
||||
}
|
||||
|
||||
c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty(RecordScope scope) {
|
||||
return LocalCallbackManager::get().getActiveCallbacksUnlessEmpty(scope);
|
||||
}
|
||||
|
||||
const RecordFunctionTLS& get_record_function_tls_() {
|
||||
return LocalCallbackManager::get().getTLS();
|
||||
}
|
||||
|
@ -478,6 +478,8 @@ struct TORCH_API RecordFunction {
|
||||
|
||||
TORCH_API StepCallbacks getStepCallbacks(RecordScope scope);
|
||||
|
||||
TORCH_API c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty(RecordScope scope);
|
||||
|
||||
namespace detail {
|
||||
template <typename Inputs, typename F, typename... Args>
|
||||
void record_function_with_scope(RecordFunction& guard, F fn, const Inputs& inputs, Args&&... args) {
|
||||
|
@ -1,3 +1,4 @@
|
||||
|
||||
#include <torch/torch.h>
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
@ -49,9 +50,9 @@ float runPureRecordFunctionBench(int iter) {
|
||||
typedef std::chrono::microseconds us;
|
||||
std::chrono::time_point<clock> start_time = clock::now();
|
||||
for (auto idx = 0; idx < iter; ++idx) {
|
||||
auto step_callbacks = at::getStepCallbacks(at::RecordScope::USER_SCOPE);
|
||||
if (!step_callbacks.empty()) {
|
||||
at::RecordFunction guard(std::move(step_callbacks));
|
||||
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::USER_SCOPE);
|
||||
if (step_callbacks.has_value()) {
|
||||
at::RecordFunction guard(std::move(*step_callbacks));
|
||||
guard.before("Test", -1);
|
||||
}
|
||||
}
|
||||
|
@ -151,9 +151,9 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
// probably operate with names.
|
||||
at::NoNamesGuard no_names_guard;
|
||||
|
||||
auto step_callbacks = at::getStepCallbacks(at::RecordScope::BACKWARD_FUNCTION);
|
||||
if (!step_callbacks.empty()) {
|
||||
at::RecordFunction guard(std::move(step_callbacks));
|
||||
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
|
||||
if (C10_UNLIKELY(step_callbacks.has_value())) {
|
||||
at::RecordFunction guard(std::move(*step_callbacks));
|
||||
// Using sequence number and thread id to correlate with
|
||||
// the forward pass function
|
||||
guard.setForwardThreadId(thread_id_);
|
||||
|
@ -845,11 +845,11 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
|
||||
static void checkAndStartRecordFunction(Frame& frame, Stack& stack) {
|
||||
if (!frame.record_function) {
|
||||
auto step_callbacks =
|
||||
at::getStepCallbacks(at::RecordScope::TORCHSCRIPT_FUNCTION);
|
||||
if (!step_callbacks.empty()) {
|
||||
auto step_callbacks = at::getStepCallbacksUnlessEmpty(
|
||||
at::RecordScope::TORCHSCRIPT_FUNCTION);
|
||||
if (C10_UNLIKELY(step_callbacks.has_value())) {
|
||||
auto rec_fn =
|
||||
std::make_unique<at::RecordFunction>(std::move(step_callbacks));
|
||||
std::make_unique<at::RecordFunction>(std::move(*step_callbacks));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive());
|
||||
if (rec_fn->needsInputs()) {
|
||||
rec_fn->before(
|
||||
|
@ -1201,9 +1201,9 @@ c10::IValue BlockRunner::run_impl_record_functions(
|
||||
IValueList&& args,
|
||||
const KeywordArgs& kwargs) {
|
||||
auto step_callbacks =
|
||||
at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_MODEL);
|
||||
if (!step_callbacks.empty()) {
|
||||
at::RecordFunction guard(std::move(step_callbacks));
|
||||
at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL);
|
||||
if (C10_UNLIKELY(step_callbacks.has_value())) {
|
||||
at::RecordFunction guard(std::move(*step_callbacks));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
|
||||
guard.needsInputs()
|
||||
? guard.before(
|
||||
@ -1845,9 +1845,9 @@ std::vector<IValue> ProcessedNode::inputs_ivalue_vec() const {
|
||||
void ProcessedNode::run() {
|
||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
auto step_callbacks =
|
||||
at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_OP);
|
||||
if (!step_callbacks.empty()) {
|
||||
at::RecordFunction guard(std::move(step_callbacks));
|
||||
at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_OP);
|
||||
if (C10_UNLIKELY(step_callbacks.has_value())) {
|
||||
at::RecordFunction guard(std::move(*step_callbacks));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
|
||||
if (guard.needsInputs()) {
|
||||
const auto inputs = inputs_ivalue_vec();
|
||||
|
Reference in New Issue
Block a user