mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[RecordFunction] More effecient machinery to determine which callbacks to run. (#75807)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75807 There is a tension in RecordFunction between two use cases: 1) In the normal eager path we don't run any callbacks, so we need to bail out of the profiling path as soon as possible to minimize eager overhead. 2) When profiling we want to determine which callbacks to run as efficiently as possible to minimize instrumentation overhead. The confounding factor in all of this is sampling callbacks because they change which callbacks will run on each call, even in steady state operation. This has traditionally been handled with a two stage procedure: first we flip a coin to determine if a sampled callback *might* run. If false (which it usually is), do nothing. This solves (1). If true, check to see if we need to build the full callback set or if it was a false positive. This procedure has two negative effects: * It forces us to rebuild the set of callbacks to run on every step when profiling * It leaks the sampling abstraction, requiring other parts of the code to bump certain values and forces RecordFunction to lazily initialize. This change introduces a multi-level cache which can (in the common case) quickly determine which callbacks *will* run, rather than if callbacks *might* run. This means that rather than call `shouldRunRecordFunction`, we can simply get the callbacks for an invocation and check if they are empty. (And completely removes the pre-sampling heuristic.) Another major benefit of the new cache structure is that it allows thread-safe registration and unregistration of global callbacks. It's worth briefly discussing how this maintains eager performance. In the standard eager case (only sampling callbacks registered) the cache first checks that the global callbacks haven't changed (atomic read), decrements a counter to see if a sampling callback fired, and then returns the active callbacks which is simply a SmallVector of pointer pairs and a couple POD values (scope, needs inputs/outputs/ids). The biggest cost according to perf is the SmallVector logic; we could consider adopting a hard limit on active callbacks; more than half a dozen callbacks *running* in a single step would be quite a lot. But the total cost relative to `PYTORCH_DISABLE_PER_OP_PROFILING` is only ~10ns, so debatable if it's worth it to switch to `std::array`. The primary change is in `record_function.cpp`, which has a more detailed description of the new cache structure. `record_function.h` has some minor changes to align with the new calling convention and the remaining files are simply changes to the call sites. Future work: * RecordFunction no longer needs to be lazily initialized. * We can deprecate the disable/reenable APIs, since we can not safely add and remove global callbacks. Test Plan: I tested eager mode performance using the overhead benchmark and found that the non-profiled path was unaffected. However the no-op observer dropped from 0.41us to 0.37us (0.25us if no observers are active) which is about 1/3rd reduction in the cost of the callback selection machinery. I also added several C++ unit tests, as the core RecordFunction machinery (especially sampling) was largely untested. Reviewed By: swolchok, davidberard98 Differential Revision: D35276158 fbshipit-source-id: 35135f444724fba4eb97c0ae7f3f710f0f9016fd (cherry picked from commit 9e359b87422c18f2a195185f32e7e85c82f956fd)
This commit is contained in:
committed by
PyTorch MergeBot
parent
6ac2ce9abc
commit
a5e338a826
@ -19,7 +19,6 @@ ThreadLocalState::ThreadLocalState()
|
||||
|
||||
saved_tensors_default_hooks_ = at::SavedTensorDefaultHooks::get_stack();
|
||||
|
||||
bumped_record_all_functions_ = at::checkRecordAllFunctions();
|
||||
python_mode_state_ = at::impl::PythonModeTLS::get_state();
|
||||
}
|
||||
|
||||
|
@ -63,9 +63,6 @@ class TORCH_API ThreadLocalState {
|
||||
// TLS for saved tensors default hooks
|
||||
std::stack<std::pair<PyObject*, PyObject*>> saved_tensors_default_hooks_;
|
||||
|
||||
// Whether pre-sampling RecordFunction optimization was enabled
|
||||
bool bumped_record_all_functions_ = false;
|
||||
|
||||
friend class ThreadLocalStateGuard;
|
||||
};
|
||||
|
||||
@ -73,21 +70,7 @@ class TORCH_API ThreadLocalState {
|
||||
class TORCH_API ThreadLocalStateGuard {
|
||||
public:
|
||||
explicit ThreadLocalStateGuard(const ThreadLocalState& state)
|
||||
: prev_state_(ThreadLocalState()),
|
||||
bumped_record_all_functions_(state.bumped_record_all_functions_) {
|
||||
// Special handling of RecordFunction pre-sampling optimization:
|
||||
// pre-samping is enabled (bumped) when there're non-sampled
|
||||
// (or high-frequency) global or TLS callbacks.
|
||||
//
|
||||
// ThreadLocalStateGuard simply resets RecordFunction's TLS and
|
||||
// hence its thread local callbacks.
|
||||
//
|
||||
// Checking if the pre-sampling was enabled and preserving it in the
|
||||
// async task by calling bumpRecordAllFunctions() and the corresponding
|
||||
// releaseRecordAllFunctions()
|
||||
if (bumped_record_all_functions_) {
|
||||
at::bumpRecordAllFunctions();
|
||||
}
|
||||
: prev_state_(ThreadLocalState()) {
|
||||
// set the given state across the thread boundary
|
||||
ThreadLocalState::setThreadLocalState(state);
|
||||
}
|
||||
@ -95,15 +78,10 @@ class TORCH_API ThreadLocalStateGuard {
|
||||
~ThreadLocalStateGuard() {
|
||||
// restore previously set variables
|
||||
ThreadLocalState::setThreadLocalState(prev_state_);
|
||||
if (bumped_record_all_functions_) {
|
||||
at::releaseRecordAllFunctions();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const ThreadLocalState prev_state_;
|
||||
// Whether pre-sampling RecordFunction optimization was enabled
|
||||
bool bumped_record_all_functions_ = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -152,7 +152,7 @@ public:
|
||||
|
||||
|
||||
template<class Return, class... Args>
|
||||
static Return callWithDispatchKeySlowPath(const TypedOperatorHandle<Return (Args...)>& op, bool pre_sampled, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args);
|
||||
static Return callWithDispatchKeySlowPath(const TypedOperatorHandle<Return (Args...)>& op, at::StepCallbacks& stepCallbacks, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args);
|
||||
|
||||
// Like call, but intended for use in a redispatch in kernels that have explicitly performed the DispatchKey update calculatulation.
|
||||
// This will take the DispatchKeySet completely as is and dispatch to the kernel of the corresponding highest priority key in the set.
|
||||
@ -494,22 +494,17 @@ struct CaptureKernelCall<void> {
|
||||
|
||||
// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
|
||||
template<class Return, class... Args>
|
||||
inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<Return(Args...)>& op, bool pre_sampled, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args) {
|
||||
// Check if we need to run callbacks registered with RecordFunction
|
||||
// If true and callbacks need inputs, we box the arguments and pass
|
||||
// them into the callbacks and also into the kernel call
|
||||
|
||||
// Note: for perf reasons we wouldn't want to pass arguments into
|
||||
// the function call or prematurely box them
|
||||
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
|
||||
if (C10_UNLIKELY(guard.isActive())) {
|
||||
inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<Return(Args...)>& op, at::StepCallbacks& stepCallbacks, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args) {
|
||||
// If callbacks need inputs, we box the arguments and pass them to the guard.
|
||||
// Note: For perf reasons we wouldn't want to prematurely box the arguments.
|
||||
at::RecordFunction guard(std::move(stepCallbacks));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.operatorDef_->op.isObserved());
|
||||
auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
|
||||
if (op.operatorDef_->op.isObserved()) {
|
||||
if (guard.needsInputs()) {
|
||||
runRecordFunction(guard, op, dispatchKey, impl::boxArgs(args...));
|
||||
} else {
|
||||
runRecordFunction(guard, op, dispatchKey);
|
||||
}
|
||||
guard.needsInputs()
|
||||
? runRecordFunction(guard, op, dispatchKey, impl::boxArgs(args...))
|
||||
: runRecordFunction(guard, op, dispatchKey);
|
||||
|
||||
if (C10_UNLIKELY(guard.needsOutputs())) {
|
||||
// Calls the kernel and capture the output temporarily to pass to
|
||||
// RecordFunction.
|
||||
@ -519,8 +514,7 @@ inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<
|
||||
// Releases the captured output to return to caller.
|
||||
return std::move(captureKernelCall).release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// keeping the guard alive while executing the kernel
|
||||
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
|
||||
}
|
||||
@ -533,15 +527,9 @@ C10_DISPATCHER_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorH
|
||||
.template getDispatchKeySetUnboxed<Args...>(args...);
|
||||
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
|
||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
// By default, when there're no high-frequency or non-sampled callbacks,
|
||||
// RecordFunction is pre-sampled as a perf optimization;
|
||||
// shouldRunRecordFunction checks whether RecordFunction should be executed,
|
||||
// and sets pre_sampled boolean argument value to whether pre-sampling was used -
|
||||
// this boolean is passed into RecordFunction to adjust the sampling rates of
|
||||
// the callbacks
|
||||
bool pre_sampled = false;
|
||||
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
|
||||
return callWithDispatchKeySlowPath<Return, Args...>(op, pre_sampled, dispatchKeySet, kernel, std::forward<Args>(args)...);
|
||||
auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION);
|
||||
if (C10_UNLIKELY(!step_callbacks.empty() && op.operatorDef_->op.isObserved())) {
|
||||
return callWithDispatchKeySlowPath<Return, Args...>(op, step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...);
|
||||
}
|
||||
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
|
||||
@ -562,25 +550,18 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
|
||||
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
|
||||
const auto& kernel = entry.lookup(dispatchKeySet);
|
||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
bool pre_sampled = false;
|
||||
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
|
||||
// using already existing stack to record function execution in observers
|
||||
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
|
||||
if (C10_UNLIKELY(guard.isActive())) {
|
||||
auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION);
|
||||
if (C10_UNLIKELY(!step_callbacks.empty() && entry.isObserved())) {
|
||||
at::RecordFunction guard(std::move(step_callbacks));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
|
||||
auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
|
||||
if (entry.isObserved()) {
|
||||
if (guard.needsInputs()) {
|
||||
runRecordFunction(guard, op, dispatchKey, *stack);
|
||||
} else {
|
||||
runRecordFunction(guard, op, dispatchKey);
|
||||
}
|
||||
}
|
||||
}
|
||||
guard.needsInputs() ? runRecordFunction(guard, op, dispatchKey, *stack)
|
||||
: runRecordFunction(guard, op, dispatchKey);
|
||||
|
||||
// keeping the guard alive while executing the kernel
|
||||
kernel.callBoxed(op, dispatchKeySet, stack);
|
||||
// track outputs
|
||||
if (C10_UNLIKELY(
|
||||
guard.isActive() && entry.isObserved() && guard.needsOutputs())) {
|
||||
|
||||
if (C10_UNLIKELY(guard.needsOutputs())) {
|
||||
guard.setOutputs(*stack);
|
||||
}
|
||||
return;
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -201,7 +201,6 @@ class TORCH_API RecordFunctionCallback {
|
||||
}
|
||||
|
||||
private:
|
||||
friend class CallbackManager;
|
||||
StartCallback start_;
|
||||
EndCallback end_;
|
||||
double sampling_prob_ = 1.0;
|
||||
@ -231,51 +230,56 @@ class TORCH_API RecordFunctionCallback {
|
||||
// execution tracer
|
||||
// - note, thread local callbacks are automatically propagated with
|
||||
// ThreadLocalState across JIT continuations and async tasks (at::launch)
|
||||
// - adding/removing global callbacks is not thread safe and should be done
|
||||
// only when no other code is running, e.g. during the initialization
|
||||
|
||||
typedef uint64_t CallbackHandle;
|
||||
|
||||
// It is unnecessary to use atomic operations for enabling
|
||||
// thread-local function callbacks. Moreover, it prevents saving to
|
||||
// ThreadLocalState because std::atomic is non-copyable.
|
||||
struct ThreadLocalRecordFunctionCallbacksEntry {
|
||||
RecordFunctionCallback callback;
|
||||
bool enabled = true;
|
||||
CallbackHandle handle;
|
||||
struct RecordFunctionCallbacksEntry {
|
||||
RecordFunctionCallbacksEntry(RecordFunctionCallback&& cb, CallbackHandle h)
|
||||
: callback_(cb), handle_(h) {}
|
||||
|
||||
ThreadLocalRecordFunctionCallbacksEntry(RecordFunctionCallback&& cb, CallbackHandle h)
|
||||
: callback(std::move(cb)), handle(h) {}
|
||||
|
||||
bool disable() {
|
||||
auto old = enabled;
|
||||
enabled = false;
|
||||
return old != enabled;
|
||||
}
|
||||
|
||||
bool enable() {
|
||||
auto old = enabled;
|
||||
enabled = true;
|
||||
return old != enabled;
|
||||
}
|
||||
|
||||
bool isEnabled() const {
|
||||
return enabled;
|
||||
}
|
||||
RecordFunctionCallback callback_;
|
||||
bool enabled_{true};
|
||||
CallbackHandle handle_;
|
||||
};
|
||||
|
||||
// Holds pairs (callbacks, unique_id)
|
||||
using ThreadLocalRecordFunctionCallbacks =
|
||||
std::vector<ThreadLocalRecordFunctionCallbacksEntry>;
|
||||
using RecordFunctionCallbacks = std::vector<RecordFunctionCallbacksEntry>;
|
||||
|
||||
// Generated by the callback managers to determine which functions to run.
|
||||
struct StepCallbacks {
|
||||
StepCallbacks() = default;
|
||||
StepCallbacks(uint64_t thread_id, RecordScope scope)
|
||||
: thread_id_{thread_id}, scope_{scope} {}
|
||||
|
||||
bool empty() const {
|
||||
return callbacks_.empty();
|
||||
}
|
||||
|
||||
struct StartEndPair {
|
||||
RecordFunctionCallback::StartCallback start_;
|
||||
RecordFunctionCallback::EndCallback end_;
|
||||
};
|
||||
|
||||
using StartEndPairs = c10::SmallVector<StartEndPair, kSoftLimitCallbacks>;
|
||||
|
||||
StartEndPairs callbacks_;
|
||||
uint64_t thread_id_{0};
|
||||
RecordScope scope_{RecordScope::FUNCTION};
|
||||
bool needs_inputs_{false};
|
||||
bool needs_outputs_{false};
|
||||
bool needs_ids_{false};
|
||||
};
|
||||
|
||||
struct TORCH_API RecordFunction {
|
||||
// Default constructor is used with before function called afterwards:
|
||||
// scope - record scope that this function tracks
|
||||
// pre_sampled - whether this RecordFunction was already pre-sampled with
|
||||
// kLowProb probability
|
||||
RecordFunction(
|
||||
RecordScope scope = RecordScope::FUNCTION,
|
||||
bool pre_sampled = false);
|
||||
explicit RecordFunction(RecordScope scope = RecordScope::FUNCTION);
|
||||
explicit RecordFunction(StepCallbacks&& step_callbacks);
|
||||
|
||||
template <typename F>
|
||||
void before(
|
||||
@ -340,7 +344,7 @@ struct TORCH_API RecordFunction {
|
||||
// executed in a different thread (async ops)
|
||||
uint64_t threadId() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called threadId() on inactive RecordFunction");
|
||||
return state_->thread_id_;
|
||||
return state_->step_callbacks_.thread_id_;
|
||||
}
|
||||
|
||||
// For backward functions - thread id of the corresponding forward function,
|
||||
@ -359,7 +363,7 @@ struct TORCH_API RecordFunction {
|
||||
|
||||
RecordScope scope() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called scope() on inactive RecordFunction");
|
||||
return state_->scope_;
|
||||
return state_->step_callbacks_.scope_;
|
||||
}
|
||||
|
||||
// Returns logical thread_id for the current thread
|
||||
@ -434,12 +438,12 @@ struct TORCH_API RecordFunction {
|
||||
|
||||
bool needsInputs() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called needsInputs() on inactive RecordFunction");
|
||||
return state_->needs_inputs;
|
||||
return state_->step_callbacks_.needs_inputs_;
|
||||
}
|
||||
|
||||
bool needsOutputs() const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(state_, "Called needsOutputs() on inactive RecordFunction");
|
||||
return state_->needs_outputs;
|
||||
return state_->step_callbacks_.needs_outputs_;
|
||||
}
|
||||
|
||||
int64_t debugHandle() const {
|
||||
@ -453,39 +457,21 @@ struct TORCH_API RecordFunction {
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
// Allows the modification of some internal states for callbacks.
|
||||
friend class CallbackManager;
|
||||
void runStartCallbacks();
|
||||
|
||||
struct State {
|
||||
explicit State(RecordScope scope) : scope_(scope) {}
|
||||
explicit State(StepCallbacks&& step_callbacks)
|
||||
: step_callbacks_{std::move(step_callbacks)} {}
|
||||
|
||||
// Whether any of the picked callbacks require inputs
|
||||
bool needs_inputs = false;
|
||||
|
||||
// Whether any of the picked callbacks require outputs
|
||||
bool needs_outputs = false;
|
||||
StepCallbacks step_callbacks_;
|
||||
|
||||
// In cases when RecordFunction might be active but we chose not to
|
||||
// use the observers (e.g. operator is not observed), this boolean
|
||||
// flag is used to check whether the start callbacks were called
|
||||
bool called_start_callbacks_ = false;
|
||||
|
||||
// Whether the RecordFunction is pre-sampled
|
||||
bool pre_sampled_ = false;
|
||||
|
||||
// Used internally to keep track of thread local and global callbacks
|
||||
// that were picked to run; must be sorted;
|
||||
CallbackHandles sorted_active_tls_handles_;
|
||||
CallbackHandles sorted_active_global_handles_;
|
||||
|
||||
// Stores various ObserverContext objects with event metadata for thread local
|
||||
// callbacks.
|
||||
ObserverContextList tls_ctx_;
|
||||
|
||||
// Stores various ObserverContext objects with event metadata for global
|
||||
// callbacks.
|
||||
ObserverContextList global_ctx_;
|
||||
// Stores various ObserverContext objects with event metadata for callbacks.
|
||||
ObserverContextList ctx_;
|
||||
|
||||
std::string name_;
|
||||
int64_t sequence_nr_ = -1;
|
||||
@ -496,12 +482,6 @@ struct TORCH_API RecordFunction {
|
||||
size_t op_input_size{0};
|
||||
size_t op_output_size{0};
|
||||
|
||||
// Kind of scope this RecordFunction is observing
|
||||
const RecordScope scope_;
|
||||
|
||||
// The logical thread_id that this RecordFunction was created with
|
||||
uint64_t thread_id_ = 0;
|
||||
|
||||
// For backward functions - thread id of the the forward function
|
||||
uint64_t fwd_thread_id_ = 0;
|
||||
|
||||
@ -524,6 +504,8 @@ struct TORCH_API RecordFunction {
|
||||
c10::optional<State> state_;
|
||||
};
|
||||
|
||||
TORCH_API StepCallbacks getStepCallbacks(RecordScope scope);
|
||||
|
||||
// Using macro to minimize inputs copies,
|
||||
// optional argument - function's seq_no
|
||||
#define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \
|
||||
@ -596,7 +578,6 @@ TORCH_API void clearThreadLocalCallbacks();
|
||||
/**
|
||||
* addGlobalCallback adds a global callback to run with RecordFunction:
|
||||
*
|
||||
* WARNING: not thread safe, typically addGlobalCallback can be called
|
||||
* only during the program initialization
|
||||
*/
|
||||
TORCH_API CallbackHandle addGlobalCallback(
|
||||
@ -606,7 +587,6 @@ TORCH_API CallbackHandle addGlobalCallback(
|
||||
* removeCallback removes a callback given the handle returned by
|
||||
* addThreadLocalCallback or addGlobalCallback;
|
||||
*
|
||||
* WARNING: removing a global callback is not thread safe,
|
||||
* no other code can run simultaneously
|
||||
*/
|
||||
TORCH_API void removeCallback(CallbackHandle handle);
|
||||
@ -631,13 +611,12 @@ TORCH_API bool hasGlobalCallbacks();
|
||||
|
||||
/**
|
||||
* clearGlobalCallbacks removes all global callbacks
|
||||
* WARNING: not thread safe
|
||||
*/
|
||||
TORCH_API void clearGlobalCallbacks();
|
||||
|
||||
// for both thread local and global callbacks
|
||||
TORCH_API bool hasCallbacks();
|
||||
TORCH_API void clearCallbacks(); // not thread safe
|
||||
TORCH_API void clearCallbacks();
|
||||
|
||||
/**
|
||||
* enableRecordFunction enables RecordFunction thread locally
|
||||
@ -674,30 +653,15 @@ class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard {
|
||||
struct TORCH_API RecordFunctionTLS {
|
||||
// Thread local vector of callbacks, holds pairs (callbacks, unique_id);
|
||||
// must be sorted in increasing handles order
|
||||
ThreadLocalRecordFunctionCallbacks sorted_tls_callbacks_;
|
||||
RecordFunctionCallbacks sorted_tls_callbacks_;
|
||||
|
||||
bool tls_record_function_enabled_ = true;
|
||||
|
||||
// Stores the number of coin flips before the next successful coin flip
|
||||
int tries_left_ = 0;
|
||||
};
|
||||
|
||||
TORCH_API const RecordFunctionTLS& get_record_function_tls_();
|
||||
|
||||
TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls);
|
||||
|
||||
// Checks whether RecordFunction should be called,
|
||||
// sets boolean pointed by the argument to whether pre-sampling was used
|
||||
TORCH_API bool shouldRunRecordFunction(bool*);
|
||||
|
||||
// The following functions are used to disable/enable pre-sampling of RecordFunction
|
||||
// when high-frequency/non-sampled callbacks are added/removed.
|
||||
// Note: every call to bumpRecordAllFunctions() is supposed to be matched with
|
||||
// the corresponding releaseRecordAllFunctions() call.
|
||||
// Note: disabling pre-sampling of RecordFunction incurs an extra overhead, since
|
||||
// RecordFunction will be created for each operator call.
|
||||
TORCH_API void bumpRecordAllFunctions();
|
||||
TORCH_API void releaseRecordAllFunctions();
|
||||
TORCH_API bool checkRecordAllFunctions();
|
||||
TORCH_API void set_record_function_seed_for_testing(uint32_t seed);
|
||||
|
||||
} // namespace at
|
||||
|
@ -49,14 +49,12 @@ float runPureRecordFunctionBench(int iter) {
|
||||
typedef std::chrono::microseconds us;
|
||||
std::chrono::time_point<clock> start_time = clock::now();
|
||||
for (auto idx = 0; idx < iter; ++idx) {
|
||||
bool pre_sampled = false;
|
||||
if (at::shouldRunRecordFunction(&pre_sampled)) {
|
||||
at::RecordFunction guard(at::RecordScope::USER_SCOPE, pre_sampled);
|
||||
if (C10_UNLIKELY(guard.isActive())) {
|
||||
auto step_callbacks = at::getStepCallbacks(at::RecordScope::USER_SCOPE);
|
||||
if (!step_callbacks.empty()) {
|
||||
at::RecordFunction guard(std::move(step_callbacks));
|
||||
guard.before("Test", -1);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto duration = static_cast<float>(
|
||||
std::chrono::duration_cast<us>(clock::now() - start_time).count());
|
||||
return duration;
|
||||
|
307
test/cpp/profiler/record_function.cpp
Normal file
307
test/cpp/profiler/record_function.cpp
Normal file
@ -0,0 +1,307 @@
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/record_function.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
// Test that we can add and remove callbacks (both global and thread local.)
|
||||
TEST(RecordFunctionTest, AddRemove) {
|
||||
at::clearCallbacks();
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
|
||||
auto start_callback =
|
||||
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
return nullptr;
|
||||
};
|
||||
auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {};
|
||||
|
||||
auto handle = at::addThreadLocalCallback(
|
||||
at::RecordFunctionCallback(start_callback, end_callback));
|
||||
|
||||
ASSERT_TRUE(at::hasCallbacks());
|
||||
ASSERT_TRUE(at::hasThreadLocalCallbacks());
|
||||
ASSERT_FALSE(at::hasGlobalCallbacks());
|
||||
|
||||
at::removeCallback(handle);
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
|
||||
handle = at::addGlobalCallback(
|
||||
at::RecordFunctionCallback(start_callback, end_callback));
|
||||
|
||||
ASSERT_TRUE(at::hasCallbacks());
|
||||
ASSERT_FALSE(at::hasThreadLocalCallbacks());
|
||||
ASSERT_TRUE(at::hasGlobalCallbacks());
|
||||
|
||||
at::removeCallback(handle);
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
}
|
||||
|
||||
// Test that the callbacks that we register are actually run.
|
||||
TEST(RecordFunctionTest, ThreadLocalState) {
|
||||
at::clearCallbacks();
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
|
||||
static int tls_test_start_counter;
|
||||
static int tls_test_end_counter;
|
||||
tls_test_start_counter = 0;
|
||||
tls_test_end_counter = 0;
|
||||
|
||||
auto start_callback =
|
||||
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
|
||||
++tls_test_start_counter;
|
||||
return nullptr;
|
||||
};
|
||||
auto end_callback = [](const at::RecordFunction&, at::ObserverContext*) {
|
||||
++tls_test_end_counter;
|
||||
};
|
||||
|
||||
auto handle = at::addThreadLocalCallback(
|
||||
at::RecordFunctionCallback(start_callback, end_callback));
|
||||
|
||||
{
|
||||
at::RecordFunction guard(at::RecordScope::USER_SCOPE);
|
||||
guard.before("Test");
|
||||
EXPECT_EQ(tls_test_start_counter, 1);
|
||||
EXPECT_EQ(tls_test_end_counter, 0);
|
||||
}
|
||||
EXPECT_EQ(tls_test_start_counter, 1);
|
||||
EXPECT_EQ(tls_test_end_counter, 1);
|
||||
|
||||
{
|
||||
tls_test_start_counter = 0;
|
||||
tls_test_end_counter = 0;
|
||||
at::DisableRecordFunctionGuard no_profile_guard;
|
||||
at::RecordFunction guard(at::RecordScope::USER_SCOPE);
|
||||
guard.before("Test");
|
||||
EXPECT_EQ(tls_test_start_counter, 0);
|
||||
EXPECT_EQ(tls_test_end_counter, 0);
|
||||
}
|
||||
EXPECT_EQ(tls_test_start_counter, 0);
|
||||
EXPECT_EQ(tls_test_end_counter, 0);
|
||||
|
||||
{
|
||||
tls_test_start_counter = 0;
|
||||
tls_test_end_counter = 0;
|
||||
RECORD_FUNCTION("Test", {});
|
||||
EXPECT_EQ(tls_test_start_counter, 1);
|
||||
EXPECT_EQ(tls_test_end_counter, 0);
|
||||
}
|
||||
EXPECT_EQ(tls_test_start_counter, 1);
|
||||
EXPECT_EQ(tls_test_end_counter, 1);
|
||||
|
||||
at::removeCallback(handle);
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
}
|
||||
|
||||
// Test that callbacks are run in the order that they are registered.
|
||||
TEST(RecordFunctionTest, CallOrder) {
|
||||
at::clearCallbacks();
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
|
||||
static int current_index;
|
||||
current_index = 0;
|
||||
|
||||
static std::array<std::string, 8> expected_order = {
|
||||
"Start Callback 0 Outer",
|
||||
"Start Callback 1 Outer",
|
||||
"Start Callback 0 Inner",
|
||||
"Start Callback 1 Inner",
|
||||
"End Callback 0 Inner",
|
||||
"End Callback 1 Inner",
|
||||
"End Callback 0 Outer",
|
||||
"End Callback 1 Outer",
|
||||
};
|
||||
|
||||
#define REGISTER_CALLBACK(index) \
|
||||
at::addThreadLocalCallback( \
|
||||
at::RecordFunctionCallback( \
|
||||
[](const at::RecordFunction& fn) \
|
||||
-> std::unique_ptr<at::ObserverContext> { \
|
||||
EXPECT_EQ( \
|
||||
fmt::format("Start Callback {} {}", index, fn.name()), \
|
||||
expected_order[current_index++]); \
|
||||
return nullptr; \
|
||||
}, \
|
||||
[](const at::RecordFunction& fn, at::ObserverContext*) { \
|
||||
EXPECT_EQ( \
|
||||
fmt::format("End Callback {} {}", index, fn.name()), \
|
||||
expected_order[current_index++]); \
|
||||
}) \
|
||||
.scopes({at::RecordScope::FUNCTION}))
|
||||
|
||||
REGISTER_CALLBACK(0);
|
||||
REGISTER_CALLBACK(1);
|
||||
#undef REGISTER_CALLBACK
|
||||
|
||||
RECORD_FUNCTION("Outer", {});
|
||||
{ RECORD_FUNCTION("Inner", {}); }
|
||||
|
||||
at::clearCallbacks();
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
}
|
||||
|
||||
// Make sure TLS migrates when tasks are launched.
|
||||
TEST(RecordFunctionTest, ThreadMigration) {
|
||||
at::clearCallbacks();
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
|
||||
static int call_count;
|
||||
call_count = 0;
|
||||
|
||||
auto handle = at::addThreadLocalCallback(
|
||||
at::RecordFunctionCallback(
|
||||
[](const at::RecordFunction&)
|
||||
-> std::unique_ptr<at::ObserverContext> { return nullptr; },
|
||||
[](const at::RecordFunction&, at::ObserverContext*) {
|
||||
++call_count;
|
||||
})
|
||||
.scopes({at::RecordScope::FUNCTION}));
|
||||
|
||||
EXPECT_EQ(call_count, 0);
|
||||
|
||||
std::condition_variable cv;
|
||||
std::mutex lock;
|
||||
at::launch([&cv]() {
|
||||
RECORD_FUNCTION("Test", {});
|
||||
cv.notify_all();
|
||||
});
|
||||
auto guard = std::unique_lock<std::mutex>(lock);
|
||||
cv.wait(guard, []{ return call_count > 0; });
|
||||
|
||||
EXPECT_EQ(call_count, 1);
|
||||
|
||||
at::removeCallback(handle);
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
}
|
||||
|
||||
// Test sampling logic and validate that callbacks fire at the correct times.
|
||||
TEST(RecordFunctionTest, Sampling) {
|
||||
at::clearCallbacks();
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
|
||||
static int sample_test_counter;
|
||||
sample_test_counter = 0;
|
||||
|
||||
uint32_t seed = 12345;
|
||||
double p = 0.25;
|
||||
|
||||
at::set_record_function_seed_for_testing(seed);
|
||||
std::mt19937 generator;
|
||||
generator.seed(seed);
|
||||
auto dist = std::geometric_distribution<int>(p);
|
||||
|
||||
// Make sure we know which steps should fire.
|
||||
auto outcomes = std::array<int, 5>{7, 0, 0, 6, 2};
|
||||
for (const auto i : c10::irange(outcomes.size())) {
|
||||
ASSERT_EQ(dist(generator), outcomes[i]);
|
||||
}
|
||||
|
||||
std::vector<int> expected_counts;
|
||||
int running_count = 0;
|
||||
for (const auto i : c10::irange(outcomes.size())) {
|
||||
for (const auto j : c10::irange(outcomes[i])) {
|
||||
expected_counts.push_back(running_count);
|
||||
}
|
||||
expected_counts.push_back(++running_count);
|
||||
}
|
||||
|
||||
auto start_callback =
|
||||
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
++sample_test_counter;
|
||||
return nullptr;
|
||||
};
|
||||
auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {};
|
||||
|
||||
auto handle = at::addThreadLocalCallback(
|
||||
at::RecordFunctionCallback(start_callback, end_callback)
|
||||
.samplingProb(p)
|
||||
.scopes({at::RecordScope::FUNCTION}));
|
||||
|
||||
for (const auto i : c10::irange(expected_counts.size())) {
|
||||
RECORD_FUNCTION("Test", {});
|
||||
EXPECT_EQ(sample_test_counter, expected_counts[i]);
|
||||
}
|
||||
|
||||
at::removeCallback(handle);
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
}
|
||||
|
||||
// Validate sampling against a simple reference implementation for a complex set
|
||||
// of registered callbacks.
|
||||
TEST(RecordFunctionTest, MultipleCallbacks) {
|
||||
at::clearCallbacks();
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
|
||||
uint32_t seed = 54321;
|
||||
|
||||
std::mt19937 generator;
|
||||
generator.seed(seed);
|
||||
|
||||
auto sample = [&](double p) {
|
||||
return (p < 1.0 ? std::geometric_distribution<int>(p)(generator) : 0) + 1;
|
||||
};
|
||||
|
||||
std::array<double, 4> probabilities{0.1, 1.0, 1.0, 0.3};
|
||||
std::array<int, 4> next_call;
|
||||
std::array<int, 4> counts;
|
||||
static std::array<int, 4> counts_from_rec_fn;
|
||||
counts_from_rec_fn.fill(0);
|
||||
|
||||
auto start_callback_0 =
|
||||
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
++counts_from_rec_fn[0];
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {};
|
||||
|
||||
#define REGISTER_CALLBACK(register_fn, index) \
|
||||
register_fn(at::RecordFunctionCallback( \
|
||||
[](const at::RecordFunction& fn) \
|
||||
-> std::unique_ptr<at::ObserverContext> { \
|
||||
++counts_from_rec_fn[index]; \
|
||||
return nullptr; \
|
||||
}, \
|
||||
end_callback) \
|
||||
.samplingProb(probabilities[index]) \
|
||||
.scopes({at::RecordScope::FUNCTION}))
|
||||
|
||||
REGISTER_CALLBACK(at::addGlobalCallback, 0);
|
||||
REGISTER_CALLBACK(at::addGlobalCallback, 1);
|
||||
REGISTER_CALLBACK(at::addThreadLocalCallback, 2);
|
||||
|
||||
// The RecordFunction machinery will rebuild callbacks whenever a new observer
|
||||
// is registered, so we need to wait until the last callback to seed the
|
||||
// random number generator.
|
||||
at::set_record_function_seed_for_testing(seed);
|
||||
REGISTER_CALLBACK(at::addThreadLocalCallback, 3);
|
||||
#undef REGISTER_CALLBACK
|
||||
|
||||
for (const auto i : c10::irange(probabilities.size())) {
|
||||
next_call[i] = sample(probabilities[i]);
|
||||
}
|
||||
|
||||
for (const auto i : c10::irange(50)) {
|
||||
RECORD_FUNCTION("Test", {});
|
||||
for (const auto j : c10::irange(next_call.size())) {
|
||||
if (!(--next_call[j])) {
|
||||
++counts[j];
|
||||
next_call[j] = sample(probabilities[j]);
|
||||
}
|
||||
EXPECT_EQ(counts[j], counts_from_rec_fn[j]);
|
||||
}
|
||||
}
|
||||
|
||||
at::clearCallbacks();
|
||||
ASSERT_FALSE(at::hasCallbacks());
|
||||
}
|
@ -151,11 +151,10 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
// probably operate with names.
|
||||
at::NoNamesGuard no_names_guard;
|
||||
|
||||
bool pre_sampled = false;
|
||||
if (at::shouldRunRecordFunction(&pre_sampled)) {
|
||||
// Using RecordFunction to trigger observers in the backward pass
|
||||
at::RecordFunction guard(at::RecordScope::BACKWARD_FUNCTION, pre_sampled);
|
||||
if (guard.isActive()) {
|
||||
auto step_callbacks = at::getStepCallbacks(at::RecordScope::BACKWARD_FUNCTION);
|
||||
if (!step_callbacks.empty()) {
|
||||
at::RecordFunction guard(std::move(step_callbacks));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
|
||||
// Using sequence number and thread id to correlate with
|
||||
// the forward pass function
|
||||
guard.setForwardThreadId(thread_id_);
|
||||
@ -167,8 +166,6 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
} else {
|
||||
guard.before(name(), sequence_nr());
|
||||
}
|
||||
}
|
||||
// keeping stack guard object alive during the call
|
||||
return apply(std::move(inputs));
|
||||
} else {
|
||||
return apply(std::move(inputs));
|
||||
|
@ -7,14 +7,12 @@ namespace torch {
|
||||
namespace autograd {
|
||||
namespace profiler {
|
||||
|
||||
struct PythonRecordFunction: public torch::CustomClassHolder {
|
||||
struct PythonRecordFunction : public torch::CustomClassHolder {
|
||||
at::RecordFunction record;
|
||||
|
||||
PythonRecordFunction(
|
||||
at::RecordScope scope = at::RecordScope::FUNCTION,
|
||||
bool pre_sampled = false)
|
||||
: record(scope, pre_sampled)
|
||||
{}
|
||||
explicit PythonRecordFunction(
|
||||
at::RecordScope scope = at::RecordScope::FUNCTION)
|
||||
: record(scope) {}
|
||||
};
|
||||
|
||||
// Creates a new profiling scope using RecordFunction and invokes its starting
|
||||
|
@ -841,12 +841,13 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
}
|
||||
|
||||
static void checkAndStartRecordFunction(Frame& frame, Stack& stack) {
|
||||
bool pre_sampled = false;
|
||||
if (!frame.record_function && at::hasCallbacks() &&
|
||||
at::shouldRunRecordFunction(&pre_sampled)) {
|
||||
auto rec_fn = std::make_unique<at::RecordFunction>(
|
||||
at::RecordScope::TORCHSCRIPT_FUNCTION, pre_sampled);
|
||||
if (rec_fn->isActive()) {
|
||||
if (!frame.record_function) {
|
||||
auto step_callbacks =
|
||||
at::getStepCallbacks(at::RecordScope::TORCHSCRIPT_FUNCTION);
|
||||
if (!step_callbacks.empty()) {
|
||||
auto rec_fn =
|
||||
std::make_unique<at::RecordFunction>(std::move(step_callbacks));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive());
|
||||
if (rec_fn->needsInputs()) {
|
||||
rec_fn->before(
|
||||
frame.function->function_name_,
|
||||
|
@ -1196,17 +1196,14 @@ template <typename IValueList>
|
||||
c10::IValue BlockRunner::run_impl_record_functions(
|
||||
IValueList&& args,
|
||||
const KeywordArgs& kwargs) {
|
||||
bool pre_sampled = false;
|
||||
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
|
||||
at::RecordFunction guard(
|
||||
at::RecordScope::STATIC_RUNTIME_MODEL, pre_sampled);
|
||||
if (guard.isActive()) {
|
||||
if (guard.needsInputs()) {
|
||||
guard.before("forward", &args);
|
||||
} else {
|
||||
guard.before("forward");
|
||||
}
|
||||
}
|
||||
auto step_callbacks =
|
||||
at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_MODEL);
|
||||
if (!step_callbacks.empty()) {
|
||||
at::RecordFunction guard(std::move(step_callbacks));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
|
||||
guard.needsInputs() ? guard.before("forward", &args)
|
||||
: guard.before("forward");
|
||||
|
||||
return run_impl(std::forward<IValueList>(args), kwargs);
|
||||
}
|
||||
return run_impl(std::forward<IValueList>(args), kwargs);
|
||||
@ -1841,16 +1838,14 @@ std::vector<IValue> ProcessedNode::inputs_ivalue_vec() const {
|
||||
|
||||
void ProcessedNode::run() {
|
||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
bool pre_sampled = false;
|
||||
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
|
||||
at::RecordFunction guard(at::RecordScope::STATIC_RUNTIME_OP, pre_sampled);
|
||||
if (guard.isActive()) {
|
||||
if (guard.needsInputs()) {
|
||||
guard.before(get_op_name(), inputs_ivalue_vec());
|
||||
} else {
|
||||
guard.before(get_op_name());
|
||||
}
|
||||
}
|
||||
auto step_callbacks =
|
||||
at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_MODEL);
|
||||
if (!step_callbacks.empty()) {
|
||||
at::RecordFunction guard(std::move(step_callbacks));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
|
||||
guard.needsInputs() ? guard.before(get_op_name(), inputs_ivalue_vec())
|
||||
: guard.before(get_op_name());
|
||||
|
||||
fn_->run(this);
|
||||
} else {
|
||||
fn_->run(this);
|
||||
|
Reference in New Issue
Block a user