Only profiling when it's enabled. (#121404)

Summary:
The profiling, even when disabled, takes up about 1.5% cpu for a model I'm looking into.

This patch just splits into with/without profile runs.

The potential downside is that now the script can't enable profiling in itself. It doesn't seem to be used anywhere. If that's a crusial usecase, we can do something about it but ideally we wouldn't.

Test Plan:
Link with profiles:
https://fburl.com/scuba/strobelight_services/ihxsl7pj

```
buck2 run fbcode//caffe2/test/cpp/jit:jit
```

Reviewed By: zhxchen17

Differential Revision: D54066589

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121404
Approved by: https://github.com/zhxchen17
This commit is contained in:
Denis Yaroshevskiy
2024-03-08 19:23:09 +00:00
committed by PyTorch MergeBot
parent df06b94778
commit 2c9c57c061
6 changed files with 47 additions and 31 deletions

View File

@ -51,15 +51,16 @@ class TestScriptProfile(JitTestCase):
def test_script(self):
seq = Sequence()
p = torch.jit._ScriptProfile()
p.enable()
@torch.jit.script
def fn():
p = torch.jit._ScriptProfile()
p.enable()
_ = seq(torch.rand((10, 100)))
p.disable()
return p
fn()
p.disable()
self.assertNotEqual(fn().dump_string(), "")
self.assertNotEqual(p.dump_string(), "")
def test_multi(self):
seq = torch.jit.script(Sequence())
@ -82,25 +83,24 @@ class TestScriptProfile(JitTestCase):
seq = Sequence()
@torch.jit.script
def fn():
p = torch.jit._ScriptProfile()
p.enable()
_ = seq(torch.rand((10, 100)))
p.disable()
stats0 = p.dump_string()
def fn(max : int):
_ = seq(torch.rand((10, max)))
_ = seq(torch.rand((10, 10)))
stats1 = p.dump_string()
p = torch.jit._ScriptProfile()
p.enable()
fn(100)
p.disable()
s0 = p.dump_string()
p.enable()
_ = seq(torch.rand((10, 10)))
p.disable()
stats2 = p.dump_string()
fn(10)
p.disable()
s1 = p.dump_string()
p.enable()
return stats0, stats1, stats2
p.enable()
fn(10)
p.disable()
s2 = p.dump_string()
s0, s1, s2 = fn()
self.assertEqual(s0, s1)
self.assertNotEqual(s1, s2)

View File

@ -4,6 +4,7 @@
#include <ATen/core/ivalue.h>
#include <ATen/record_function.h>
#include <c10/core/thread_pool.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/edge.h>
@ -239,6 +240,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
std::size_t initialSize_{stack_.size()};
};
struct C10_UNUSED DoNothing {};
#if defined(__GNUC__) || defined(__clang__)
#define JIT_USE_COMPUTED_GOTO
#endif
@ -265,7 +268,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
inst = instFetch(1); \
INST_DISPATCH
bool runImpl(Stack& stack) {
template <bool EnableProfiling>
bool runTemplate(Stack& stack) {
// if we have never run before, then we might have to return the
// stack when we suspend, record where it starts so we return the right
// stack
@ -301,8 +305,12 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
};
auto instGuard = [&] {
return profiling::InstructionSpan{
*frame.function->instructions_source()[frame.pc]};
if constexpr (!EnableProfiling) {
return DoNothing{};
} else {
return profiling::InstructionSpan{
*frame.function->instructions_source()[frame.pc]};
}
};
Instruction inst = instFetch(0);
@ -889,6 +897,14 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
#undef INST
#undef JIT_USE_COMPUTED_GOTO
bool runImpl(Stack& stack) {
if (!profiling::isProfilingOngoing()) {
return runTemplate</*EnableProfiling*/ false>(stack);
} else {
return runTemplate</*EnableProfiling*/ true>(stack);
}
}
void formatStackTrace(std::ostream& out) {
format_stack_trace(out, callstack());
}

View File

@ -109,22 +109,18 @@ const auto C10_UNUSED torchBindInitializer = initBindings();
namespace profiling {
InstructionSpan::InstructionSpan(Node& node) {
if (getProfilesRegistry().empty()) {
return;
}
datapoint_ = std::make_unique<Datapoint>(node.sourceRange());
}
InstructionSpan::~InstructionSpan() {
if (!datapoint_) {
return;
}
datapoint_->end = std::chrono::steady_clock::now();
getProfilesRegistry().send(std::move(datapoint_));
}
bool isProfilingOngoing() {
return !getProfilesRegistry().empty();
}
} // namespace profiling
void ScriptProfile::enable() {

View File

@ -33,6 +33,8 @@ class TORCH_API InstructionSpan {
std::unique_ptr<Datapoint> datapoint_;
};
bool TORCH_API isProfilingOngoing();
} // namespace profiling
struct TORCH_API InstructionStats : public CustomClassHolder {
@ -72,6 +74,8 @@ class TORCH_API SourceStats : public CustomClassHolder {
* scriptProfile.disable();
* ...
*
* NOTE: you cannot attach the profiler while the script is running.
*
* To retrieve collected runtime data, users may call dumpStats() and do
* arbitrary filtering on the data they want. Note that dumpStats() should
* not be called inside a profiling section.