diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 4f3476133834..7f8e1532ae0b 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -545,9 +545,9 @@ C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandl .template getDispatchKeySetUnboxed(args...); const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING - auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION); - if (C10_UNLIKELY(!step_callbacks.empty() && op.operatorDef_->op.isObserved())) { - return callWithDispatchKeySlowPath(op, step_callbacks, dispatchKeySet, kernel, std::forward(args)...); + auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION); + if (C10_UNLIKELY(step_callbacks.has_value() && op.operatorDef_->op.isObserved())) { + return callWithDispatchKeySlowPath(op, *step_callbacks, dispatchKeySet, kernel, std::forward(args)...); } #endif // PYTORCH_DISABLE_PER_OP_PROFILING return kernel.template call(op, dispatchKeySet, std::forward(args)...); @@ -568,9 +568,9 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); const auto& kernel = entry.lookup(dispatchKeySet); #ifndef PYTORCH_DISABLE_PER_OP_PROFILING - auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION); - if (C10_UNLIKELY(!step_callbacks.empty() && entry.isObserved())) { - at::RecordFunction guard(std::move(step_callbacks)); + auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION); + if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) { + at::RecordFunction guard(std::move(*step_callbacks)); auto dispatchKey = dispatchKeySet.highestPriorityTypeId(); auto& schema = op.schema(); auto schema_ref = std::reference_wrapper(schema); diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index 8d9160135cc1..dda05ffc53c2 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -130,6 +130,7 @@ class CacheEntry { // The caller is expected to check `GlobalCallbackManager::get().version()' // and call CacheEntry::update() if necessary. StepCallbacks getActiveCallbacks(); + c10::optional getActiveCallbacksUnlessEmpty(); // Full rebuild. (E.g. during registration) void update(const std::vector& callbacks); @@ -142,6 +143,8 @@ class CacheEntry { int tries_left_{-1}; }; + C10_ALWAYS_INLINE void getActiveCallbacksImpl(); + void rebuildActiveCallbacks(); int sampleTries(double p) const; @@ -169,6 +172,7 @@ class LocalCallbackManager { public: const RecordFunctionTLS& getTLS() const; StepCallbacks getActiveCallbacks(const RecordScope scope); + c10::optional getActiveCallbacksUnlessEmpty(const RecordScope scope); void setTLS(const RecordFunctionTLS& tls); void seed(uint32_t seed); @@ -178,6 +182,8 @@ class LocalCallbackManager { void clearCallbacks(); private: + void rebuildActiveCallbacksIfNeeded(); + void rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot); void rebuild_callback_scopes( @@ -271,7 +277,7 @@ void CacheEntry::update(const std::vector& callbacks) { rebuildActiveCallbacks(); } -StepCallbacks CacheEntry::getActiveCallbacks() { +void CacheEntry::getActiveCallbacksImpl() { // We rebuild the active set when `sampling_countdown_` reaches zero, so if it // reaches zero at the start of this function something has gone wrong. TORCH_INTERNAL_ASSERT(sampling_countdown_ > 0, sampling_countdown_); @@ -295,7 +301,18 @@ StepCallbacks CacheEntry::getActiveCallbacks() { } } } +} +StepCallbacks CacheEntry::getActiveCallbacks() { + getActiveCallbacksImpl(); + return active_callbacks_; +} + +c10::optional CacheEntry::getActiveCallbacksUnlessEmpty() { + getActiveCallbacksImpl(); + if (C10_LIKELY(active_callbacks_.empty())) { + return c10::nullopt; + } return active_callbacks_; } @@ -365,15 +382,25 @@ const RecordFunctionTLS& LocalCallbackManager::getTLS() const { return registered_callbacks_; } -StepCallbacks LocalCallbackManager::getActiveCallbacks( - const RecordScope scope) { +void LocalCallbackManager::rebuildActiveCallbacksIfNeeded() { const auto global_version = GlobalCallbackManager::get().version(); if (C10_UNLIKELY(global_version != global_version_)) { rebuild_all(GlobalCallbackManager::get().getSnapshot()); } +} + +StepCallbacks LocalCallbackManager::getActiveCallbacks( + const RecordScope scope) { + rebuildActiveCallbacksIfNeeded(); return active_callbacks_[static_cast(scope)].getActiveCallbacks(); } +c10::optional LocalCallbackManager::getActiveCallbacksUnlessEmpty( + const RecordScope scope) { + rebuildActiveCallbacksIfNeeded(); + return active_callbacks_[static_cast(scope)].getActiveCallbacksUnlessEmpty(); +} + void LocalCallbackManager::setTLS(const RecordFunctionTLS& tls) { registered_callbacks_ = tls; rebuild_all(GlobalCallbackManager::get().getSnapshot()); @@ -572,6 +599,10 @@ StepCallbacks getStepCallbacks(RecordScope scope) { return LocalCallbackManager::get().getActiveCallbacks(scope); } +c10::optional getStepCallbacksUnlessEmpty(RecordScope scope) { + return LocalCallbackManager::get().getActiveCallbacksUnlessEmpty(scope); +} + const RecordFunctionTLS& get_record_function_tls_() { return LocalCallbackManager::get().getTLS(); } diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index af594f47e789..e1b762a4319a 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -478,6 +478,8 @@ struct TORCH_API RecordFunction { TORCH_API StepCallbacks getStepCallbacks(RecordScope scope); +TORCH_API c10::optional getStepCallbacksUnlessEmpty(RecordScope scope); + namespace detail { template void record_function_with_scope(RecordFunction& guard, F fn, const Inputs& inputs, Args&&... args) { diff --git a/binaries/record_function_benchmark.cc b/binaries/record_function_benchmark.cc index 8d53007bc8ef..6a48052d3638 100644 --- a/binaries/record_function_benchmark.cc +++ b/binaries/record_function_benchmark.cc @@ -1,3 +1,4 @@ + #include #include @@ -49,9 +50,9 @@ float runPureRecordFunctionBench(int iter) { typedef std::chrono::microseconds us; std::chrono::time_point start_time = clock::now(); for (auto idx = 0; idx < iter; ++idx) { - auto step_callbacks = at::getStepCallbacks(at::RecordScope::USER_SCOPE); - if (!step_callbacks.empty()) { - at::RecordFunction guard(std::move(step_callbacks)); + auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::USER_SCOPE); + if (step_callbacks.has_value()) { + at::RecordFunction guard(std::move(*step_callbacks)); guard.before("Test", -1); } } diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index dfeb1c973df5..9c18eced91e0 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -151,9 +151,9 @@ struct TORCH_API Node : std::enable_shared_from_this { // probably operate with names. at::NoNamesGuard no_names_guard; - auto step_callbacks = at::getStepCallbacks(at::RecordScope::BACKWARD_FUNCTION); - if (!step_callbacks.empty()) { - at::RecordFunction guard(std::move(step_callbacks)); + auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION); + if (C10_UNLIKELY(step_callbacks.has_value())) { + at::RecordFunction guard(std::move(*step_callbacks)); // Using sequence number and thread id to correlate with // the forward pass function guard.setForwardThreadId(thread_id_); diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 81e6b951c598..d7a6465c8d7b 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -845,11 +845,11 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { static void checkAndStartRecordFunction(Frame& frame, Stack& stack) { if (!frame.record_function) { - auto step_callbacks = - at::getStepCallbacks(at::RecordScope::TORCHSCRIPT_FUNCTION); - if (!step_callbacks.empty()) { + auto step_callbacks = at::getStepCallbacksUnlessEmpty( + at::RecordScope::TORCHSCRIPT_FUNCTION); + if (C10_UNLIKELY(step_callbacks.has_value())) { auto rec_fn = - std::make_unique(std::move(step_callbacks)); + std::make_unique(std::move(*step_callbacks)); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive()); if (rec_fn->needsInputs()) { rec_fn->before( diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index b58c90029aac..929d40fe4deb 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -1201,9 +1201,9 @@ c10::IValue BlockRunner::run_impl_record_functions( IValueList&& args, const KeywordArgs& kwargs) { auto step_callbacks = - at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_MODEL); - if (!step_callbacks.empty()) { - at::RecordFunction guard(std::move(step_callbacks)); + at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL); + if (C10_UNLIKELY(step_callbacks.has_value())) { + at::RecordFunction guard(std::move(*step_callbacks)); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive()); guard.needsInputs() ? guard.before( @@ -1845,9 +1845,9 @@ std::vector ProcessedNode::inputs_ivalue_vec() const { void ProcessedNode::run() { #ifndef PYTORCH_DISABLE_PER_OP_PROFILING auto step_callbacks = - at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_OP); - if (!step_callbacks.empty()) { - at::RecordFunction guard(std::move(step_callbacks)); + at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_OP); + if (C10_UNLIKELY(step_callbacks.has_value())) { + at::RecordFunction guard(std::move(*step_callbacks)); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive()); if (guard.needsInputs()) { const auto inputs = inputs_ivalue_vec();