mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyPer][ET] Refactor EG to ET (#99694)
Summary: Change execution graph to execution trace. See post: https://fb.workplace.com/groups/873291503156329/permalink/1529496217535851/ Test Plan: Run a job. Reviewed By: chaekit Differential Revision: D44121392 Pull Request resolved: https://github.com/pytorch/pytorch/pull/99694 Approved by: https://github.com/chaekit
This commit is contained in:
committed by
PyTorch MergeBot
parent
ec922efe3b
commit
5847cb55e4
@ -139,7 +139,7 @@ libtorch_profiler_sources = [
|
||||
"torch/csrc/profiler/kineto_client_interface.cpp",
|
||||
"torch/csrc/profiler/orchestration/observer.cpp",
|
||||
"torch/csrc/profiler/orchestration/python_tracer.cpp",
|
||||
"torch/csrc/profiler/standalone/execution_graph_observer.cpp",
|
||||
"torch/csrc/profiler/standalone/execution_trace_observer.cpp",
|
||||
"torch/csrc/profiler/standalone/itt_observer.cpp",
|
||||
"torch/csrc/profiler/standalone/nvtx_observer.cpp",
|
||||
"torch/csrc/profiler/stubs/base.cpp",
|
||||
|
@ -32,7 +32,7 @@ from torch.autograd.profiler_legacy import profile as _profile_legacy
|
||||
from torch.profiler import (
|
||||
_utils,
|
||||
DeviceType,
|
||||
ExecutionGraphObserver,
|
||||
ExecutionTraceObserver,
|
||||
kineto_available,
|
||||
profile,
|
||||
ProfilerAction,
|
||||
@ -314,7 +314,7 @@ class TestRecordFunction(TestCase):
|
||||
self.assertTrue(has_child)
|
||||
|
||||
|
||||
class TestExecutionGraph(TestCase):
|
||||
class TestExecutionTrace(TestCase):
|
||||
def payload(self, use_cuda=False):
|
||||
u = torch.randn(3, 4, 5, requires_grad=True)
|
||||
with record_function("## TEST 1 ##", "1, 2, 3"):
|
||||
@ -338,16 +338,16 @@ class TestExecutionGraph(TestCase):
|
||||
z = z.cpu()
|
||||
_record_function_with_args_exit(rf_handle)
|
||||
|
||||
def get_execution_graph_root(self, output_file_name):
|
||||
def get_execution_trace_root(self, output_file_name):
|
||||
nodes = []
|
||||
with open(output_file_name, 'r') as f:
|
||||
eg_graph = json.load(f)
|
||||
assert "nodes" in eg_graph
|
||||
nodes = eg_graph["nodes"]
|
||||
et_graph = json.load(f)
|
||||
assert "nodes" in et_graph
|
||||
nodes = et_graph["nodes"]
|
||||
return nodes
|
||||
|
||||
@unittest.skipIf(not kineto_available(), "Kineto is required")
|
||||
def test_execution_graph_with_kineto(self):
|
||||
def test_execution_trace_with_kineto(self):
|
||||
trace_called_num = 0
|
||||
|
||||
def trace_handler(p):
|
||||
@ -355,12 +355,12 @@ class TestExecutionGraph(TestCase):
|
||||
trace_called_num += 1
|
||||
|
||||
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
|
||||
# Create a temp file to save execution graph data.
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
|
||||
# Create a temp file to save execution trace data.
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.et.json', delete=False)
|
||||
fp.close()
|
||||
expected_loop_events = 0
|
||||
eg = ExecutionGraphObserver()
|
||||
eg.register_callback(fp.name)
|
||||
et = ExecutionTraceObserver()
|
||||
et.register_callback(fp.name)
|
||||
with profile(
|
||||
activities=supported_activities(),
|
||||
schedule=torch.profiler.schedule(
|
||||
@ -370,50 +370,50 @@ class TestExecutionGraph(TestCase):
|
||||
active=2),
|
||||
on_trace_ready=trace_handler,
|
||||
) as p:
|
||||
eg.start()
|
||||
et.start()
|
||||
for idx in range(10):
|
||||
expected_loop_events += 1
|
||||
with record_function(f"## LOOP {idx} ##"):
|
||||
self.payload(use_cuda=use_cuda)
|
||||
p.step()
|
||||
eg.stop()
|
||||
et.stop()
|
||||
|
||||
assert trace_called_num == 2
|
||||
assert fp.name == eg.get_output_file_path()
|
||||
assert fp.name == et.get_output_file_path()
|
||||
|
||||
# cleanup
|
||||
eg.unregister_callback()
|
||||
nodes = self.get_execution_graph_root(fp.name)
|
||||
et.unregister_callback()
|
||||
nodes = self.get_execution_trace_root(fp.name)
|
||||
loop_count = 0
|
||||
found_root_node = False
|
||||
for n in nodes:
|
||||
assert "name" in n
|
||||
if "[pytorch|profiler|execution_graph|process]" in n["name"]:
|
||||
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
|
||||
found_root_node = True
|
||||
if n["name"].startswith("## LOOP "):
|
||||
loop_count += 1
|
||||
assert found_root_node
|
||||
assert loop_count == expected_loop_events
|
||||
|
||||
def test_execution_graph_alone(self):
|
||||
def test_execution_trace_alone(self):
|
||||
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
|
||||
# Create a temp file to save execution graph data.
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
|
||||
# Create a temp file to save execution trace data.
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.et.json', delete=False)
|
||||
fp.close()
|
||||
expected_loop_events = 0
|
||||
|
||||
eg = ExecutionGraphObserver()
|
||||
eg.register_callback(fp.name)
|
||||
eg.start()
|
||||
et = ExecutionTraceObserver()
|
||||
et.register_callback(fp.name)
|
||||
et.start()
|
||||
for idx in range(5):
|
||||
expected_loop_events += 1
|
||||
with record_function(f"## LOOP {idx} ##"):
|
||||
self.payload(use_cuda=use_cuda)
|
||||
eg.stop()
|
||||
et.stop()
|
||||
|
||||
assert fp.name == eg.get_output_file_path()
|
||||
eg.unregister_callback()
|
||||
nodes = self.get_execution_graph_root(fp.name)
|
||||
assert fp.name == et.get_output_file_path()
|
||||
et.unregister_callback()
|
||||
nodes = self.get_execution_trace_root(fp.name)
|
||||
loop_count = 0
|
||||
# Expected tensor object tuple size, in th form of:
|
||||
# [tensor_id, storage_id, offset, numel, itemsize, device_str]
|
||||
@ -421,7 +421,7 @@ class TestExecutionGraph(TestCase):
|
||||
found_root_node = False
|
||||
for n in nodes:
|
||||
assert "name" in n
|
||||
if "[pytorch|profiler|execution_graph|process]" in n["name"]:
|
||||
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
|
||||
found_root_node = True
|
||||
if n["name"].startswith("## LOOP "):
|
||||
loop_count += 1
|
||||
@ -431,69 +431,69 @@ class TestExecutionGraph(TestCase):
|
||||
assert found_root_node
|
||||
assert loop_count == expected_loop_events
|
||||
|
||||
def test_execution_graph_start_stop(self):
|
||||
def test_execution_trace_start_stop(self):
|
||||
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
|
||||
# Create a temp file to save execution graph data.
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
|
||||
# Create a temp file to save execution trace data.
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.et.json', delete=False)
|
||||
fp.close()
|
||||
expected_loop_events = 0
|
||||
eg = ExecutionGraphObserver()
|
||||
eg.register_callback(fp.name)
|
||||
et = ExecutionTraceObserver()
|
||||
et.register_callback(fp.name)
|
||||
for idx in range(10):
|
||||
if idx == 3:
|
||||
eg.start()
|
||||
et.start()
|
||||
elif idx == 5:
|
||||
eg.stop()
|
||||
et.stop()
|
||||
elif idx == 8:
|
||||
eg.start()
|
||||
et.start()
|
||||
elif idx == 9:
|
||||
eg.stop()
|
||||
if eg._execution_graph_running:
|
||||
et.stop()
|
||||
if et._execution_trace_running:
|
||||
expected_loop_events += 1
|
||||
with record_function(f"## LOOP {idx} ##"):
|
||||
self.payload(use_cuda=use_cuda)
|
||||
|
||||
assert fp.name == eg.get_output_file_path()
|
||||
eg.unregister_callback()
|
||||
nodes = self.get_execution_graph_root(fp.name)
|
||||
assert fp.name == et.get_output_file_path()
|
||||
et.unregister_callback()
|
||||
nodes = self.get_execution_trace_root(fp.name)
|
||||
loop_count = 0
|
||||
found_root_node = False
|
||||
for n in nodes:
|
||||
assert "name" in n
|
||||
if "[pytorch|profiler|execution_graph|process]" in n["name"]:
|
||||
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
|
||||
found_root_node = True
|
||||
if n["name"].startswith("## LOOP "):
|
||||
loop_count += 1
|
||||
assert found_root_node
|
||||
assert loop_count == expected_loop_events
|
||||
|
||||
def test_execution_graph_repeat_in_loop(self):
|
||||
def test_execution_trace_repeat_in_loop(self):
|
||||
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
|
||||
iter_list = {3, 4, 6, 8}
|
||||
expected_loop_events = len(iter_list)
|
||||
output_files = []
|
||||
for idx in range(10):
|
||||
if idx in iter_list:
|
||||
# Create a temp file to save execution graph data.
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
|
||||
# Create a temp file to save execution trace data.
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.et.json', delete=False)
|
||||
fp.close()
|
||||
output_files.append(fp.name)
|
||||
eg = ExecutionGraphObserver()
|
||||
eg.register_callback(fp.name)
|
||||
eg.start()
|
||||
et = ExecutionTraceObserver()
|
||||
et.register_callback(fp.name)
|
||||
et.start()
|
||||
with record_function(f"## LOOP {idx} ##"):
|
||||
self.payload(use_cuda=use_cuda)
|
||||
if idx in iter_list:
|
||||
eg.stop()
|
||||
eg.unregister_callback()
|
||||
et.stop()
|
||||
et.unregister_callback()
|
||||
|
||||
event_count = 0
|
||||
for eg_file in output_files:
|
||||
nodes = self.get_execution_graph_root(eg_file)
|
||||
for et_file in output_files:
|
||||
nodes = self.get_execution_trace_root(et_file)
|
||||
found_root_node = False
|
||||
for n in nodes:
|
||||
assert "name" in n
|
||||
if "[pytorch|profiler|execution_graph|process]" in n["name"]:
|
||||
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
|
||||
assert n["id"] == 1
|
||||
found_root_node = True
|
||||
if n["name"].startswith("## LOOP "):
|
||||
@ -501,18 +501,18 @@ class TestExecutionGraph(TestCase):
|
||||
assert found_root_node
|
||||
assert event_count == expected_loop_events
|
||||
|
||||
def test_execution_graph_no_capture(self):
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
|
||||
def test_execution_trace_no_capture(self):
|
||||
fp = tempfile.NamedTemporaryFile('w+t', suffix='.et.json', delete=False)
|
||||
fp.close()
|
||||
eg = ExecutionGraphObserver()
|
||||
eg.register_callback(fp.name)
|
||||
et = ExecutionTraceObserver()
|
||||
et.register_callback(fp.name)
|
||||
|
||||
assert fp.name == eg.get_output_file_path()
|
||||
eg.unregister_callback()
|
||||
nodes = self.get_execution_graph_root(fp.name)
|
||||
assert fp.name == et.get_output_file_path()
|
||||
et.unregister_callback()
|
||||
nodes = self.get_execution_trace_root(fp.name)
|
||||
for n in nodes:
|
||||
assert "name" in n
|
||||
if "[pytorch|profiler|execution_graph|process]" in n["name"]:
|
||||
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
|
||||
found_root_node = True
|
||||
assert found_root_node
|
||||
|
||||
|
@ -210,9 +210,9 @@ class _ExtraFields_PyCall:
|
||||
|
||||
class _ExtraFields_Kineto: ...
|
||||
|
||||
def _add_execution_graph_observer(output_file_path: str) -> bool: ...
|
||||
def _remove_execution_graph_observer() -> None: ...
|
||||
def _enable_execution_graph_observer() -> None: ...
|
||||
def _disable_execution_graph_observer() -> None: ...
|
||||
def _add_execution_trace_observer(output_file_path: str) -> bool: ...
|
||||
def _remove_execution_trace_observer() -> None: ...
|
||||
def _enable_execution_trace_observer() -> None: ...
|
||||
def _disable_execution_trace_observer() -> None: ...
|
||||
def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ...
|
||||
def _set_fwd_bwd_enabled_val(val: bool) -> None: ...
|
||||
|
@ -7,7 +7,7 @@
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/profiler/collection.h>
|
||||
#include <torch/csrc/profiler/python/combined_traceback.h>
|
||||
#include <torch/csrc/profiler/standalone/execution_graph_observer.h>
|
||||
#include <torch/csrc/profiler/standalone/execution_trace_observer.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
@ -283,20 +283,20 @@ void initPythonBindings(PyObject* module) {
|
||||
return r.endTimeNS() - r.start_time_ns_;
|
||||
});
|
||||
|
||||
// PyTorch profiler execution graph internal interface.
|
||||
// PyTorch profiler execution trace internal interface.
|
||||
m.def(
|
||||
"_add_execution_graph_observer",
|
||||
&torch::profiler::impl::addExecutionGraphObserver,
|
||||
"_add_execution_trace_observer",
|
||||
&torch::profiler::impl::addExecutionTraceObserver,
|
||||
py::arg("output_file_name"));
|
||||
m.def(
|
||||
"_remove_execution_graph_observer",
|
||||
&torch::profiler::impl::removeExecutionGraphObserver);
|
||||
"_remove_execution_trace_observer",
|
||||
&torch::profiler::impl::removeExecutionTraceObserver);
|
||||
m.def(
|
||||
"_enable_execution_graph_observer",
|
||||
&torch::profiler::impl::enableExecutionGraphObserver);
|
||||
"_enable_execution_trace_observer",
|
||||
&torch::profiler::impl::enableExecutionTraceObserver);
|
||||
m.def(
|
||||
"_disable_execution_graph_observer",
|
||||
&torch::profiler::impl::disableExecutionGraphObserver);
|
||||
"_disable_execution_trace_observer",
|
||||
&torch::profiler::impl::disableExecutionTraceObserver);
|
||||
m.def(
|
||||
"_set_record_concrete_inputs_enabled_val",
|
||||
&torch::profiler::impl::set_record_concrete_inputs_enabled_val);
|
||||
|
@ -1,25 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
namespace impl {
|
||||
|
||||
// Adds the execution graph observer as a global callback function, the data
|
||||
// will be written to output file path.
|
||||
TORCH_API bool addExecutionGraphObserver(const std::string& output_file_path);
|
||||
|
||||
// Remove the execution graph observer from the global callback functions.
|
||||
TORCH_API void removeExecutionGraphObserver();
|
||||
|
||||
// Enables execution graph observer.
|
||||
TORCH_API void enableExecutionGraphObserver();
|
||||
|
||||
// Disables execution graph observer.
|
||||
TORCH_API void disableExecutionGraphObserver();
|
||||
|
||||
} // namespace impl
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
@ -27,7 +27,7 @@
|
||||
#include <ATen/core/stack.h>
|
||||
#include <ATen/record_function.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/profiler/standalone/execution_graph_observer.h>
|
||||
#include <torch/csrc/profiler/standalone/execution_trace_observer.h>
|
||||
#include <torch/csrc/profiler/util.h>
|
||||
|
||||
using namespace at;
|
||||
@ -147,17 +147,17 @@ inline int32_t processId() {
|
||||
}
|
||||
|
||||
//******************************************************************************
|
||||
// Main ExecutionGraphObserver implementation.
|
||||
// Main ExecutionTraceObserver implementation.
|
||||
//******************************************************************************
|
||||
|
||||
// ExecutionGraphObserver contains all the states of the observer. Some of them
|
||||
// ExecutionTraceObserver contains all the states of the observer. Some of them
|
||||
// are shared between the enter and exit RecordFunction call backs, some data
|
||||
// like the `op_stack` may be accessed across different threads. So we should be
|
||||
// careful about data races. A global mutex `g_mutex` is used avoid these races
|
||||
// at the cost of performance in large number of threads situations. We may
|
||||
// optimize this further to thread local, fine-grained locking, or use thread
|
||||
// safe containers.
|
||||
struct TORCH_API ExecutionGraphObserver {
|
||||
struct TORCH_API ExecutionTraceObserver {
|
||||
using ID = size_t;
|
||||
|
||||
// Mapping of each thread to its own operator stack
|
||||
@ -183,7 +183,7 @@ struct TORCH_API ExecutionGraphObserver {
|
||||
int32_t pid{-1};
|
||||
std::string record_time{};
|
||||
|
||||
ExecutionGraphObserver() = default;
|
||||
ExecutionTraceObserver() = default;
|
||||
|
||||
// Returns a new unique ID.
|
||||
ID getNewID() {
|
||||
@ -208,7 +208,7 @@ struct TORCH_API ExecutionGraphObserver {
|
||||
|
||||
private:
|
||||
static bool callbackShouldBeEnabled(RunState run_state) {
|
||||
return run_state == ExecutionGraphObserver::RunState::enabled;
|
||||
return run_state == ExecutionTraceObserver::RunState::enabled;
|
||||
}
|
||||
|
||||
// Must use accessors to change this so that we can keep the
|
||||
@ -224,31 +224,31 @@ struct TORCH_API ExecutionGraphObserver {
|
||||
};
|
||||
|
||||
// Using a singleton manager here to allow init and delete the observer object.
|
||||
using ObserverManager = GlobalStateManager<ExecutionGraphObserver>;
|
||||
using ObserverManager = GlobalStateManager<ExecutionTraceObserver>;
|
||||
|
||||
// Uninitialized node has id = 0
|
||||
const ExecutionGraphObserver::ID uninitialized_id{0};
|
||||
const ExecutionTraceObserver::ID uninitialized_id{0};
|
||||
// Root node has id = 1
|
||||
const ExecutionGraphObserver::ID root_id{1};
|
||||
const ExecutionTraceObserver::ID root_id{1};
|
||||
|
||||
struct FunctionCallContext : public ObserverContext {
|
||||
std::string name;
|
||||
ExecutionGraphObserver::ID op_id{uninitialized_id};
|
||||
ExecutionGraphObserver::ID parent_id{uninitialized_id};
|
||||
ExecutionGraphObserver::ID fw_parent_id{uninitialized_id};
|
||||
ExecutionTraceObserver::ID op_id{uninitialized_id};
|
||||
ExecutionTraceObserver::ID parent_id{uninitialized_id};
|
||||
ExecutionTraceObserver::ID fw_parent_id{uninitialized_id};
|
||||
std::vector<std::string> input_types;
|
||||
std::vector<std::string> input_shapes;
|
||||
std::vector<std::string> input_values;
|
||||
};
|
||||
|
||||
// Opens the json file to write the execution graph.
|
||||
// Opens the json file to write the execution trace.
|
||||
static std::ofstream openOutputFile(const std::string& name) {
|
||||
std::ofstream stream;
|
||||
stream.open(name, std::ofstream::out | std::ofstream::trunc);
|
||||
if (!stream) {
|
||||
LOG(ERROR) << "Failed to open '" << name << "'";
|
||||
} else {
|
||||
VLOG(1) << "Writing PyTorch execution graph to: " << name;
|
||||
VLOG(1) << "PyTorch Execution Trace: writing to " << name;
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
@ -302,7 +302,7 @@ inline std::string timeString(const std::time_t timepoint) {
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
static bool initExecutionGraphStart(ExecutionGraphObserver& ob) {
|
||||
static bool initExecutionTraceStart(ExecutionTraceObserver& ob) {
|
||||
ob.out = openOutputFile(ob.file_name);
|
||||
// If somehow the output stream failed to open, finish observer here.
|
||||
if (!ob.out) {
|
||||
@ -330,11 +330,11 @@ static bool initExecutionGraphStart(ExecutionGraphObserver& ob) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Write out Execution Graph to file
|
||||
static void finalizeExecutionGraphOutput(ExecutionGraphObserver& ob) {
|
||||
// Write out Execution Trace to file
|
||||
static void finalizeExecutionTraceOutput(ExecutionTraceObserver& ob) {
|
||||
writeJsonNode(
|
||||
ob.out,
|
||||
"[pytorch|profiler|execution_graph|process]",
|
||||
"[pytorch|profiler|execution_trace|process]",
|
||||
root_id,
|
||||
0, // rf_id
|
||||
root_id, // parent is self
|
||||
@ -357,15 +357,15 @@ static void finalizeExecutionGraphOutput(ExecutionGraphObserver& ob) {
|
||||
timestamp);
|
||||
|
||||
ob.out.close();
|
||||
VLOG(1) << "PyTorch execution graph is written to file: " << ob.file_name;
|
||||
VLOG(1) << "PyTorch Execution Trace: written to file " << ob.file_name;
|
||||
}
|
||||
|
||||
inline ExecutionGraphObserver::ID getObjectID(
|
||||
ExecutionGraphObserver& ob,
|
||||
inline ExecutionTraceObserver::ID getObjectID(
|
||||
ExecutionTraceObserver& ob,
|
||||
const void* t) {
|
||||
auto iter = ob.object_id.find(t);
|
||||
if (iter == ob.object_id.end()) {
|
||||
ExecutionGraphObserver::ID object_id = ob.getNewID();
|
||||
ExecutionTraceObserver::ID object_id = ob.getNewID();
|
||||
ob.object_id[t] = object_id;
|
||||
return object_id;
|
||||
}
|
||||
@ -374,13 +374,13 @@ inline ExecutionGraphObserver::ID getObjectID(
|
||||
}
|
||||
|
||||
inline std::string convertIValue(
|
||||
ExecutionGraphObserver& ob,
|
||||
ExecutionTraceObserver& ob,
|
||||
const c10::IValue& val,
|
||||
const size_t maxArrayLen = maxNumElements) {
|
||||
if (val.isTensor()) {
|
||||
const auto t = val.toTensor().unsafeGetTensorImpl();
|
||||
ExecutionGraphObserver::ID tensor_id = getObjectID(ob, t);
|
||||
ExecutionGraphObserver::ID storage_id = 0;
|
||||
ExecutionTraceObserver::ID tensor_id = getObjectID(ob, t);
|
||||
ExecutionTraceObserver::ID storage_id = 0;
|
||||
size_t offset = 0;
|
||||
size_t numel = 0;
|
||||
size_t itemsize = 0;
|
||||
@ -427,7 +427,7 @@ inline std::string convertIValue(
|
||||
}
|
||||
|
||||
inline void appendValueInfo(
|
||||
ExecutionGraphObserver& ob,
|
||||
ExecutionTraceObserver& ob,
|
||||
const c10::IValue& val,
|
||||
std::vector<std::string>& values,
|
||||
std::vector<std::string>& types,
|
||||
@ -438,7 +438,7 @@ inline void appendValueInfo(
|
||||
}
|
||||
|
||||
static void recordOperatorStart(
|
||||
ExecutionGraphObserver& ob,
|
||||
ExecutionTraceObserver& ob,
|
||||
FunctionCallContext& fc,
|
||||
const RecordFunction& fn) {
|
||||
auto tid = fn.threadId();
|
||||
@ -452,7 +452,7 @@ static void recordOperatorStart(
|
||||
ob.op_stack[tid].push(thread_node_id);
|
||||
writeJsonNode(
|
||||
ob.out,
|
||||
"[pytorch|profiler|execution_graph|thread]",
|
||||
"[pytorch|profiler|execution_trace|thread]",
|
||||
thread_node_id,
|
||||
0, // rf_id
|
||||
root_id,
|
||||
@ -499,13 +499,13 @@ static void recordOperatorStart(
|
||||
ob.op_stack[tid].push(fc.op_id);
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
LOG(WARNING) << "Exception in execution graph observer: " << e.what();
|
||||
LOG(WARNING) << "Exception in execution trace observer: " << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
static std::unique_ptr<ObserverContext> onFunctionEnter(
|
||||
const RecordFunction& fn) {
|
||||
using RunState = ExecutionGraphObserver::RunState;
|
||||
using RunState = ExecutionTraceObserver::RunState;
|
||||
auto ob = ObserverManager::get();
|
||||
if (ob != nullptr && ob->getState() == RunState::enabled) {
|
||||
// record op
|
||||
@ -544,7 +544,7 @@ inline std::string json_str_escape(const std::string& str) {
|
||||
}
|
||||
|
||||
static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
|
||||
using RunState = ExecutionGraphObserver::RunState;
|
||||
using RunState = ExecutionTraceObserver::RunState;
|
||||
auto ob = ObserverManager::get();
|
||||
if (ob == nullptr || ctx_ptr == nullptr) {
|
||||
return;
|
||||
@ -613,23 +613,23 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
|
||||
op_schema_str);
|
||||
ob->out << ",";
|
||||
} catch (const std::exception& e) {
|
||||
LOG(WARNING) << "Exception in execution graph observer: [" << fc.name
|
||||
LOG(WARNING) << "Exception in execution trace observer: [" << fc.name
|
||||
<< " (" << fc.op_id << ")] " << e.what();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add execution graph observer callback functions to the RecordFunction global
|
||||
// Add execution trace observer callback functions to the RecordFunction global
|
||||
// observers.
|
||||
bool addExecutionGraphObserver(const std::string& output_file_path) {
|
||||
bool addExecutionTraceObserver(const std::string& output_file_path) {
|
||||
// Check if the observer is already initialized.
|
||||
if (ObserverManager::get() == nullptr) {
|
||||
ObserverManager::push(std::make_shared<ExecutionGraphObserver>());
|
||||
ObserverManager::push(std::make_shared<ExecutionTraceObserver>());
|
||||
auto& ob = *ObserverManager::get();
|
||||
ob.pid = processId();
|
||||
// Set output
|
||||
ob.file_name = output_file_path;
|
||||
if (!initExecutionGraphStart(ob)) {
|
||||
if (!initExecutionTraceStart(ob)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -639,60 +639,60 @@ bool addExecutionGraphObserver(const std::string& output_file_path) {
|
||||
.needsOutputs(true)
|
||||
.needsIds(true));
|
||||
// Default to disabled.
|
||||
ob.setState(ExecutionGraphObserver::RunState::disabled);
|
||||
ob.setState(ExecutionTraceObserver::RunState::disabled);
|
||||
|
||||
VLOG(1) << "Added PyTorch execution graph observer, output="
|
||||
VLOG(1) << "PyTorch Execution Trace: added observer, output="
|
||||
<< output_file_path;
|
||||
} else if (ObserverManager::get()->cb_handle != INVALID_CALLBACK_HANDLE) {
|
||||
LOG(WARNING) << "Execution graph observer is already registered.";
|
||||
LOG(WARNING) << "Execution trace observer is already registered.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void removeExecutionGraphObserver() {
|
||||
void removeExecutionTraceObserver() {
|
||||
auto ob = ObserverManager::get();
|
||||
if (ob != nullptr) {
|
||||
if (ob->getState() != ExecutionGraphObserver::RunState::disabled) {
|
||||
disableExecutionGraphObserver();
|
||||
if (ob->getState() != ExecutionTraceObserver::RunState::disabled) {
|
||||
disableExecutionTraceObserver();
|
||||
}
|
||||
|
||||
if (ob->cb_handle != INVALID_CALLBACK_HANDLE) {
|
||||
finalizeExecutionGraphOutput(*ob);
|
||||
finalizeExecutionTraceOutput(*ob);
|
||||
removeCallback(ob->cb_handle);
|
||||
ob->cb_handle = INVALID_CALLBACK_HANDLE;
|
||||
// Release the current EG observer object and reset.
|
||||
// Release the current ET observer object and reset.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
ObserverManager::pop() != nullptr,
|
||||
"Global state ptr cannot be null before resetting");
|
||||
VLOG(1) << "Removed PyTorch execution graph observer";
|
||||
VLOG(1) << "PyTorch Execution Trace: removed observer";
|
||||
} else {
|
||||
LOG(WARNING) << "Execution graph observer was not registered.";
|
||||
LOG(WARNING) << "Execution trace observer was not registered.";
|
||||
}
|
||||
} else {
|
||||
LOG(WARNING) << "Execution graph observer was not initialized.";
|
||||
LOG(WARNING) << "Execution trace observer was not initialized.";
|
||||
}
|
||||
}
|
||||
|
||||
void enableExecutionGraphObserver() {
|
||||
VLOG(1) << "enableExecutionGraphObserver() ";
|
||||
void enableExecutionTraceObserver() {
|
||||
VLOG(1) << "enableExecutionTraceObserver() ";
|
||||
auto& ob = *ObserverManager::get();
|
||||
// Make sure we are not already enabled.
|
||||
if (ob.getState() == ExecutionGraphObserver::RunState::enabled) {
|
||||
if (ob.getState() == ExecutionTraceObserver::RunState::enabled) {
|
||||
LOG(WARNING)
|
||||
<< "Trying to enable Execution Graph Observer when it's already enabled.";
|
||||
<< "Trying to enable Execution Trace Observer when it's already enabled.";
|
||||
} else {
|
||||
ob.setState(ExecutionGraphObserver::RunState::enabled);
|
||||
ob.setState(ExecutionTraceObserver::RunState::enabled);
|
||||
}
|
||||
}
|
||||
|
||||
void disableExecutionGraphObserver() {
|
||||
VLOG(1) << "disableExecutionGraphObserver()";
|
||||
void disableExecutionTraceObserver() {
|
||||
VLOG(1) << "disableExecutionTraceObserver()";
|
||||
auto& ob = *ObserverManager::get();
|
||||
if (ob.getState() != ExecutionGraphObserver::RunState::disabled) {
|
||||
ob.setState(ExecutionGraphObserver::RunState::disabled);
|
||||
if (ob.getState() != ExecutionTraceObserver::RunState::disabled) {
|
||||
ob.setState(ExecutionTraceObserver::RunState::disabled);
|
||||
} else {
|
||||
LOG(WARNING)
|
||||
<< "Trying to disable Execution Graph Observer when it's already disabled.";
|
||||
<< "Trying to disable Execution Trace Observer when it's already disabled.";
|
||||
}
|
||||
}
|
||||
} // namespace impl
|
25
torch/csrc/profiler/standalone/execution_trace_observer.h
Normal file
25
torch/csrc/profiler/standalone/execution_trace_observer.h
Normal file
@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
namespace impl {
|
||||
|
||||
// Adds the execution trace observer as a global callback function, the data
|
||||
// will be written to output file path.
|
||||
TORCH_API bool addExecutionTraceObserver(const std::string& output_file_path);
|
||||
|
||||
// Remove the execution trace observer from the global callback functions.
|
||||
TORCH_API void removeExecutionTraceObserver();
|
||||
|
||||
// Enables execution trace observer.
|
||||
TORCH_API void enableExecutionTraceObserver();
|
||||
|
||||
// Disables execution trace observer.
|
||||
TORCH_API void disableExecutionTraceObserver();
|
||||
|
||||
} // namespace impl
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
@ -16,7 +16,7 @@ from torch.optim.optimizer import register_optimizer_step_post_hook
|
||||
|
||||
from .profiler import (
|
||||
_KinetoProfile,
|
||||
ExecutionGraphObserver,
|
||||
ExecutionTraceObserver,
|
||||
profile,
|
||||
ProfilerAction,
|
||||
schedule,
|
||||
@ -34,7 +34,7 @@ __all__ = [
|
||||
"kineto_available",
|
||||
"DeviceType",
|
||||
"record_function",
|
||||
"ExecutionGraphObserver",
|
||||
"ExecutionTraceObserver",
|
||||
]
|
||||
|
||||
from . import itt
|
||||
|
@ -10,11 +10,11 @@ from warnings import warn
|
||||
import torch
|
||||
import torch.autograd.profiler as prof
|
||||
from torch._C._profiler import (
|
||||
_add_execution_graph_observer,
|
||||
_disable_execution_graph_observer,
|
||||
_enable_execution_graph_observer,
|
||||
_add_execution_trace_observer,
|
||||
_disable_execution_trace_observer,
|
||||
_enable_execution_trace_observer,
|
||||
_ExperimentalConfig,
|
||||
_remove_execution_graph_observer,
|
||||
_remove_execution_trace_observer,
|
||||
)
|
||||
from torch.autograd import kineto_available, ProfilerActivity
|
||||
from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
|
||||
@ -26,7 +26,7 @@ __all__ = [
|
||||
"schedule",
|
||||
"tensorboard_trace_handler",
|
||||
"profile",
|
||||
"ExecutionGraphObserver",
|
||||
"ExecutionTraceObserver",
|
||||
]
|
||||
PROFILER_STEP_NAME = "ProfilerStep"
|
||||
|
||||
@ -601,14 +601,14 @@ class profile(_KinetoProfile):
|
||||
|
||||
|
||||
|
||||
class ExecutionGraphObserver:
|
||||
"""Execution Graph Observer
|
||||
class ExecutionTraceObserver:
|
||||
"""Execution Trace Observer
|
||||
|
||||
Each process can have a single ExecutionGraphObserver instance. The observer
|
||||
Each process can have a single ExecutionTraceObserver instance. The observer
|
||||
can be added to record function callbacks via calling register_callback()
|
||||
explicitly. Without calling unregister_callback(), repeated calls to
|
||||
register_callback() will not add additional observers to record function
|
||||
callbacks. Once an ExecutionGraphObserver is created, the start() and stop()
|
||||
callbacks. Once an ExecutionTraceObserver is created, the start() and stop()
|
||||
methods control when the event data is recorded.
|
||||
|
||||
Deleting or calling unregister_callback() will remove the observer from the
|
||||
@ -620,7 +620,7 @@ class ExecutionGraphObserver:
|
||||
Initializes the default states.
|
||||
"""
|
||||
self._registered = False
|
||||
self._execution_graph_running = False
|
||||
self._execution_trace_running = False
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
@ -630,44 +630,50 @@ class ExecutionGraphObserver:
|
||||
|
||||
def register_callback(self, output_file_path: str):
|
||||
"""
|
||||
Adds EG observer to record function callbacks. The the data will be
|
||||
Adds ET observer to record function callbacks. The the data will be
|
||||
written to output_file_path.
|
||||
"""
|
||||
if not self._registered:
|
||||
self._output_file_path = output_file_path
|
||||
self._registered = _add_execution_graph_observer(output_file_path)
|
||||
self._registered = _add_execution_trace_observer(output_file_path)
|
||||
|
||||
def unregister_callback(self):
|
||||
"""
|
||||
Removes EG observer from record function callbacks.
|
||||
Removes ET observer from record function callbacks.
|
||||
"""
|
||||
if self._registered:
|
||||
self.stop()
|
||||
_remove_execution_graph_observer()
|
||||
_remove_execution_trace_observer()
|
||||
self._registered = False
|
||||
|
||||
@property
|
||||
def is_registered(self):
|
||||
"""
|
||||
Return if the execution graph observer is registered.
|
||||
Returns True if the execution trace observer is registered, otherwise False.
|
||||
"""
|
||||
return self._registered
|
||||
|
||||
def is_running(self):
|
||||
"""
|
||||
Returns True if the observer is running, otherwise False.
|
||||
"""
|
||||
return self._execution_trace_running
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Starts to capture.
|
||||
"""
|
||||
if self._registered and not self._execution_graph_running:
|
||||
_enable_execution_graph_observer()
|
||||
self._execution_graph_running = True
|
||||
if self._registered and not self._execution_trace_running:
|
||||
_enable_execution_trace_observer()
|
||||
self._execution_trace_running = True
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stops to capture.
|
||||
"""
|
||||
if self._execution_graph_running:
|
||||
_disable_execution_graph_observer()
|
||||
self._execution_graph_running = False
|
||||
if self._execution_trace_running:
|
||||
_disable_execution_trace_observer()
|
||||
self._execution_trace_running = False
|
||||
|
||||
def get_output_file_path(self) -> str:
|
||||
"""
|
||||
@ -677,6 +683,6 @@ class ExecutionGraphObserver:
|
||||
return self._output_file_path
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"A callback to the EG profiler needs to be registered "
|
||||
"A callback to the ET profiler needs to be registered "
|
||||
"first before getting the output file path"
|
||||
)
|
||||
|
Reference in New Issue
Block a user