mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Extra sampling of record function events (#48289)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48289 Adding extra sampling step when dispatching RecordFunction. (Note: this ignores all push blocking failures!) Reviewed By: swolchok Differential Revision: D25111515 Pulled By: ilia-cher fbshipit-source-id: 0d572a3636fe649a47ec47901826bbfc08368937
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a20d4511e4
commit
09b974c2d5
@ -19,6 +19,7 @@ ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
|
||||
grad_mode_enabled_ = GradMode::is_enabled();
|
||||
}
|
||||
#endif
|
||||
bumped_record_all_functions_ = at::checkRecordAllFunctions();
|
||||
}
|
||||
|
||||
/* static */
|
||||
|
@ -38,6 +38,9 @@ class TORCH_API ThreadLocalState {
|
||||
bool grad_mode_enabled_;
|
||||
#endif
|
||||
|
||||
// Whether pre-sampling RecordFunction optimization was enabled
|
||||
bool bumped_record_all_functions_ = false;
|
||||
|
||||
friend class ThreadLocalStateGuard;
|
||||
};
|
||||
|
||||
@ -45,7 +48,21 @@ class TORCH_API ThreadLocalState {
|
||||
class TORCH_API ThreadLocalStateGuard {
|
||||
public:
|
||||
explicit ThreadLocalStateGuard(const ThreadLocalState& state)
|
||||
: prev_state_(ThreadLocalState()) {
|
||||
: prev_state_(ThreadLocalState()),
|
||||
bumped_record_all_functions_(state.bumped_record_all_functions_) {
|
||||
// Special handling of RecordFunction pre-sampling optimization:
|
||||
// pre-samping is enabled (bumped) when there're non-sampled
|
||||
// (or high-frequency) global or TLS callbacks.
|
||||
//
|
||||
// ThreadLocalStateGuard simply resets RecordFunction's TLS and
|
||||
// hence its thread local callbacks.
|
||||
//
|
||||
// Checking if the pre-sampling was enabled and preserving it in the
|
||||
// async task by calling bumpRecordAllFunctions() and the corresponding
|
||||
// releaseRecordAllFunctions()
|
||||
if (bumped_record_all_functions_) {
|
||||
at::bumpRecordAllFunctions();
|
||||
}
|
||||
// set the given state across the thread boundary
|
||||
ThreadLocalState::setThreadLocalState(state);
|
||||
}
|
||||
@ -53,10 +70,15 @@ class TORCH_API ThreadLocalStateGuard {
|
||||
~ThreadLocalStateGuard() {
|
||||
// restore previously set variables
|
||||
ThreadLocalState::setThreadLocalState(prev_state_);
|
||||
if (bumped_record_all_functions_) {
|
||||
at::releaseRecordAllFunctions();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const ThreadLocalState prev_state_;
|
||||
// Whether pre-sampling RecordFunction optimization was enabled
|
||||
bool bumped_record_all_functions_ = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -371,28 +371,39 @@ inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle<Return(A
|
||||
const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKey);
|
||||
|
||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
// Check if we need to run callbacks registered with RecordFunction
|
||||
// If true and callbacks need inputs, we box the arguments and pass
|
||||
// them into the callbacks and also into the kernel call
|
||||
// By default, when there're no high-frequency or non-sampled callbacks,
|
||||
// RecordFunction is pre-sampled as a perf optimization;
|
||||
// shouldRunRecordFunction checks whether RecordFunction should be executed,
|
||||
// and sets pre_sampled boolean argument value to whether pre-sampling was used -
|
||||
// this boolean is passed into RecordFunction to adjust the sampling rates of
|
||||
// the callbacks
|
||||
bool pre_sampled = false;
|
||||
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
|
||||
// Check if we need to run callbacks registered with RecordFunction
|
||||
// If true and callbacks need inputs, we box the arguments and pass
|
||||
// them into the callbacks and also into the kernel call
|
||||
|
||||
// Note: for perf reasons we wouldn't want to pass arguments into
|
||||
// the function call or prematurely box them
|
||||
at::RecordFunction guard(at::RecordScope::FUNCTION);
|
||||
if (C10_UNLIKELY(guard.isActive())) {
|
||||
if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) {
|
||||
int64_t seq_num = -1;
|
||||
// Setting sequence number in the Autograd case to associate
|
||||
// the forward range with the coresponding Autograd's node
|
||||
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
|
||||
seq_num = at::sequence_number::peek();
|
||||
}
|
||||
if (guard.needsInputs()) {
|
||||
torch::jit::Stack stack = impl::boxArgs(args...);
|
||||
guard.before(op, stack, seq_num);
|
||||
} else {
|
||||
guard.before(op, seq_num);
|
||||
// Note: for perf reasons we wouldn't want to pass arguments into
|
||||
// the function call or prematurely box them
|
||||
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
|
||||
if (C10_UNLIKELY(guard.isActive())) {
|
||||
if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) {
|
||||
int64_t seq_num = -1;
|
||||
// Setting sequence number in the Autograd case to associate
|
||||
// the forward range with the coresponding Autograd's node
|
||||
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
|
||||
seq_num = at::sequence_number::peek();
|
||||
}
|
||||
if (guard.needsInputs()) {
|
||||
torch::jit::Stack stack = impl::boxArgs(args...);
|
||||
guard.before(op, stack, seq_num);
|
||||
} else {
|
||||
guard.before(op, seq_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
// keeping the guard alive while executing the kernel
|
||||
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
|
||||
}
|
||||
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
|
||||
@ -429,20 +440,26 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
|
||||
const auto& kernel = entry.lookup(dispatchKey);
|
||||
|
||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
// using already existing stack to record function execution in observers
|
||||
at::RecordFunction guard(at::RecordScope::FUNCTION);
|
||||
if (C10_UNLIKELY(guard.isActive())) {
|
||||
if (shouldRecord(dispatchKey) && entry.isObserved()) {
|
||||
int64_t seq_num = -1;
|
||||
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
|
||||
seq_num = at::sequence_number::peek();
|
||||
}
|
||||
if (guard.needsInputs()) {
|
||||
guard.before(op, *stack, seq_num);
|
||||
} else {
|
||||
guard.before(op, seq_num);
|
||||
bool pre_sampled = false;
|
||||
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
|
||||
// using already existing stack to record function execution in observers
|
||||
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
|
||||
if (C10_UNLIKELY(guard.isActive())) {
|
||||
if (shouldRecord(dispatchKey) && entry.isObserved()) {
|
||||
int64_t seq_num = -1;
|
||||
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
|
||||
seq_num = at::sequence_number::peek();
|
||||
}
|
||||
if (guard.needsInputs()) {
|
||||
guard.before(op, *stack, seq_num);
|
||||
} else {
|
||||
guard.before(op, seq_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
// keeping the guard alive while executing the kernel
|
||||
kernel.callBoxed(op, stack);
|
||||
return;
|
||||
}
|
||||
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
kernel.callBoxed(op, stack);
|
||||
|
@ -30,8 +30,6 @@ std::atomic<int64_t> defaultNodeId(-1);
|
||||
std::atomic<uint64_t> next_thread_id_ {0};
|
||||
thread_local uint64_t current_thread_id_ = 0;
|
||||
|
||||
thread_local bool tls_record_function_enabled_ = true;
|
||||
|
||||
// Low probability constant
|
||||
static const double kLowProb = 0.001;
|
||||
struct CoinflipTLS {
|
||||
@ -68,6 +66,10 @@ void set_record_function_tls_(const RecordFunctionTLS& tls) {
|
||||
class CallbackManager {
|
||||
public:
|
||||
CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb) {
|
||||
if (cb.samplingProb() > kLowProb) {
|
||||
// pre-sampling of RecordFunction with prob. kLowProb cannot be used
|
||||
at::bumpRecordAllFunctions();
|
||||
}
|
||||
// note: monotonically increasing callbacks_unique_id keeps
|
||||
// sorted_tls_callbacks_ sorted
|
||||
auto handle = next_unique_callback_handle();
|
||||
@ -76,6 +78,10 @@ class CallbackManager {
|
||||
}
|
||||
|
||||
CallbackHandle addGlobalCallback(RecordFunctionCallback cb) {
|
||||
if (cb.samplingProb() > kLowProb) {
|
||||
// pre-sampling of RecordFunction with prob. kLowProb cannot be used
|
||||
at::bumpRecordAllFunctions();
|
||||
}
|
||||
auto handle = next_unique_callback_handle();
|
||||
sorted_global_callbacks_.emplace_back(std::move(cb), handle);
|
||||
return handle;
|
||||
@ -92,6 +98,10 @@ class CallbackManager {
|
||||
return el.second == handle;
|
||||
});
|
||||
if (it != cbs.end()) {
|
||||
if (it->first.samplingProb() > kLowProb) {
|
||||
// try to restore pre-sampling of RecordFunction
|
||||
at::releaseRecordAllFunctions();
|
||||
}
|
||||
// keeps it sorted
|
||||
cbs.erase(it);
|
||||
return true;
|
||||
@ -127,7 +137,13 @@ class CallbackManager {
|
||||
// callbackShouldRun is even hotter because it's called multiple
|
||||
// times per init(). Profiling shows that the function prologue is
|
||||
// taking up a significant fraction of the time.
|
||||
static bool C10_ALWAYS_INLINE callbackShouldRun(const RecordFunctionCallback& cb, RecordScope scope) {
|
||||
static bool C10_ALWAYS_INLINE callbackShouldRun(
|
||||
const RecordFunctionCallback& cb, RecordScope scope, bool pre_sampled) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!pre_sampled || (cb.sampling_prob_ <= kLowProb),
|
||||
"Incorrect usage of a pre-sampled RecordFunction with a high-frequency "
|
||||
" or non-sampled callback");
|
||||
|
||||
// first check whether this callback is interested in
|
||||
// the given scope type
|
||||
if (!cb.checkScope(scope)) {
|
||||
@ -138,36 +154,45 @@ class CallbackManager {
|
||||
return cb.should_run_(cb);
|
||||
}
|
||||
|
||||
if (cb.sampling_prob_ == 1.0) {
|
||||
return true;
|
||||
// otherwise potentially do the sampling
|
||||
double sampling_prob = cb.sampling_prob_;
|
||||
if (pre_sampled) {
|
||||
// adjust the sampling rate to account for kLowProb pre-sampling of
|
||||
// the RecordFunction
|
||||
sampling_prob /= kLowProb;
|
||||
}
|
||||
// model the low probability events as events happening
|
||||
// with probability kLowProb followed by another sampling with
|
||||
// probability (sampling_prob__ / kLowProb), then replace the coin
|
||||
// flip for kLowProb with a thread local number of tries tries_left_
|
||||
// sampled from the geometric distribution.
|
||||
if (cb.sampling_prob_ < kLowProb) {
|
||||
if (coinflip_tls_.tries_left_ == 0) {
|
||||
coinflip_tls_.tries_left_ = sample_geometric();
|
||||
return (sample_zero_one() < cb.sampling_prob_ / kLowProb);
|
||||
|
||||
if (sampling_prob < 1.0) {
|
||||
// model the low probability events as events happening
|
||||
// with probability kLowProb followed by another sampling with
|
||||
// probability (sampling_prob / kLowProb), then replace the coin
|
||||
// flip for kLowProb with a thread local number of tries tries_left_
|
||||
// sampled from the geometric distribution.
|
||||
if (sampling_prob < kLowProb) {
|
||||
if (coinflip_tls_.tries_left_ == 0) {
|
||||
coinflip_tls_.tries_left_ = sample_geometric();
|
||||
return (sample_zero_one() < sampling_prob / kLowProb);
|
||||
} else {
|
||||
--coinflip_tls_.tries_left_;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
--coinflip_tls_.tries_left_;
|
||||
return false;
|
||||
return (sample_zero_one() < sampling_prob);
|
||||
}
|
||||
} else {
|
||||
return (sample_zero_one() < cb.sampling_prob_);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// init is called by RecordFunction in constructor to
|
||||
// determine which thread local and global callbacks are going
|
||||
// to be executed and whether any of them need inputs
|
||||
inline void init(RecordFunction& rec_fn, RecordScope scope) {
|
||||
inline void init(RecordFunction& rec_fn, RecordScope scope, bool pre_sampled) {
|
||||
bool found_needs_inputs = false;
|
||||
bool found_needs_ids = false;
|
||||
|
||||
for (const auto& cb: rf_tls_.sorted_tls_callbacks_) {
|
||||
if (callbackShouldRun(cb.first, scope)) {
|
||||
if (callbackShouldRun(cb.first, scope, pre_sampled)) {
|
||||
if (cb.first.needsInputs()) {
|
||||
found_needs_inputs = true;
|
||||
}
|
||||
@ -182,7 +207,7 @@ class CallbackManager {
|
||||
}
|
||||
|
||||
for (const auto& cb: sorted_global_callbacks_) {
|
||||
if (callbackShouldRun(cb.first, scope)) {
|
||||
if (callbackShouldRun(cb.first, scope, pre_sampled)) {
|
||||
if (cb.first.needsInputs()) {
|
||||
found_needs_inputs = true;
|
||||
}
|
||||
@ -308,7 +333,6 @@ namespace {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
||||
RecordFunctionCallbacks _getTLSCallbacks() {
|
||||
return rf_tls_.sorted_tls_callbacks_;
|
||||
}
|
||||
@ -374,12 +398,12 @@ void enableRecordFunction(bool enable) {
|
||||
rf_tls_.tls_record_function_enabled_ = enable;
|
||||
}
|
||||
|
||||
RecordFunction::RecordFunction(RecordScope scope) {
|
||||
RecordFunction::RecordFunction(RecordScope scope, bool pre_sampled) {
|
||||
auto* rf_tls_ptr = &rf_tls_;
|
||||
if (rf_tls_ptr->tls_record_function_enabled_) {
|
||||
auto& m = manager();
|
||||
if (!m.sorted_global_callbacks_.empty() || !rf_tls_ptr->sorted_tls_callbacks_.empty()) {
|
||||
m.init(*this, scope);
|
||||
m.init(*this, scope, pre_sampled);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -451,4 +475,46 @@ void RecordFunction::end() {
|
||||
}
|
||||
}
|
||||
|
||||
// RecordFunction pre-sampling
|
||||
namespace {
|
||||
// Whether to try to create RecordFunction on each call (>0) or
|
||||
// use pre-sampling (=0)
|
||||
std::atomic<int> global_record_all_functions_ {0};
|
||||
}
|
||||
|
||||
void bumpRecordAllFunctions() {
|
||||
global_record_all_functions_.fetch_add(1, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void releaseRecordAllFunctions() {
|
||||
TORCH_CHECK(global_record_all_functions_.fetch_sub(1, std::memory_order_relaxed) >= 0);
|
||||
}
|
||||
|
||||
bool checkRecordAllFunctions() {
|
||||
return (global_record_all_functions_.load(std::memory_order_relaxed) > 0);
|
||||
}
|
||||
|
||||
bool shouldRunRecordFunction(bool* pre_sampled) {
|
||||
auto* rf_tls_ptr = &rf_tls_;
|
||||
if (!rf_tls_ptr->tls_record_function_enabled_) {
|
||||
*pre_sampled = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (global_record_all_functions_.load(std::memory_order_relaxed) > 0) {
|
||||
*pre_sampled = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
*pre_sampled = true;
|
||||
auto* coinflip_tls_ptr = &coinflip_tls_;
|
||||
if (coinflip_tls_ptr->tries_left_ == 0) {
|
||||
coinflip_tls_ptr->tries_left_ = sample_geometric();
|
||||
return true;
|
||||
} else {
|
||||
--coinflip_tls_ptr->tries_left_;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
@ -90,8 +90,11 @@ typedef uint64_t RecordFunctionHandle;
|
||||
struct TORCH_API RecordFunction {
|
||||
// Default constructor is used with before function called afterwards:
|
||||
// scope - record scope that this function tracks
|
||||
// pre_sampled - whether this RecordFunction was already pre-sampled with
|
||||
// kLowProb probability
|
||||
RecordFunction(
|
||||
RecordScope scope = RecordScope::FUNCTION);
|
||||
RecordScope scope = RecordScope::FUNCTION,
|
||||
bool pre_sampled = false);
|
||||
|
||||
template <typename F>
|
||||
void before(
|
||||
@ -238,6 +241,9 @@ struct TORCH_API RecordFunction {
|
||||
// flag is used to check whether the start callbacks were called
|
||||
bool called_start_callbacks_ = false;
|
||||
|
||||
// Whether the RecordFunction is pre-sampled
|
||||
bool pre_sampled_ = false;
|
||||
|
||||
// Used internally to keep track of thread local and global callbacks
|
||||
// that were picked to run; must be sorted;
|
||||
CallbackHandles sorted_active_tls_handles_;
|
||||
@ -330,7 +336,7 @@ class TORCH_API RecordFunctionCallback {
|
||||
}
|
||||
|
||||
RecordFunctionCallback& samplingProb(double sampling_prob) {
|
||||
TORCH_CHECK(sampling_prob >= 0.0 && sampling_prob_ <= 1.0,
|
||||
TORCH_CHECK(sampling_prob >= 0.0 && sampling_prob <= 1.0,
|
||||
"Invalid sampling probability");
|
||||
sampling_prob_ = sampling_prob;
|
||||
return *this;
|
||||
@ -544,10 +550,27 @@ struct TORCH_API RecordFunctionTLS {
|
||||
RecordFunctionCallbacks sorted_tls_callbacks_;
|
||||
|
||||
bool tls_record_function_enabled_ = true;
|
||||
|
||||
// Stores the number of coin flips before the next successful coin flip
|
||||
int tries_left_ = 0;
|
||||
};
|
||||
|
||||
TORCH_API const RecordFunctionTLS& get_record_function_tls_();
|
||||
|
||||
TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls);
|
||||
|
||||
// Checks whether RecordFunction should be called,
|
||||
// sets boolean pointed by the argument to whether pre-sampling was used
|
||||
TORCH_API bool shouldRunRecordFunction(bool*);
|
||||
|
||||
// The following functions are used to disable/enable pre-sampling of RecordFunction
|
||||
// when high-frequency/non-sampled callbacks are added/removed.
|
||||
// Note: every call to bumpRecordAllFunctions() is supposed to be matched with
|
||||
// the corresponding releaseRecordAllFunctions() call.
|
||||
// Note: disabling pre-sampling of RecordFunction incurs an extra overhead, since
|
||||
// RecordFunction will be created for each operator call.
|
||||
TORCH_API void bumpRecordAllFunctions();
|
||||
TORCH_API void releaseRecordAllFunctions();
|
||||
TORCH_API bool checkRecordAllFunctions();
|
||||
|
||||
} // namespace at
|
||||
|
@ -7,61 +7,55 @@
|
||||
#include <iostream>
|
||||
#include <ctime>
|
||||
|
||||
C10_DEFINE_int(iter, 100, "Number of iterations");
|
||||
C10_DEFINE_int(warmup_iter, 10, "Number of warmup iterations");
|
||||
C10_DEFINE_int(iter, 10000, "Number of iterations");
|
||||
C10_DEFINE_int(sampled_iter, 10e6,
|
||||
"Number of iterations for the sampled observer benchmark");
|
||||
|
||||
namespace {
|
||||
const int kInnerIter = 100;
|
||||
const int kNumSampledCb = 2;
|
||||
const int kTensorSize = 16;
|
||||
const int kSmallTensorSize = 1;
|
||||
const float kSampingProb = 0.1;
|
||||
|
||||
const float kLowSamplingProb = 0.0001;
|
||||
}
|
||||
|
||||
void setupBenchmarkCallbacks() {
|
||||
at::enableRecordFunction();
|
||||
at::clearCallbacks();
|
||||
// non-sampled callback
|
||||
at::addGlobalCallback(at::RecordFunctionCallback(
|
||||
[&](const at::RecordFunction& fn) {},
|
||||
void addTestCallback(
|
||||
double sampling_prob = 1.0,
|
||||
std::function<void(const at::RecordFunction&)> fn =
|
||||
[](const at::RecordFunction&) {}) {
|
||||
auto cb = at::RecordFunctionCallback(
|
||||
std::move(fn),
|
||||
[](const at::RecordFunction&) {})
|
||||
.needsInputs(true));
|
||||
|
||||
// sampled
|
||||
for (auto idx = 0; idx < kNumSampledCb; ++idx) {
|
||||
at::addGlobalCallback(at::RecordFunctionCallback(
|
||||
[](const at::RecordFunction& fn) {},
|
||||
[](const at::RecordFunction&) {})
|
||||
.needsInputs(true)
|
||||
.samplingProb(kSampingProb)
|
||||
);
|
||||
.needsInputs(false);
|
||||
if (sampling_prob < 1.0) {
|
||||
cb.samplingProb(sampling_prob);
|
||||
}
|
||||
at::addGlobalCallback(cb);
|
||||
}
|
||||
|
||||
float runTensorBench(int tensor_size, int outer_iter) {
|
||||
float runTensorGEMMBench(int tensor_size, int iter) {
|
||||
typedef std::chrono::high_resolution_clock clock;
|
||||
typedef std::chrono::microseconds us;
|
||||
std::chrono::time_point<clock> start_time = clock::now();
|
||||
for (auto idx = 0; idx < kInnerIter * outer_iter; ++idx) {
|
||||
torch::mm(
|
||||
torch::randn({tensor_size, tensor_size}),
|
||||
torch::randn({tensor_size, tensor_size}));
|
||||
auto inp = torch::randn({tensor_size, tensor_size});
|
||||
for (auto idx = 0; idx < iter; ++idx) {
|
||||
torch::mm(inp, inp);
|
||||
}
|
||||
auto duration = static_cast<float>(
|
||||
std::chrono::duration_cast<us>(clock::now() - start_time).count());
|
||||
return duration;
|
||||
}
|
||||
|
||||
float runPureRecordFunctionBench(int outer_iter) {
|
||||
float runPureRecordFunctionBench(int iter) {
|
||||
typedef std::chrono::high_resolution_clock clock;
|
||||
typedef std::chrono::microseconds us;
|
||||
std::chrono::time_point<clock> start_time = clock::now();
|
||||
for (auto n = 0; n < outer_iter; ++n) {
|
||||
RECORD_USER_SCOPE("test");
|
||||
for (auto idx = 0; idx < iter; ++idx) {
|
||||
bool pre_sampled = false;
|
||||
if (at::shouldRunRecordFunction(&pre_sampled)) {
|
||||
at::RecordFunction guard(at::RecordScope::USER_SCOPE, pre_sampled);
|
||||
if (C10_UNLIKELY(guard.isActive())) {
|
||||
guard.before("Test", -1);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto duration = static_cast<float>(
|
||||
std::chrono::duration_cast<us>(clock::now() - start_time).count());
|
||||
@ -71,18 +65,19 @@ float runPureRecordFunctionBench(int outer_iter) {
|
||||
void runBenchmark() {
|
||||
float duration = 0;
|
||||
for (auto tensor_size : std::set<int>({kSmallTensorSize, kTensorSize})) {
|
||||
duration = runTensorBench(tensor_size, FLAGS_iter);
|
||||
std::cout << "Running tensor benchmark, time per iteration ("
|
||||
duration = runTensorGEMMBench(tensor_size, FLAGS_iter);
|
||||
std::cout << "Tensor GEMM benchmark ("
|
||||
<< tensor_size
|
||||
<< "x"
|
||||
<< tensor_size
|
||||
<< "): " << (duration/FLAGS_iter)
|
||||
<< ", " << FLAGS_iter << "): " << duration
|
||||
<< " us." << std::endl;
|
||||
}
|
||||
duration = runPureRecordFunctionBench(FLAGS_iter * 100);
|
||||
std::cout << "Running pure RecordFunction benchmark, time per iteration: "
|
||||
<< (duration/FLAGS_iter)
|
||||
<< " us." << std::endl;
|
||||
duration = runPureRecordFunctionBench(FLAGS_iter);
|
||||
std::cout << "Pure RecordFunction benchmark ("
|
||||
<< FLAGS_iter << "): "
|
||||
<< duration
|
||||
<< " us." << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
@ -91,32 +86,38 @@ int main(int argc, char** argv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
auto duration = runTensorBench(kSmallTensorSize, FLAGS_warmup_iter);
|
||||
std::cout << "Warmup time: " << duration << " us." << std::endl;
|
||||
at::enableRecordFunction();
|
||||
at::clearCallbacks();
|
||||
|
||||
setupBenchmarkCallbacks();
|
||||
std::cout << "Running with empty observers" << std::endl;
|
||||
std::cout << "Warm up" << std::endl;
|
||||
runBenchmark();
|
||||
|
||||
at::clearCallbacks();
|
||||
std::cout << "Running without observers" << std::endl;
|
||||
runBenchmark();
|
||||
|
||||
std::cout << "Running sampled observer benchmark" << std::endl;
|
||||
addTestCallback();
|
||||
std::cout << "Running with empty non-sampled observer" << std::endl;
|
||||
runBenchmark();
|
||||
at::clearCallbacks();
|
||||
|
||||
addTestCallback(kLowSamplingProb);
|
||||
std::cout << "Running with empty sampled observer" << std::endl;
|
||||
runBenchmark();
|
||||
at::clearCallbacks();
|
||||
|
||||
std::cout << "Checking number of sampled observer invocations" << std::endl;
|
||||
int cb_count = 0;
|
||||
at::addGlobalCallback(at::RecordFunctionCallback(
|
||||
addTestCallback(
|
||||
kLowSamplingProb,
|
||||
[&](const at::RecordFunction& fn) {
|
||||
++cb_count;
|
||||
},
|
||||
[](const at::RecordFunction&) {})
|
||||
.needsInputs(true)
|
||||
.samplingProb(kLowSamplingProb)
|
||||
}
|
||||
);
|
||||
|
||||
runPureRecordFunctionBench(FLAGS_sampled_iter);
|
||||
auto duration = runPureRecordFunctionBench(FLAGS_sampled_iter);
|
||||
|
||||
std::cout << "Pure RecordFunction runtime of " << FLAGS_sampled_iter
|
||||
<< " iterations " << duration
|
||||
<< " iterations: " << duration
|
||||
<< " us, number of callback invocations: " << cb_count
|
||||
<< ", expected number: ~" << (int)(FLAGS_sampled_iter * kLowSamplingProb)
|
||||
<< " invocations" << std::endl;
|
||||
|
@ -133,26 +133,33 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
/// Evaluates the function on the given inputs and returns the result of the
|
||||
/// function call.
|
||||
variable_list operator()(variable_list&& inputs) {
|
||||
// Using RecordFunction to trogger observers in the backward pass
|
||||
at::RecordFunction guard(at::RecordScope::BACKWARD_FUNCTION);
|
||||
if (guard.isActive()) {
|
||||
// Using sequence number and thread id to correlate with
|
||||
// the forward pass function
|
||||
guard.setForwardThreadId(thread_id_);
|
||||
if (guard.needsInputs()) {
|
||||
guard.before(
|
||||
name(),
|
||||
std::vector<c10::IValue>(inputs.begin(), inputs.end()),
|
||||
sequence_nr());
|
||||
} else {
|
||||
guard.before(name(), sequence_nr());
|
||||
}
|
||||
}
|
||||
// In the first iteration of named tensors, autograd ignores names and
|
||||
// operates on unnamed tensors. In the long term, autograd should
|
||||
// probably operate with names.
|
||||
at::NoNamesGuard no_names_guard;
|
||||
return apply(std::move(inputs));
|
||||
|
||||
bool pre_sampled = false;
|
||||
if (at::shouldRunRecordFunction(&pre_sampled)) {
|
||||
// Using RecordFunction to trogger observers in the backward pass
|
||||
at::RecordFunction guard(at::RecordScope::BACKWARD_FUNCTION, pre_sampled);
|
||||
if (guard.isActive()) {
|
||||
// Using sequence number and thread id to correlate with
|
||||
// the forward pass function
|
||||
guard.setForwardThreadId(thread_id_);
|
||||
if (guard.needsInputs()) {
|
||||
guard.before(
|
||||
name(),
|
||||
std::vector<c10::IValue>(inputs.begin(), inputs.end()),
|
||||
sequence_nr());
|
||||
} else {
|
||||
guard.before(name(), sequence_nr());
|
||||
}
|
||||
}
|
||||
// keeping stack guard object alive during the call
|
||||
return apply(std::move(inputs));
|
||||
} else {
|
||||
return apply(std::move(inputs));
|
||||
}
|
||||
}
|
||||
|
||||
// Graph Connectivity API
|
||||
|
@ -1607,10 +1607,11 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
}
|
||||
|
||||
static void checkAndStartRecordFunction(Frame& frame, Stack& stack) {
|
||||
bool pre_sampled = false;
|
||||
if (!frame.record_function && at::hasCallbacks() &&
|
||||
at::isRecordFunctionEnabled()) {
|
||||
at::shouldRunRecordFunction(&pre_sampled)) {
|
||||
auto rec_fn = std::make_unique<at::RecordFunction>(
|
||||
at::RecordScope::TORCHSCRIPT_FUNCTION);
|
||||
at::RecordScope::TORCHSCRIPT_FUNCTION, pre_sampled);
|
||||
if (rec_fn->isActive()) {
|
||||
if (rec_fn->needsInputs()) {
|
||||
rec_fn->before(
|
||||
|
Reference in New Issue
Block a user