[PyPer][ET] Refactor EG to ET (#99694)

Summary:
Change execution graph to execution trace.
See post: https://fb.workplace.com/groups/873291503156329/permalink/1529496217535851/

Test Plan: Run a job.

Reviewed By: chaekit

Differential Revision: D44121392

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99694
Approved by: https://github.com/chaekit
This commit is contained in:
Louis Feng
2023-06-22 19:41:54 +00:00
committed by PyTorch MergeBot
parent ec922efe3b
commit 5847cb55e4
9 changed files with 188 additions and 182 deletions

View File

@ -139,7 +139,7 @@ libtorch_profiler_sources = [
"torch/csrc/profiler/kineto_client_interface.cpp",
"torch/csrc/profiler/orchestration/observer.cpp",
"torch/csrc/profiler/orchestration/python_tracer.cpp",
"torch/csrc/profiler/standalone/execution_graph_observer.cpp",
"torch/csrc/profiler/standalone/execution_trace_observer.cpp",
"torch/csrc/profiler/standalone/itt_observer.cpp",
"torch/csrc/profiler/standalone/nvtx_observer.cpp",
"torch/csrc/profiler/stubs/base.cpp",

View File

@ -32,7 +32,7 @@ from torch.autograd.profiler_legacy import profile as _profile_legacy
from torch.profiler import (
_utils,
DeviceType,
ExecutionGraphObserver,
ExecutionTraceObserver,
kineto_available,
profile,
ProfilerAction,
@ -314,7 +314,7 @@ class TestRecordFunction(TestCase):
self.assertTrue(has_child)
class TestExecutionGraph(TestCase):
class TestExecutionTrace(TestCase):
def payload(self, use_cuda=False):
u = torch.randn(3, 4, 5, requires_grad=True)
with record_function("## TEST 1 ##", "1, 2, 3"):
@ -338,16 +338,16 @@ class TestExecutionGraph(TestCase):
z = z.cpu()
_record_function_with_args_exit(rf_handle)
def get_execution_graph_root(self, output_file_name):
def get_execution_trace_root(self, output_file_name):
nodes = []
with open(output_file_name, 'r') as f:
eg_graph = json.load(f)
assert "nodes" in eg_graph
nodes = eg_graph["nodes"]
et_graph = json.load(f)
assert "nodes" in et_graph
nodes = et_graph["nodes"]
return nodes
@unittest.skipIf(not kineto_available(), "Kineto is required")
def test_execution_graph_with_kineto(self):
def test_execution_trace_with_kineto(self):
trace_called_num = 0
def trace_handler(p):
@ -355,12 +355,12 @@ class TestExecutionGraph(TestCase):
trace_called_num += 1
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
# Create a temp file to save execution graph data.
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile('w+t', suffix='.et.json', delete=False)
fp.close()
expected_loop_events = 0
eg = ExecutionGraphObserver()
eg.register_callback(fp.name)
et = ExecutionTraceObserver()
et.register_callback(fp.name)
with profile(
activities=supported_activities(),
schedule=torch.profiler.schedule(
@ -370,50 +370,50 @@ class TestExecutionGraph(TestCase):
active=2),
on_trace_ready=trace_handler,
) as p:
eg.start()
et.start()
for idx in range(10):
expected_loop_events += 1
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
p.step()
eg.stop()
et.stop()
assert trace_called_num == 2
assert fp.name == eg.get_output_file_path()
assert fp.name == et.get_output_file_path()
# cleanup
eg.unregister_callback()
nodes = self.get_execution_graph_root(fp.name)
et.unregister_callback()
nodes = self.get_execution_trace_root(fp.name)
loop_count = 0
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_graph|process]" in n["name"]:
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
if n["name"].startswith("## LOOP "):
loop_count += 1
assert found_root_node
assert loop_count == expected_loop_events
def test_execution_graph_alone(self):
def test_execution_trace_alone(self):
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
# Create a temp file to save execution graph data.
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile('w+t', suffix='.et.json', delete=False)
fp.close()
expected_loop_events = 0
eg = ExecutionGraphObserver()
eg.register_callback(fp.name)
eg.start()
et = ExecutionTraceObserver()
et.register_callback(fp.name)
et.start()
for idx in range(5):
expected_loop_events += 1
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
eg.stop()
et.stop()
assert fp.name == eg.get_output_file_path()
eg.unregister_callback()
nodes = self.get_execution_graph_root(fp.name)
assert fp.name == et.get_output_file_path()
et.unregister_callback()
nodes = self.get_execution_trace_root(fp.name)
loop_count = 0
# Expected tensor object tuple size, in th form of:
# [tensor_id, storage_id, offset, numel, itemsize, device_str]
@ -421,7 +421,7 @@ class TestExecutionGraph(TestCase):
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_graph|process]" in n["name"]:
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
if n["name"].startswith("## LOOP "):
loop_count += 1
@ -431,69 +431,69 @@ class TestExecutionGraph(TestCase):
assert found_root_node
assert loop_count == expected_loop_events
def test_execution_graph_start_stop(self):
def test_execution_trace_start_stop(self):
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
# Create a temp file to save execution graph data.
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile('w+t', suffix='.et.json', delete=False)
fp.close()
expected_loop_events = 0
eg = ExecutionGraphObserver()
eg.register_callback(fp.name)
et = ExecutionTraceObserver()
et.register_callback(fp.name)
for idx in range(10):
if idx == 3:
eg.start()
et.start()
elif idx == 5:
eg.stop()
et.stop()
elif idx == 8:
eg.start()
et.start()
elif idx == 9:
eg.stop()
if eg._execution_graph_running:
et.stop()
if et._execution_trace_running:
expected_loop_events += 1
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
assert fp.name == eg.get_output_file_path()
eg.unregister_callback()
nodes = self.get_execution_graph_root(fp.name)
assert fp.name == et.get_output_file_path()
et.unregister_callback()
nodes = self.get_execution_trace_root(fp.name)
loop_count = 0
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_graph|process]" in n["name"]:
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
if n["name"].startswith("## LOOP "):
loop_count += 1
assert found_root_node
assert loop_count == expected_loop_events
def test_execution_graph_repeat_in_loop(self):
def test_execution_trace_repeat_in_loop(self):
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
iter_list = {3, 4, 6, 8}
expected_loop_events = len(iter_list)
output_files = []
for idx in range(10):
if idx in iter_list:
# Create a temp file to save execution graph data.
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile('w+t', suffix='.et.json', delete=False)
fp.close()
output_files.append(fp.name)
eg = ExecutionGraphObserver()
eg.register_callback(fp.name)
eg.start()
et = ExecutionTraceObserver()
et.register_callback(fp.name)
et.start()
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
if idx in iter_list:
eg.stop()
eg.unregister_callback()
et.stop()
et.unregister_callback()
event_count = 0
for eg_file in output_files:
nodes = self.get_execution_graph_root(eg_file)
for et_file in output_files:
nodes = self.get_execution_trace_root(et_file)
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_graph|process]" in n["name"]:
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
assert n["id"] == 1
found_root_node = True
if n["name"].startswith("## LOOP "):
@ -501,18 +501,18 @@ class TestExecutionGraph(TestCase):
assert found_root_node
assert event_count == expected_loop_events
def test_execution_graph_no_capture(self):
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
def test_execution_trace_no_capture(self):
fp = tempfile.NamedTemporaryFile('w+t', suffix='.et.json', delete=False)
fp.close()
eg = ExecutionGraphObserver()
eg.register_callback(fp.name)
et = ExecutionTraceObserver()
et.register_callback(fp.name)
assert fp.name == eg.get_output_file_path()
eg.unregister_callback()
nodes = self.get_execution_graph_root(fp.name)
assert fp.name == et.get_output_file_path()
et.unregister_callback()
nodes = self.get_execution_trace_root(fp.name)
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_graph|process]" in n["name"]:
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
assert found_root_node

View File

@ -210,9 +210,9 @@ class _ExtraFields_PyCall:
class _ExtraFields_Kineto: ...
def _add_execution_graph_observer(output_file_path: str) -> bool: ...
def _remove_execution_graph_observer() -> None: ...
def _enable_execution_graph_observer() -> None: ...
def _disable_execution_graph_observer() -> None: ...
def _add_execution_trace_observer(output_file_path: str) -> bool: ...
def _remove_execution_trace_observer() -> None: ...
def _enable_execution_trace_observer() -> None: ...
def _disable_execution_trace_observer() -> None: ...
def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ...
def _set_fwd_bwd_enabled_val(val: bool) -> None: ...

View File

@ -7,7 +7,7 @@
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/profiler/collection.h>
#include <torch/csrc/profiler/python/combined_traceback.h>
#include <torch/csrc/profiler/standalone/execution_graph_observer.h>
#include <torch/csrc/profiler/standalone/execution_trace_observer.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
@ -283,20 +283,20 @@ void initPythonBindings(PyObject* module) {
return r.endTimeNS() - r.start_time_ns_;
});
// PyTorch profiler execution graph internal interface.
// PyTorch profiler execution trace internal interface.
m.def(
"_add_execution_graph_observer",
&torch::profiler::impl::addExecutionGraphObserver,
"_add_execution_trace_observer",
&torch::profiler::impl::addExecutionTraceObserver,
py::arg("output_file_name"));
m.def(
"_remove_execution_graph_observer",
&torch::profiler::impl::removeExecutionGraphObserver);
"_remove_execution_trace_observer",
&torch::profiler::impl::removeExecutionTraceObserver);
m.def(
"_enable_execution_graph_observer",
&torch::profiler::impl::enableExecutionGraphObserver);
"_enable_execution_trace_observer",
&torch::profiler::impl::enableExecutionTraceObserver);
m.def(
"_disable_execution_graph_observer",
&torch::profiler::impl::disableExecutionGraphObserver);
"_disable_execution_trace_observer",
&torch::profiler::impl::disableExecutionTraceObserver);
m.def(
"_set_record_concrete_inputs_enabled_val",
&torch::profiler::impl::set_record_concrete_inputs_enabled_val);

View File

@ -1,25 +0,0 @@
#pragma once
#include <c10/macros/Export.h>
#include <string>
namespace torch {
namespace profiler {
namespace impl {
// Adds the execution graph observer as a global callback function, the data
// will be written to output file path.
TORCH_API bool addExecutionGraphObserver(const std::string& output_file_path);
// Remove the execution graph observer from the global callback functions.
TORCH_API void removeExecutionGraphObserver();
// Enables execution graph observer.
TORCH_API void enableExecutionGraphObserver();
// Disables execution graph observer.
TORCH_API void disableExecutionGraphObserver();
} // namespace impl
} // namespace profiler
} // namespace torch

View File

@ -27,7 +27,7 @@
#include <ATen/core/stack.h>
#include <ATen/record_function.h>
#include <c10/util/irange.h>
#include <torch/csrc/profiler/standalone/execution_graph_observer.h>
#include <torch/csrc/profiler/standalone/execution_trace_observer.h>
#include <torch/csrc/profiler/util.h>
using namespace at;
@ -147,17 +147,17 @@ inline int32_t processId() {
}
//******************************************************************************
// Main ExecutionGraphObserver implementation.
// Main ExecutionTraceObserver implementation.
//******************************************************************************
// ExecutionGraphObserver contains all the states of the observer. Some of them
// ExecutionTraceObserver contains all the states of the observer. Some of them
// are shared between the enter and exit RecordFunction call backs, some data
// like the `op_stack` may be accessed across different threads. So we should be
// careful about data races. A global mutex `g_mutex` is used avoid these races
// at the cost of performance in large number of threads situations. We may
// optimize this further to thread local, fine-grained locking, or use thread
// safe containers.
struct TORCH_API ExecutionGraphObserver {
struct TORCH_API ExecutionTraceObserver {
using ID = size_t;
// Mapping of each thread to its own operator stack
@ -183,7 +183,7 @@ struct TORCH_API ExecutionGraphObserver {
int32_t pid{-1};
std::string record_time{};
ExecutionGraphObserver() = default;
ExecutionTraceObserver() = default;
// Returns a new unique ID.
ID getNewID() {
@ -208,7 +208,7 @@ struct TORCH_API ExecutionGraphObserver {
private:
static bool callbackShouldBeEnabled(RunState run_state) {
return run_state == ExecutionGraphObserver::RunState::enabled;
return run_state == ExecutionTraceObserver::RunState::enabled;
}
// Must use accessors to change this so that we can keep the
@ -224,31 +224,31 @@ struct TORCH_API ExecutionGraphObserver {
};
// Using a singleton manager here to allow init and delete the observer object.
using ObserverManager = GlobalStateManager<ExecutionGraphObserver>;
using ObserverManager = GlobalStateManager<ExecutionTraceObserver>;
// Uninitialized node has id = 0
const ExecutionGraphObserver::ID uninitialized_id{0};
const ExecutionTraceObserver::ID uninitialized_id{0};
// Root node has id = 1
const ExecutionGraphObserver::ID root_id{1};
const ExecutionTraceObserver::ID root_id{1};
struct FunctionCallContext : public ObserverContext {
std::string name;
ExecutionGraphObserver::ID op_id{uninitialized_id};
ExecutionGraphObserver::ID parent_id{uninitialized_id};
ExecutionGraphObserver::ID fw_parent_id{uninitialized_id};
ExecutionTraceObserver::ID op_id{uninitialized_id};
ExecutionTraceObserver::ID parent_id{uninitialized_id};
ExecutionTraceObserver::ID fw_parent_id{uninitialized_id};
std::vector<std::string> input_types;
std::vector<std::string> input_shapes;
std::vector<std::string> input_values;
};
// Opens the json file to write the execution graph.
// Opens the json file to write the execution trace.
static std::ofstream openOutputFile(const std::string& name) {
std::ofstream stream;
stream.open(name, std::ofstream::out | std::ofstream::trunc);
if (!stream) {
LOG(ERROR) << "Failed to open '" << name << "'";
} else {
VLOG(1) << "Writing PyTorch execution graph to: " << name;
VLOG(1) << "PyTorch Execution Trace: writing to " << name;
}
return stream;
}
@ -302,7 +302,7 @@ inline std::string timeString(const std::time_t timepoint) {
return oss.str();
}
static bool initExecutionGraphStart(ExecutionGraphObserver& ob) {
static bool initExecutionTraceStart(ExecutionTraceObserver& ob) {
ob.out = openOutputFile(ob.file_name);
// If somehow the output stream failed to open, finish observer here.
if (!ob.out) {
@ -330,11 +330,11 @@ static bool initExecutionGraphStart(ExecutionGraphObserver& ob) {
return true;
}
// Write out Execution Graph to file
static void finalizeExecutionGraphOutput(ExecutionGraphObserver& ob) {
// Write out Execution Trace to file
static void finalizeExecutionTraceOutput(ExecutionTraceObserver& ob) {
writeJsonNode(
ob.out,
"[pytorch|profiler|execution_graph|process]",
"[pytorch|profiler|execution_trace|process]",
root_id,
0, // rf_id
root_id, // parent is self
@ -357,15 +357,15 @@ static void finalizeExecutionGraphOutput(ExecutionGraphObserver& ob) {
timestamp);
ob.out.close();
VLOG(1) << "PyTorch execution graph is written to file: " << ob.file_name;
VLOG(1) << "PyTorch Execution Trace: written to file " << ob.file_name;
}
inline ExecutionGraphObserver::ID getObjectID(
ExecutionGraphObserver& ob,
inline ExecutionTraceObserver::ID getObjectID(
ExecutionTraceObserver& ob,
const void* t) {
auto iter = ob.object_id.find(t);
if (iter == ob.object_id.end()) {
ExecutionGraphObserver::ID object_id = ob.getNewID();
ExecutionTraceObserver::ID object_id = ob.getNewID();
ob.object_id[t] = object_id;
return object_id;
}
@ -374,13 +374,13 @@ inline ExecutionGraphObserver::ID getObjectID(
}
inline std::string convertIValue(
ExecutionGraphObserver& ob,
ExecutionTraceObserver& ob,
const c10::IValue& val,
const size_t maxArrayLen = maxNumElements) {
if (val.isTensor()) {
const auto t = val.toTensor().unsafeGetTensorImpl();
ExecutionGraphObserver::ID tensor_id = getObjectID(ob, t);
ExecutionGraphObserver::ID storage_id = 0;
ExecutionTraceObserver::ID tensor_id = getObjectID(ob, t);
ExecutionTraceObserver::ID storage_id = 0;
size_t offset = 0;
size_t numel = 0;
size_t itemsize = 0;
@ -427,7 +427,7 @@ inline std::string convertIValue(
}
inline void appendValueInfo(
ExecutionGraphObserver& ob,
ExecutionTraceObserver& ob,
const c10::IValue& val,
std::vector<std::string>& values,
std::vector<std::string>& types,
@ -438,7 +438,7 @@ inline void appendValueInfo(
}
static void recordOperatorStart(
ExecutionGraphObserver& ob,
ExecutionTraceObserver& ob,
FunctionCallContext& fc,
const RecordFunction& fn) {
auto tid = fn.threadId();
@ -452,7 +452,7 @@ static void recordOperatorStart(
ob.op_stack[tid].push(thread_node_id);
writeJsonNode(
ob.out,
"[pytorch|profiler|execution_graph|thread]",
"[pytorch|profiler|execution_trace|thread]",
thread_node_id,
0, // rf_id
root_id,
@ -499,13 +499,13 @@ static void recordOperatorStart(
ob.op_stack[tid].push(fc.op_id);
} catch (const std::exception& e) {
LOG(WARNING) << "Exception in execution graph observer: " << e.what();
LOG(WARNING) << "Exception in execution trace observer: " << e.what();
}
}
static std::unique_ptr<ObserverContext> onFunctionEnter(
const RecordFunction& fn) {
using RunState = ExecutionGraphObserver::RunState;
using RunState = ExecutionTraceObserver::RunState;
auto ob = ObserverManager::get();
if (ob != nullptr && ob->getState() == RunState::enabled) {
// record op
@ -544,7 +544,7 @@ inline std::string json_str_escape(const std::string& str) {
}
static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
using RunState = ExecutionGraphObserver::RunState;
using RunState = ExecutionTraceObserver::RunState;
auto ob = ObserverManager::get();
if (ob == nullptr || ctx_ptr == nullptr) {
return;
@ -613,23 +613,23 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
op_schema_str);
ob->out << ",";
} catch (const std::exception& e) {
LOG(WARNING) << "Exception in execution graph observer: [" << fc.name
LOG(WARNING) << "Exception in execution trace observer: [" << fc.name
<< " (" << fc.op_id << ")] " << e.what();
}
}
}
// Add execution graph observer callback functions to the RecordFunction global
// Add execution trace observer callback functions to the RecordFunction global
// observers.
bool addExecutionGraphObserver(const std::string& output_file_path) {
bool addExecutionTraceObserver(const std::string& output_file_path) {
// Check if the observer is already initialized.
if (ObserverManager::get() == nullptr) {
ObserverManager::push(std::make_shared<ExecutionGraphObserver>());
ObserverManager::push(std::make_shared<ExecutionTraceObserver>());
auto& ob = *ObserverManager::get();
ob.pid = processId();
// Set output
ob.file_name = output_file_path;
if (!initExecutionGraphStart(ob)) {
if (!initExecutionTraceStart(ob)) {
return false;
}
@ -639,60 +639,60 @@ bool addExecutionGraphObserver(const std::string& output_file_path) {
.needsOutputs(true)
.needsIds(true));
// Default to disabled.
ob.setState(ExecutionGraphObserver::RunState::disabled);
ob.setState(ExecutionTraceObserver::RunState::disabled);
VLOG(1) << "Added PyTorch execution graph observer, output="
VLOG(1) << "PyTorch Execution Trace: added observer, output="
<< output_file_path;
} else if (ObserverManager::get()->cb_handle != INVALID_CALLBACK_HANDLE) {
LOG(WARNING) << "Execution graph observer is already registered.";
LOG(WARNING) << "Execution trace observer is already registered.";
}
return true;
}
void removeExecutionGraphObserver() {
void removeExecutionTraceObserver() {
auto ob = ObserverManager::get();
if (ob != nullptr) {
if (ob->getState() != ExecutionGraphObserver::RunState::disabled) {
disableExecutionGraphObserver();
if (ob->getState() != ExecutionTraceObserver::RunState::disabled) {
disableExecutionTraceObserver();
}
if (ob->cb_handle != INVALID_CALLBACK_HANDLE) {
finalizeExecutionGraphOutput(*ob);
finalizeExecutionTraceOutput(*ob);
removeCallback(ob->cb_handle);
ob->cb_handle = INVALID_CALLBACK_HANDLE;
// Release the current EG observer object and reset.
// Release the current ET observer object and reset.
TORCH_INTERNAL_ASSERT(
ObserverManager::pop() != nullptr,
"Global state ptr cannot be null before resetting");
VLOG(1) << "Removed PyTorch execution graph observer";
VLOG(1) << "PyTorch Execution Trace: removed observer";
} else {
LOG(WARNING) << "Execution graph observer was not registered.";
LOG(WARNING) << "Execution trace observer was not registered.";
}
} else {
LOG(WARNING) << "Execution graph observer was not initialized.";
LOG(WARNING) << "Execution trace observer was not initialized.";
}
}
void enableExecutionGraphObserver() {
VLOG(1) << "enableExecutionGraphObserver() ";
void enableExecutionTraceObserver() {
VLOG(1) << "enableExecutionTraceObserver() ";
auto& ob = *ObserverManager::get();
// Make sure we are not already enabled.
if (ob.getState() == ExecutionGraphObserver::RunState::enabled) {
if (ob.getState() == ExecutionTraceObserver::RunState::enabled) {
LOG(WARNING)
<< "Trying to enable Execution Graph Observer when it's already enabled.";
<< "Trying to enable Execution Trace Observer when it's already enabled.";
} else {
ob.setState(ExecutionGraphObserver::RunState::enabled);
ob.setState(ExecutionTraceObserver::RunState::enabled);
}
}
void disableExecutionGraphObserver() {
VLOG(1) << "disableExecutionGraphObserver()";
void disableExecutionTraceObserver() {
VLOG(1) << "disableExecutionTraceObserver()";
auto& ob = *ObserverManager::get();
if (ob.getState() != ExecutionGraphObserver::RunState::disabled) {
ob.setState(ExecutionGraphObserver::RunState::disabled);
if (ob.getState() != ExecutionTraceObserver::RunState::disabled) {
ob.setState(ExecutionTraceObserver::RunState::disabled);
} else {
LOG(WARNING)
<< "Trying to disable Execution Graph Observer when it's already disabled.";
<< "Trying to disable Execution Trace Observer when it's already disabled.";
}
}
} // namespace impl

View File

@ -0,0 +1,25 @@
#pragma once
#include <c10/macros/Export.h>
#include <string>
namespace torch {
namespace profiler {
namespace impl {
// Adds the execution trace observer as a global callback function, the data
// will be written to output file path.
TORCH_API bool addExecutionTraceObserver(const std::string& output_file_path);
// Remove the execution trace observer from the global callback functions.
TORCH_API void removeExecutionTraceObserver();
// Enables execution trace observer.
TORCH_API void enableExecutionTraceObserver();
// Disables execution trace observer.
TORCH_API void disableExecutionTraceObserver();
} // namespace impl
} // namespace profiler
} // namespace torch

View File

@ -16,7 +16,7 @@ from torch.optim.optimizer import register_optimizer_step_post_hook
from .profiler import (
_KinetoProfile,
ExecutionGraphObserver,
ExecutionTraceObserver,
profile,
ProfilerAction,
schedule,
@ -34,7 +34,7 @@ __all__ = [
"kineto_available",
"DeviceType",
"record_function",
"ExecutionGraphObserver",
"ExecutionTraceObserver",
]
from . import itt

View File

@ -10,11 +10,11 @@ from warnings import warn
import torch
import torch.autograd.profiler as prof
from torch._C._profiler import (
_add_execution_graph_observer,
_disable_execution_graph_observer,
_enable_execution_graph_observer,
_add_execution_trace_observer,
_disable_execution_trace_observer,
_enable_execution_trace_observer,
_ExperimentalConfig,
_remove_execution_graph_observer,
_remove_execution_trace_observer,
)
from torch.autograd import kineto_available, ProfilerActivity
from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
@ -26,7 +26,7 @@ __all__ = [
"schedule",
"tensorboard_trace_handler",
"profile",
"ExecutionGraphObserver",
"ExecutionTraceObserver",
]
PROFILER_STEP_NAME = "ProfilerStep"
@ -601,14 +601,14 @@ class profile(_KinetoProfile):
class ExecutionGraphObserver:
"""Execution Graph Observer
class ExecutionTraceObserver:
"""Execution Trace Observer
Each process can have a single ExecutionGraphObserver instance. The observer
Each process can have a single ExecutionTraceObserver instance. The observer
can be added to record function callbacks via calling register_callback()
explicitly. Without calling unregister_callback(), repeated calls to
register_callback() will not add additional observers to record function
callbacks. Once an ExecutionGraphObserver is created, the start() and stop()
callbacks. Once an ExecutionTraceObserver is created, the start() and stop()
methods control when the event data is recorded.
Deleting or calling unregister_callback() will remove the observer from the
@ -620,7 +620,7 @@ class ExecutionGraphObserver:
Initializes the default states.
"""
self._registered = False
self._execution_graph_running = False
self._execution_trace_running = False
def __del__(self):
"""
@ -630,44 +630,50 @@ class ExecutionGraphObserver:
def register_callback(self, output_file_path: str):
"""
Adds EG observer to record function callbacks. The the data will be
Adds ET observer to record function callbacks. The the data will be
written to output_file_path.
"""
if not self._registered:
self._output_file_path = output_file_path
self._registered = _add_execution_graph_observer(output_file_path)
self._registered = _add_execution_trace_observer(output_file_path)
def unregister_callback(self):
"""
Removes EG observer from record function callbacks.
Removes ET observer from record function callbacks.
"""
if self._registered:
self.stop()
_remove_execution_graph_observer()
_remove_execution_trace_observer()
self._registered = False
@property
def is_registered(self):
"""
Return if the execution graph observer is registered.
Returns True if the execution trace observer is registered, otherwise False.
"""
return self._registered
def is_running(self):
"""
Returns True if the observer is running, otherwise False.
"""
return self._execution_trace_running
def start(self):
"""
Starts to capture.
"""
if self._registered and not self._execution_graph_running:
_enable_execution_graph_observer()
self._execution_graph_running = True
if self._registered and not self._execution_trace_running:
_enable_execution_trace_observer()
self._execution_trace_running = True
def stop(self):
"""
Stops to capture.
"""
if self._execution_graph_running:
_disable_execution_graph_observer()
self._execution_graph_running = False
if self._execution_trace_running:
_disable_execution_trace_observer()
self._execution_trace_running = False
def get_output_file_path(self) -> str:
"""
@ -677,6 +683,6 @@ class ExecutionGraphObserver:
return self._output_file_path
else:
raise RuntimeError(
"A callback to the EG profiler needs to be registered "
"A callback to the ET profiler needs to be registered "
"first before getting the output file path"
)