mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Only profiling when it's enabled. (#121404)
Summary: The profiling, even when disabled, takes up about 1.5% cpu for a model I'm looking into. This patch just splits into with/without profile runs. The potential downside is that now the script can't enable profiling in itself. It doesn't seem to be used anywhere. If that's a crusial usecase, we can do something about it but ideally we wouldn't. Test Plan: Link with profiles: https://fburl.com/scuba/strobelight_services/ihxsl7pj ``` buck2 run fbcode//caffe2/test/cpp/jit:jit ``` Reviewed By: zhxchen17 Differential Revision: D54066589 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121404 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
df06b94778
commit
2c9c57c061
0
test/dynamo_skips/TestScriptProfile.test_script
Normal file
0
test/dynamo_skips/TestScriptProfile.test_script
Normal file
0
test/dynamo_skips/TestScriptProfile.test_section
Normal file
0
test/dynamo_skips/TestScriptProfile.test_section
Normal file
@ -51,15 +51,16 @@ class TestScriptProfile(JitTestCase):
|
||||
def test_script(self):
|
||||
seq = Sequence()
|
||||
|
||||
p = torch.jit._ScriptProfile()
|
||||
p.enable()
|
||||
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
p = torch.jit._ScriptProfile()
|
||||
p.enable()
|
||||
_ = seq(torch.rand((10, 100)))
|
||||
p.disable()
|
||||
return p
|
||||
fn()
|
||||
p.disable()
|
||||
|
||||
self.assertNotEqual(fn().dump_string(), "")
|
||||
self.assertNotEqual(p.dump_string(), "")
|
||||
|
||||
def test_multi(self):
|
||||
seq = torch.jit.script(Sequence())
|
||||
@ -82,25 +83,24 @@ class TestScriptProfile(JitTestCase):
|
||||
seq = Sequence()
|
||||
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
p = torch.jit._ScriptProfile()
|
||||
p.enable()
|
||||
_ = seq(torch.rand((10, 100)))
|
||||
p.disable()
|
||||
stats0 = p.dump_string()
|
||||
def fn(max : int):
|
||||
_ = seq(torch.rand((10, max)))
|
||||
|
||||
_ = seq(torch.rand((10, 10)))
|
||||
stats1 = p.dump_string()
|
||||
p = torch.jit._ScriptProfile()
|
||||
p.enable()
|
||||
fn(100)
|
||||
p.disable()
|
||||
s0 = p.dump_string()
|
||||
|
||||
p.enable()
|
||||
_ = seq(torch.rand((10, 10)))
|
||||
p.disable()
|
||||
stats2 = p.dump_string()
|
||||
fn(10)
|
||||
p.disable()
|
||||
s1 = p.dump_string()
|
||||
|
||||
p.enable()
|
||||
return stats0, stats1, stats2
|
||||
p.enable()
|
||||
fn(10)
|
||||
p.disable()
|
||||
s2 = p.dump_string()
|
||||
|
||||
s0, s1, s2 = fn()
|
||||
self.assertEqual(s0, s1)
|
||||
self.assertNotEqual(s1, s2)
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/record_function.h>
|
||||
#include <c10/core/thread_pool.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/autograd/edge.h>
|
||||
@ -239,6 +240,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
std::size_t initialSize_{stack_.size()};
|
||||
};
|
||||
|
||||
struct C10_UNUSED DoNothing {};
|
||||
|
||||
#if defined(__GNUC__) || defined(__clang__)
|
||||
#define JIT_USE_COMPUTED_GOTO
|
||||
#endif
|
||||
@ -265,7 +268,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
inst = instFetch(1); \
|
||||
INST_DISPATCH
|
||||
|
||||
bool runImpl(Stack& stack) {
|
||||
template <bool EnableProfiling>
|
||||
bool runTemplate(Stack& stack) {
|
||||
// if we have never run before, then we might have to return the
|
||||
// stack when we suspend, record where it starts so we return the right
|
||||
// stack
|
||||
@ -301,8 +305,12 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
};
|
||||
|
||||
auto instGuard = [&] {
|
||||
return profiling::InstructionSpan{
|
||||
*frame.function->instructions_source()[frame.pc]};
|
||||
if constexpr (!EnableProfiling) {
|
||||
return DoNothing{};
|
||||
} else {
|
||||
return profiling::InstructionSpan{
|
||||
*frame.function->instructions_source()[frame.pc]};
|
||||
}
|
||||
};
|
||||
|
||||
Instruction inst = instFetch(0);
|
||||
@ -889,6 +897,14 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
#undef INST
|
||||
#undef JIT_USE_COMPUTED_GOTO
|
||||
|
||||
bool runImpl(Stack& stack) {
|
||||
if (!profiling::isProfilingOngoing()) {
|
||||
return runTemplate</*EnableProfiling*/ false>(stack);
|
||||
} else {
|
||||
return runTemplate</*EnableProfiling*/ true>(stack);
|
||||
}
|
||||
}
|
||||
|
||||
void formatStackTrace(std::ostream& out) {
|
||||
format_stack_trace(out, callstack());
|
||||
}
|
||||
|
@ -109,22 +109,18 @@ const auto C10_UNUSED torchBindInitializer = initBindings();
|
||||
namespace profiling {
|
||||
|
||||
InstructionSpan::InstructionSpan(Node& node) {
|
||||
if (getProfilesRegistry().empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
datapoint_ = std::make_unique<Datapoint>(node.sourceRange());
|
||||
}
|
||||
|
||||
InstructionSpan::~InstructionSpan() {
|
||||
if (!datapoint_) {
|
||||
return;
|
||||
}
|
||||
|
||||
datapoint_->end = std::chrono::steady_clock::now();
|
||||
getProfilesRegistry().send(std::move(datapoint_));
|
||||
}
|
||||
|
||||
bool isProfilingOngoing() {
|
||||
return !getProfilesRegistry().empty();
|
||||
}
|
||||
|
||||
} // namespace profiling
|
||||
|
||||
void ScriptProfile::enable() {
|
||||
|
@ -33,6 +33,8 @@ class TORCH_API InstructionSpan {
|
||||
std::unique_ptr<Datapoint> datapoint_;
|
||||
};
|
||||
|
||||
bool TORCH_API isProfilingOngoing();
|
||||
|
||||
} // namespace profiling
|
||||
|
||||
struct TORCH_API InstructionStats : public CustomClassHolder {
|
||||
@ -72,6 +74,8 @@ class TORCH_API SourceStats : public CustomClassHolder {
|
||||
* scriptProfile.disable();
|
||||
* ...
|
||||
*
|
||||
* NOTE: you cannot attach the profiler while the script is running.
|
||||
*
|
||||
* To retrieve collected runtime data, users may call dumpStats() and do
|
||||
* arbitrary filtering on the data they want. Note that dumpStats() should
|
||||
* not be called inside a profiling section.
|
||||
|
Reference in New Issue
Block a user