From c083489f46e3885fbfc8888fd308e238bfcabfd6 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 19 May 2022 15:26:22 -0700 Subject: [PATCH] [kineto] Optimize getStepCallbacks for common case of no active callbacks Pull Request resolved: https://github.com/pytorch/pytorch/pull/77804 IIUC, the result of this function will be empty and unused if there are no sampled callbacks, which is the common case. We can accelerate this case by wrapping the result in an optional to save initializing an empty SmallVector. Differential Revision: [D36497279](https://our.internmc.facebook.com/intern/diff/D36497279/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36497279/)! Approved by: https://github.com/robieta --- aten/src/ATen/core/dispatch/Dispatcher.h | 12 ++++---- aten/src/ATen/record_function.cpp | 37 ++++++++++++++++++++++-- aten/src/ATen/record_function.h | 2 ++ binaries/record_function_benchmark.cc | 7 +++-- torch/csrc/autograd/function.h | 6 ++-- torch/csrc/jit/runtime/interpreter.cpp | 8 ++--- torch/csrc/jit/runtime/static/impl.cpp | 12 ++++---- 7 files changed, 59 insertions(+), 25 deletions(-) diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 4f3476133834..7f8e1532ae0b 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -545,9 +545,9 @@ C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandl .template getDispatchKeySetUnboxed(args...); const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING - auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION); - if (C10_UNLIKELY(!step_callbacks.empty() && op.operatorDef_->op.isObserved())) { - return callWithDispatchKeySlowPath(op, step_callbacks, dispatchKeySet, kernel, std::forward(args)...); + auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION); + if (C10_UNLIKELY(step_callbacks.has_value() && op.operatorDef_->op.isObserved())) { + return callWithDispatchKeySlowPath(op, *step_callbacks, dispatchKeySet, kernel, std::forward(args)...); } #endif // PYTORCH_DISABLE_PER_OP_PROFILING return kernel.template call(op, dispatchKeySet, std::forward(args)...); @@ -568,9 +568,9 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); const auto& kernel = entry.lookup(dispatchKeySet); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING - auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION); - if (C10_UNLIKELY(!step_callbacks.empty() && entry.isObserved())) { - at::RecordFunction guard(std::move(step_callbacks)); + auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION); + if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) { + at::RecordFunction guard(std::move(*step_callbacks)); auto dispatchKey = dispatchKeySet.highestPriorityTypeId(); auto& schema = op.schema(); auto schema_ref = std::reference_wrapper(schema); diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index 8d9160135cc1..dda05ffc53c2 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -130,6 +130,7 @@ class CacheEntry { // The caller is expected to check `GlobalCallbackManager::get().version()' // and call CacheEntry::update() if necessary. StepCallbacks getActiveCallbacks(); + c10::optional getActiveCallbacksUnlessEmpty(); // Full rebuild. (E.g. during registration) void update(const std::vector& callbacks); @@ -142,6 +143,8 @@ class CacheEntry { int tries_left_{-1}; }; + C10_ALWAYS_INLINE void getActiveCallbacksImpl(); + void rebuildActiveCallbacks(); int sampleTries(double p) const; @@ -169,6 +172,7 @@ class LocalCallbackManager { public: const RecordFunctionTLS& getTLS() const; StepCallbacks getActiveCallbacks(const RecordScope scope); + c10::optional getActiveCallbacksUnlessEmpty(const RecordScope scope); void setTLS(const RecordFunctionTLS& tls); void seed(uint32_t seed); @@ -178,6 +182,8 @@ class LocalCallbackManager { void clearCallbacks(); private: + void rebuildActiveCallbacksIfNeeded(); + void rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot); void rebuild_callback_scopes( @@ -271,7 +277,7 @@ void CacheEntry::update(const std::vector& callbacks) { rebuildActiveCallbacks(); } -StepCallbacks CacheEntry::getActiveCallbacks() { +void CacheEntry::getActiveCallbacksImpl() { // We rebuild the active set when `sampling_countdown_` reaches zero, so if it // reaches zero at the start of this function something has gone wrong. TORCH_INTERNAL_ASSERT(sampling_countdown_ > 0, sampling_countdown_); @@ -295,7 +301,18 @@ StepCallbacks CacheEntry::getActiveCallbacks() { } } } +} +StepCallbacks CacheEntry::getActiveCallbacks() { + getActiveCallbacksImpl(); + return active_callbacks_; +} + +c10::optional CacheEntry::getActiveCallbacksUnlessEmpty() { + getActiveCallbacksImpl(); + if (C10_LIKELY(active_callbacks_.empty())) { + return c10::nullopt; + } return active_callbacks_; } @@ -365,15 +382,25 @@ const RecordFunctionTLS& LocalCallbackManager::getTLS() const { return registered_callbacks_; } -StepCallbacks LocalCallbackManager::getActiveCallbacks( - const RecordScope scope) { +void LocalCallbackManager::rebuildActiveCallbacksIfNeeded() { const auto global_version = GlobalCallbackManager::get().version(); if (C10_UNLIKELY(global_version != global_version_)) { rebuild_all(GlobalCallbackManager::get().getSnapshot()); } +} + +StepCallbacks LocalCallbackManager::getActiveCallbacks( + const RecordScope scope) { + rebuildActiveCallbacksIfNeeded(); return active_callbacks_[static_cast(scope)].getActiveCallbacks(); } +c10::optional LocalCallbackManager::getActiveCallbacksUnlessEmpty( + const RecordScope scope) { + rebuildActiveCallbacksIfNeeded(); + return active_callbacks_[static_cast(scope)].getActiveCallbacksUnlessEmpty(); +} + void LocalCallbackManager::setTLS(const RecordFunctionTLS& tls) { registered_callbacks_ = tls; rebuild_all(GlobalCallbackManager::get().getSnapshot()); @@ -572,6 +599,10 @@ StepCallbacks getStepCallbacks(RecordScope scope) { return LocalCallbackManager::get().getActiveCallbacks(scope); } +c10::optional getStepCallbacksUnlessEmpty(RecordScope scope) { + return LocalCallbackManager::get().getActiveCallbacksUnlessEmpty(scope); +} + const RecordFunctionTLS& get_record_function_tls_() { return LocalCallbackManager::get().getTLS(); } diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index af594f47e789..e1b762a4319a 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -478,6 +478,8 @@ struct TORCH_API RecordFunction { TORCH_API StepCallbacks getStepCallbacks(RecordScope scope); +TORCH_API c10::optional getStepCallbacksUnlessEmpty(RecordScope scope); + namespace detail { template void record_function_with_scope(RecordFunction& guard, F fn, const Inputs& inputs, Args&&... args) { diff --git a/binaries/record_function_benchmark.cc b/binaries/record_function_benchmark.cc index 8d53007bc8ef..6a48052d3638 100644 --- a/binaries/record_function_benchmark.cc +++ b/binaries/record_function_benchmark.cc @@ -1,3 +1,4 @@ + #include #include @@ -49,9 +50,9 @@ float runPureRecordFunctionBench(int iter) { typedef std::chrono::microseconds us; std::chrono::time_point start_time = clock::now(); for (auto idx = 0; idx < iter; ++idx) { - auto step_callbacks = at::getStepCallbacks(at::RecordScope::USER_SCOPE); - if (!step_callbacks.empty()) { - at::RecordFunction guard(std::move(step_callbacks)); + auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::USER_SCOPE); + if (step_callbacks.has_value()) { + at::RecordFunction guard(std::move(*step_callbacks)); guard.before("Test", -1); } } diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index dfeb1c973df5..9c18eced91e0 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -151,9 +151,9 @@ struct TORCH_API Node : std::enable_shared_from_this { // probably operate with names. at::NoNamesGuard no_names_guard; - auto step_callbacks = at::getStepCallbacks(at::RecordScope::BACKWARD_FUNCTION); - if (!step_callbacks.empty()) { - at::RecordFunction guard(std::move(step_callbacks)); + auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION); + if (C10_UNLIKELY(step_callbacks.has_value())) { + at::RecordFunction guard(std::move(*step_callbacks)); // Using sequence number and thread id to correlate with // the forward pass function guard.setForwardThreadId(thread_id_); diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 81e6b951c598..d7a6465c8d7b 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -845,11 +845,11 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { static void checkAndStartRecordFunction(Frame& frame, Stack& stack) { if (!frame.record_function) { - auto step_callbacks = - at::getStepCallbacks(at::RecordScope::TORCHSCRIPT_FUNCTION); - if (!step_callbacks.empty()) { + auto step_callbacks = at::getStepCallbacksUnlessEmpty( + at::RecordScope::TORCHSCRIPT_FUNCTION); + if (C10_UNLIKELY(step_callbacks.has_value())) { auto rec_fn = - std::make_unique(std::move(step_callbacks)); + std::make_unique(std::move(*step_callbacks)); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive()); if (rec_fn->needsInputs()) { rec_fn->before( diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index b58c90029aac..929d40fe4deb 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -1201,9 +1201,9 @@ c10::IValue BlockRunner::run_impl_record_functions( IValueList&& args, const KeywordArgs& kwargs) { auto step_callbacks = - at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_MODEL); - if (!step_callbacks.empty()) { - at::RecordFunction guard(std::move(step_callbacks)); + at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL); + if (C10_UNLIKELY(step_callbacks.has_value())) { + at::RecordFunction guard(std::move(*step_callbacks)); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive()); guard.needsInputs() ? guard.before( @@ -1845,9 +1845,9 @@ std::vector ProcessedNode::inputs_ivalue_vec() const { void ProcessedNode::run() { #ifndef PYTORCH_DISABLE_PER_OP_PROFILING auto step_callbacks = - at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_OP); - if (!step_callbacks.empty()) { - at::RecordFunction guard(std::move(step_callbacks)); + at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_OP); + if (C10_UNLIKELY(step_callbacks.has_value())) { + at::RecordFunction guard(std::move(*step_callbacks)); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive()); if (guard.needsInputs()) { const auto inputs = inputs_ivalue_vec();