mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Revert "[memory profiling] add a facility to gather combined C++/Python/TorchScript stack traces. (#95541)"" (#96878)
This reverts commit e1ea584b1caf9c50de25ce69396dfeb523a452c0. Adds __has_include check to fix fbcode build. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96878 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
1f340df33c
commit
e74f70d212
@ -402,6 +402,8 @@ core_sources_full_mobile_no_backend_interface_xplat = [
|
||||
"torch/csrc/jit/tensorexpr/types.cpp",
|
||||
"torch/csrc/jit/tensorexpr/unique_name_manager.cpp",
|
||||
"torch/csrc/jit/testing/file_check.cpp",
|
||||
"torch/csrc/profiler/unwind/unwind.cpp",
|
||||
"torch/csrc/profiler/combined_traceback.cpp",
|
||||
"torch/csrc/jit/testing/hooks_for_testing.cpp",
|
||||
"torch/csrc/utils/cpp_stacktraces.cpp",
|
||||
"torch/csrc/utils/schema_info.cpp",
|
||||
@ -769,7 +771,6 @@ torch_cpp_srcs = [
|
||||
|
||||
libtorch_python_cuda_core_sources = [
|
||||
"torch/csrc/cuda/Event.cpp",
|
||||
"torch/csrc/profiler/unwind/unwind.cpp",
|
||||
"torch/csrc/cuda/Module.cpp",
|
||||
"torch/csrc/cuda/python_comm.cpp",
|
||||
"torch/csrc/cuda/Stream.cpp",
|
||||
@ -871,6 +872,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/multiprocessing/init.cpp",
|
||||
"torch/csrc/onnx/init.cpp",
|
||||
"torch/csrc/profiler/python/init.cpp",
|
||||
"torch/csrc/profiler/python/combined_traceback.cpp",
|
||||
"torch/csrc/serialization.cpp",
|
||||
"torch/csrc/tensor/python_tensor.cpp",
|
||||
"torch/csrc/utils/init.cpp",
|
||||
|
@ -4983,6 +4983,15 @@ class TestCudaComm(TestCase):
|
||||
finally:
|
||||
torch.cuda.memory._record_memory_history(False)
|
||||
|
||||
@unittest.skipIf(not IS_LINUX, "linux only cpp unwinding")
|
||||
def test_direct_traceback(self):
|
||||
from torch._C._profiler import gather_traceback, symbolize_tracebacks
|
||||
c = gather_traceback(True, True, True)
|
||||
r, = symbolize_tracebacks([c])
|
||||
r = str(r)
|
||||
self.assertTrue("test_cuda.py" in r)
|
||||
self.assertTrue("unwind" in r)
|
||||
|
||||
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync")
|
||||
@unittest.skipIf(not IS_LINUX, "cpp contexts are linux only")
|
||||
def test_memory_snapshot_with_cpp(self):
|
||||
@ -5165,7 +5174,7 @@ class TestCudaComm(TestCase):
|
||||
with self.assertRaises(torch.cuda.OutOfMemoryError):
|
||||
torch.empty(1024 * 1024 * 1024 * 1024, device='cuda')
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, 'Windows CI does not like the load_inline')
|
||||
@unittest.skipIf(not IS_LINUX, 'cpp traces only on linux')
|
||||
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync")
|
||||
def test_cpp_memory_snapshot_pickle(self):
|
||||
from torch.utils.cpp_extension import load_inline
|
||||
@ -5175,28 +5184,45 @@ class TestCudaComm(TestCase):
|
||||
std::string data = torch::cuda::_memory_snapshot_pickled();
|
||||
return py::bytes(data);
|
||||
}
|
||||
void record(bool e) {
|
||||
torch::cuda::_record_memory_history(e);
|
||||
void record(bool e, bool ctx) {
|
||||
torch::cuda::_record_memory_history(e, ctx, 10, ctx, ctx);
|
||||
}
|
||||
"""
|
||||
m = load_inline(name='snapshot', cpp_sources=[source], functions=['do_snapshot', 'record'])
|
||||
try:
|
||||
m.record(True)
|
||||
t = torch.rand(311, 411, device='cuda')
|
||||
mem = pickle.loads(m.do_snapshot())
|
||||
found = False
|
||||
for s in mem['segments']:
|
||||
for b in s['blocks']:
|
||||
if b['state'] == 'active_allocated' and 'history' in b:
|
||||
history = b['history']
|
||||
if history and history[0]['real_size'] == 311 * 411 * 4:
|
||||
found = True
|
||||
last_action = mem['device_traces'][0][-1]
|
||||
self.assertTrue(last_action['action'] == 'alloc')
|
||||
self.assertTrue(last_action['size'] == 311 * 411 * 4)
|
||||
self.assertTrue(found)
|
||||
finally:
|
||||
m.record(False)
|
||||
for ctx in (False, True):
|
||||
try:
|
||||
m.record(True, ctx)
|
||||
|
||||
@torch.jit.script
|
||||
def the_script_fn():
|
||||
return torch.rand(311, 411, device='cuda')
|
||||
|
||||
def run():
|
||||
t = the_script_fn()
|
||||
return pickle.loads(m.do_snapshot())
|
||||
|
||||
mem = run()
|
||||
found = False
|
||||
for s in mem['segments']:
|
||||
for b in s['blocks']:
|
||||
if b['state'] == 'active_allocated' and 'history' in b:
|
||||
history = b['history']
|
||||
if history and history[0]['real_size'] == 311 * 411 * 4:
|
||||
if ctx:
|
||||
frame_text = str(history[0]['frames'])
|
||||
# C++ frame
|
||||
self.assertTrue('::rand' in frame_text)
|
||||
# script frame
|
||||
self.assertTrue('the_script_fn' in frame_text)
|
||||
# python frame
|
||||
self.assertTrue('case.py' in frame_text)
|
||||
found = True
|
||||
last_action = mem['device_traces'][0][-1]
|
||||
self.assertTrue(last_action['action'] == 'alloc')
|
||||
self.assertTrue(last_action['size'] == 311 * 411 * 4)
|
||||
self.assertTrue(found)
|
||||
finally:
|
||||
m.record(False, False)
|
||||
|
||||
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
|
||||
def test_notifies_oom(self):
|
||||
|
@ -1830,3 +1830,9 @@ def _current_autograd_node() -> _Node: ...
|
||||
|
||||
class _OutOfMemoryError: ...
|
||||
class _DistBackendError(RuntimeError): ...
|
||||
|
||||
# Defined in torch/csrc/profiler/init.cpp
|
||||
class CapturedTraceback:
|
||||
pass
|
||||
def gather_traceback(python: _bool, script: _bool, cpp: _bool) -> CapturedTraceback: ...
|
||||
def symbolize_tracebacks(tracebacks: List[CapturedTraceback]) -> List[Dict[str, Any]]: ...
|
||||
|
@ -28,8 +28,7 @@
|
||||
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
|
||||
#include <torch/csrc/cuda/THCP.h>
|
||||
#include <torch/csrc/cuda/python_comm.h>
|
||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||
#include <torch/csrc/profiler/unwind/unwind.h>
|
||||
#include <torch/csrc/profiler/python/combined_traceback.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/cuda_lazy_init.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
@ -599,186 +598,14 @@ PyObject* THCPModule_resetPeakMemoryStats(PyObject* _unused, PyObject* arg) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
struct Frame {
|
||||
PyCodeObject* code;
|
||||
int lasti;
|
||||
};
|
||||
|
||||
static std::mutex to_free_frames_mutex;
|
||||
static std::vector<Frame> to_free_frames;
|
||||
|
||||
struct StackContext : public c10::GatheredContext {
|
||||
// Locking:
|
||||
// We need to free PyCodeObjects when ~StackContext runs, but
|
||||
// CUDACachingAllocator may hold its device lock when ~StackContext runs.
|
||||
|
||||
// Because the thread calling the allocator _may_ hold the GIL,
|
||||
// attempting to lock the GIL in ~StackContext can deadlock:
|
||||
// T0: GIL Lock -> Call Allocator ->| Waiting Device Lock
|
||||
// T1: Call Allocator -> Device Lock ->| Waiting GIL Lock
|
||||
// Instead the destructor defers freeing stack frames by putting them in
|
||||
// to_free_frames. We still need a lock to manage this vector, but
|
||||
// we can ensure an overall lock ordering of GIL -> device_lock ->
|
||||
// to_free_frames_mutex because ::gather is called outside of the device lock.
|
||||
std::vector<Frame> frames;
|
||||
std::vector<void*> cpp_frames;
|
||||
std::vector<jit::StackEntry> script_frames;
|
||||
|
||||
~StackContext() {
|
||||
std::lock_guard lock(to_free_frames_mutex);
|
||||
to_free_frames.insert(to_free_frames.end(), frames.begin(), frames.end());
|
||||
}
|
||||
static std::shared_ptr<StackContext> _gather(
|
||||
bool python,
|
||||
bool script,
|
||||
bool cpp) {
|
||||
auto r = std::make_shared<StackContext>();
|
||||
if (python) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
{
|
||||
std::lock_guard lock(to_free_frames_mutex);
|
||||
for (Frame f : to_free_frames) {
|
||||
Py_XDECREF(f.code);
|
||||
}
|
||||
to_free_frames.clear();
|
||||
}
|
||||
PyFrameObject* f = PyEval_GetFrame();
|
||||
Py_XINCREF(f);
|
||||
while (f) {
|
||||
r->frames.emplace_back(Frame{PyFrame_GetCode(f), PyFrame_GetLasti(f)});
|
||||
auto f_back = PyFrame_GetBack(f);
|
||||
Py_XDECREF(f);
|
||||
f = f_back;
|
||||
}
|
||||
}
|
||||
if (script) {
|
||||
r->script_frames = torch::jit::currentCallstack();
|
||||
}
|
||||
if (cpp) {
|
||||
r->cpp_frames = unwind::unwind();
|
||||
}
|
||||
return r;
|
||||
}
|
||||
static std::shared_ptr<c10::GatheredContext> gather() {
|
||||
return _gather(true, true, false);
|
||||
}
|
||||
static std::shared_ptr<c10::GatheredContext> gather_with_cpp() {
|
||||
return _gather(true, true, true);
|
||||
}
|
||||
};
|
||||
|
||||
void gatherFrames(
|
||||
const std::vector<std::pair<StackContext*, py::dict>>& to_gather) {
|
||||
py::str frames_s = "frames";
|
||||
py::str filename_s = "filename";
|
||||
py::str name_s = "name";
|
||||
py::str line_s = "line";
|
||||
|
||||
std::unordered_map<void*, size_t> ip_to_frame_offset; // in all_cpp_frames
|
||||
std::vector<void*> all_cpp_ips;
|
||||
struct CPPFrame {
|
||||
enum Kind { PYTHON, JIT, REPORT } kind;
|
||||
py::object frame;
|
||||
};
|
||||
std::vector<CPPFrame> all_cpp_frames;
|
||||
|
||||
// dedup and collect any C++ frames that need symbols for
|
||||
for (const auto& e : to_gather) {
|
||||
for (void* f : e.first->cpp_frames) {
|
||||
if (!ip_to_frame_offset.count(f)) {
|
||||
ip_to_frame_offset[f] = all_cpp_ips.size();
|
||||
all_cpp_ips.push_back(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// gather symbol names for C++ frames
|
||||
if (all_cpp_ips.size() > 0) {
|
||||
auto all_frames = unwind::symbolize(all_cpp_ips);
|
||||
for (auto& f : all_frames) {
|
||||
py::dict frame;
|
||||
frame[filename_s] = f.filename;
|
||||
frame[name_s] = f.funcname;
|
||||
frame[line_s] = f.lineno;
|
||||
CPPFrame::Kind kind = CPPFrame::REPORT;
|
||||
if (f.funcname.find("PyEval_EvalFrame") != std::string::npos) {
|
||||
kind = CPPFrame::PYTHON;
|
||||
} else if (
|
||||
f.funcname.rfind("torch::jit::InterpreterStateImpl::run", 0) !=
|
||||
std::string::npos) {
|
||||
kind = CPPFrame::JIT;
|
||||
}
|
||||
all_cpp_frames.emplace_back(CPPFrame{kind, frame});
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<StackContext*, py::list> cached_frames;
|
||||
for (const auto& e : to_gather) {
|
||||
auto sc = e.first;
|
||||
auto it = cached_frames.find(sc);
|
||||
if (it == cached_frames.end()) {
|
||||
py::list frames;
|
||||
auto py_it = sc->frames.begin();
|
||||
auto py_end = sc->frames.end();
|
||||
|
||||
bool jit_appended = false;
|
||||
|
||||
auto append_python = [&](const Frame& f) {
|
||||
py::dict frame;
|
||||
frame[filename_s] =
|
||||
py::reinterpret_borrow<py::object>(f.code->co_filename);
|
||||
frame[name_s] = py::reinterpret_borrow<py::object>(f.code->co_name);
|
||||
frame[line_s] = PyCode_Addr2Line(f.code, f.lasti);
|
||||
frames.append(std::move(frame));
|
||||
};
|
||||
|
||||
auto append_jit = [&]() {
|
||||
if (jit_appended) {
|
||||
return;
|
||||
}
|
||||
jit_appended = true;
|
||||
for (const auto& f : sc->script_frames) {
|
||||
py::dict frame;
|
||||
frame[name_s] = f.filename;
|
||||
auto flc = f.range.file_line_col();
|
||||
if (flc) {
|
||||
std::string filename;
|
||||
size_t line;
|
||||
size_t col;
|
||||
std::tie(filename, line, col) = *flc;
|
||||
frame[filename_s] = filename;
|
||||
frame[line_s] = line;
|
||||
} else {
|
||||
frame[filename_s] = "??";
|
||||
frame[line_s] = 0;
|
||||
}
|
||||
frames.append(std::move(frame));
|
||||
}
|
||||
};
|
||||
|
||||
for (void* f : sc->cpp_frames) {
|
||||
const CPPFrame& wf = all_cpp_frames.at(ip_to_frame_offset.at(f));
|
||||
if (wf.kind == CPPFrame::PYTHON) {
|
||||
if (py_it != py_end) {
|
||||
append_python(*py_it++);
|
||||
}
|
||||
} else if (wf.kind == CPPFrame::JIT) {
|
||||
append_jit();
|
||||
}
|
||||
frames.append(wf.frame);
|
||||
}
|
||||
|
||||
// add frames if we otherwise haven't seen the C++ frame indicating where
|
||||
// it should go
|
||||
append_jit();
|
||||
|
||||
for (; py_it != py_end; ++py_it) {
|
||||
append_python(*py_it);
|
||||
}
|
||||
it = cached_frames.insert({sc, frames}).first;
|
||||
}
|
||||
e.second[frames_s] = it->second;
|
||||
CapturedTraceback* getFromContext(
|
||||
const std::shared_ptr<c10::GatheredContext>& x) {
|
||||
if (CapturedTraceback* sc = dynamic_cast<CapturedTraceback*>(x.get())) {
|
||||
return sc;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"attempting to gather stack context from the wrong StackContext type.");
|
||||
}
|
||||
|
||||
PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
|
||||
@ -810,7 +637,8 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
|
||||
py::str history_s = "history";
|
||||
py::str blocks_s = "blocks";
|
||||
|
||||
std::vector<std::pair<StackContext*, py::dict>> frames_to_gather;
|
||||
std::vector<CapturedTraceback*> to_gather_frames;
|
||||
std::vector<py::dict> to_gather_dest;
|
||||
|
||||
const auto segmentInfoToDict = [&](const SegmentInfo& segmentInfo) {
|
||||
py::dict segmentDict;
|
||||
@ -842,8 +670,9 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
|
||||
history_entry[addr_s] = (int64_t)h.addr;
|
||||
history_entry[real_size_s] = h.real_size;
|
||||
if (h.context) {
|
||||
auto sc = (StackContext*)h.context.get();
|
||||
frames_to_gather.emplace_back(sc, history_entry);
|
||||
auto sc = getFromContext(h.context);
|
||||
to_gather_frames.emplace_back(sc);
|
||||
to_gather_dest.emplace_back(history_entry);
|
||||
}
|
||||
history.append(std::move(history_entry));
|
||||
}
|
||||
@ -903,8 +732,9 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
|
||||
py::dict trace_entry;
|
||||
if (te.context_) {
|
||||
// without further compression frames can get really large on dump
|
||||
auto sc = (StackContext*)te.context_.get();
|
||||
frames_to_gather.emplace_back(sc, trace_entry);
|
||||
auto sc = getFromContext(te.context_);
|
||||
to_gather_frames.emplace_back(sc);
|
||||
to_gather_dest.emplace_back(trace_entry);
|
||||
}
|
||||
trace_entry[action_s] = action_to_str(te.action_);
|
||||
trace_entry[TraceEntry::OOM == te.action_ ? device_free_s : addr_s] =
|
||||
@ -920,7 +750,11 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
|
||||
result["segments"] = segments;
|
||||
result["device_traces"] = traces;
|
||||
|
||||
gatherFrames(frames_to_gather);
|
||||
py::str frames_s = "frames";
|
||||
auto frames = py_symbolize(to_gather_frames);
|
||||
for (auto i : c10::irange(frames.size())) {
|
||||
to_gather_dest.at(i)[frames_s] = frames.at(i);
|
||||
}
|
||||
|
||||
return result.release().ptr();
|
||||
END_HANDLE_TH_ERRORS
|
||||
@ -996,6 +830,14 @@ PyObject* THCPModule_cudaGetSyncDebugMode(PyObject* self, PyObject* noargs) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static std::shared_ptr<c10::GatheredContext> gather() {
|
||||
return CapturedTraceback::gather(true, true, false);
|
||||
}
|
||||
|
||||
static std::shared_ptr<c10::GatheredContext> gather_with_cpp() {
|
||||
return CapturedTraceback::gather(true, true, true);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Cuda module initialization
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1034,8 +876,7 @@ static void registerCudaDeviceProperties(PyObject* module) {
|
||||
}
|
||||
c10::cuda::CUDACachingAllocator::recordHistory(
|
||||
enabled,
|
||||
record_context ? (record_context_cpp ? StackContext::gather_with_cpp
|
||||
: StackContext::gather)
|
||||
record_context ? (record_context_cpp ? gather_with_cpp : gather)
|
||||
: nullptr,
|
||||
alloc_trace_max_entries,
|
||||
alloc_trace_record_context);
|
||||
|
@ -1,6 +1,9 @@
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <torch/csrc/cuda/memory_snapshot.h>
|
||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/profiler/combined_traceback.h>
|
||||
|
||||
namespace torch {
|
||||
namespace cuda {
|
||||
|
||||
@ -32,10 +35,86 @@ Dict<IValue, IValue> new_dict() {
|
||||
c10::List<IValue> new_list() {
|
||||
return List<IValue>(c10::AnyType::get());
|
||||
}
|
||||
|
||||
std::vector<IValue> ivalue_symbolize(
|
||||
std::vector<CapturedTraceback*>& to_symbolize) {
|
||||
// we dedup repeated to_symbolize objects to prevent
|
||||
// creating a bunch of duplicated frame objects
|
||||
std::unordered_map<CapturedTraceback*, uint64_t> cached_frames;
|
||||
std::vector<CapturedTraceback*> unique_frames;
|
||||
for (const auto& sc : to_symbolize) {
|
||||
auto it = cached_frames.find(sc);
|
||||
if (it == cached_frames.end()) {
|
||||
cached_frames.insert({sc, unique_frames.size()});
|
||||
unique_frames.push_back(sc);
|
||||
}
|
||||
}
|
||||
auto s = symbolize(unique_frames);
|
||||
|
||||
IValue line_s = "line";
|
||||
IValue name_s = "name";
|
||||
IValue filename_s = "filename";
|
||||
std::vector<IValue> all_frames;
|
||||
for (const auto& f : s.all_frames) {
|
||||
auto d = new_dict();
|
||||
d.insert(name_s, f.funcname);
|
||||
d.insert(filename_s, f.filename);
|
||||
d.insert(line_s, int64_t(f.lineno));
|
||||
all_frames.emplace_back(std::move(d));
|
||||
}
|
||||
|
||||
std::vector<IValue> py_unique_frames;
|
||||
for (const auto& t : s.tracebacks) {
|
||||
auto l = new_list();
|
||||
for (const auto& e : t) {
|
||||
l.push_back(all_frames.at(e));
|
||||
}
|
||||
py_unique_frames.push_back(std::move(l));
|
||||
}
|
||||
|
||||
std::vector<IValue> result;
|
||||
for (const auto& sc : to_symbolize) {
|
||||
result.push_back(py_unique_frames.at(cached_frames.at(sc)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::shared_ptr<c10::GatheredContext> gather() {
|
||||
return CapturedTraceback::gather(true, true, false);
|
||||
}
|
||||
|
||||
std::shared_ptr<c10::GatheredContext> gather_with_cpp() {
|
||||
return CapturedTraceback::gather(true, true, true);
|
||||
}
|
||||
|
||||
CapturedTraceback* getFromContext(
|
||||
const std::shared_ptr<c10::GatheredContext>& x) {
|
||||
if (CapturedTraceback* sc = dynamic_cast<CapturedTraceback*>(x.get())) {
|
||||
return sc;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"attempting to gather stack context from the wrong StackContext type.");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
void _record_memory_history(bool enabled, int64_t alloc_trace_max_entries) {
|
||||
|
||||
void _record_memory_history(
|
||||
bool enabled,
|
||||
bool record_context,
|
||||
int64_t trace_alloc_max_entries,
|
||||
bool trace_alloc_record_context,
|
||||
bool record_cpp_context) {
|
||||
c10::cuda::CUDACachingAllocator::CreateContextFn recorder = nullptr;
|
||||
if (record_context) {
|
||||
if (record_cpp_context) {
|
||||
recorder = gather_with_cpp;
|
||||
} else {
|
||||
recorder = gather;
|
||||
}
|
||||
}
|
||||
c10::cuda::CUDACachingAllocator::recordHistory(
|
||||
enabled, nullptr, alloc_trace_max_entries, false);
|
||||
enabled, recorder, trace_alloc_max_entries, trace_alloc_record_context);
|
||||
}
|
||||
|
||||
std::string _memory_snapshot_pickled() {
|
||||
@ -66,6 +145,9 @@ std::string _memory_snapshot_pickled() {
|
||||
|
||||
auto empty_frames = new_list();
|
||||
|
||||
std::vector<CapturedTraceback*> frame_tracebacks;
|
||||
std::vector<Dict<IValue, IValue>> frame_dict;
|
||||
|
||||
const auto segmentInfoToDict = [&](const SegmentInfo& segmentInfo) {
|
||||
auto segmentDict = new_dict();
|
||||
segmentDict.insert(device_s, segmentInfo.device);
|
||||
@ -98,7 +180,8 @@ std::string _memory_snapshot_pickled() {
|
||||
history_entry.insert(addr_s, (int64_t)h.addr);
|
||||
history_entry.insert(real_size_s, (int64_t)h.real_size);
|
||||
if (h.context) {
|
||||
history_entry.insert(frames_s, empty_frames);
|
||||
frame_tracebacks.push_back(getFromContext(h.context));
|
||||
frame_dict.push_back(history_entry);
|
||||
}
|
||||
history.push_back(std::move(history_entry));
|
||||
}
|
||||
@ -160,6 +243,11 @@ std::string _memory_snapshot_pickled() {
|
||||
TraceEntry::OOM == te.action_ ? device_free_s : addr_s, te.addr_);
|
||||
trace_entry.insert(size_s, (int64_t)te.size_);
|
||||
trace_entry.insert(stream_s, int64_t(te.stream_));
|
||||
if (te.context_) {
|
||||
auto sc = getFromContext(te.context_);
|
||||
frame_tracebacks.push_back(sc);
|
||||
frame_dict.push_back(trace_entry);
|
||||
}
|
||||
trace.push_back(trace_entry);
|
||||
}
|
||||
traces.push_back(trace);
|
||||
@ -168,6 +256,12 @@ std::string _memory_snapshot_pickled() {
|
||||
auto result = new_dict();
|
||||
result.insert("segments", segments);
|
||||
result.insert("device_traces", traces);
|
||||
|
||||
auto frames = ivalue_symbolize(frame_tracebacks);
|
||||
for (auto i : c10::irange(frames.size())) {
|
||||
frame_dict.at(i).insert(frames_s, frames.at(i));
|
||||
}
|
||||
|
||||
return write_pickle(result);
|
||||
}
|
||||
} // namespace cuda
|
||||
|
@ -10,7 +10,11 @@ namespace cuda {
|
||||
// those defined in cuda/Module.cpp which also record python state.
|
||||
TORCH_CUDA_CU_API void _record_memory_history(
|
||||
bool enabled,
|
||||
int64_t alloc_trace_max_entries = 1);
|
||||
bool record_context = true,
|
||||
int64_t trace_alloc_max_entries = 1,
|
||||
bool trace_alloc_record_context = false,
|
||||
bool record_cpp_context = false);
|
||||
|
||||
TORCH_CUDA_CU_API std::string _memory_snapshot_pickled();
|
||||
|
||||
} // namespace cuda
|
||||
|
171
torch/csrc/profiler/combined_traceback.cpp
Normal file
171
torch/csrc/profiler/combined_traceback.cpp
Normal file
@ -0,0 +1,171 @@
|
||||
#include <torch/csrc/profiler/combined_traceback.h>
|
||||
|
||||
namespace torch {
|
||||
|
||||
static std::atomic<CapturedTraceback::Python*> python_support_ = nullptr;
|
||||
|
||||
std::shared_ptr<CapturedTraceback> CapturedTraceback::gather(
|
||||
bool python,
|
||||
bool script,
|
||||
bool cpp) {
|
||||
auto r = std::make_shared<CapturedTraceback>();
|
||||
if (python) {
|
||||
auto p = python_support_.load();
|
||||
while (p && r->frames_.size() == 0) {
|
||||
r->frames_ = p->gather();
|
||||
r->python_ = p;
|
||||
p = p->next_;
|
||||
}
|
||||
}
|
||||
if (script) {
|
||||
r->script_frames_ = torch::jit::currentCallstack();
|
||||
}
|
||||
if (cpp) {
|
||||
r->cpp_frames_ = unwind::unwind();
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
CapturedTraceback::~CapturedTraceback() {
|
||||
if (frames_.size() > 0) {
|
||||
TORCH_INTERNAL_ASSERT(python_);
|
||||
python_->release(frames_);
|
||||
}
|
||||
}
|
||||
|
||||
struct PyFrameHash {
|
||||
std::size_t operator()(const CapturedTraceback::PyFrame& f) const {
|
||||
return std::hash<void*>()(f.code) ^ std::hash<int>()(f.lasti);
|
||||
}
|
||||
};
|
||||
|
||||
struct PyFrameEq {
|
||||
std::size_t operator()(
|
||||
const CapturedTraceback::PyFrame& lhs,
|
||||
const CapturedTraceback::PyFrame& rhs) const {
|
||||
return lhs.code == rhs.code && lhs.lasti == rhs.lasti;
|
||||
}
|
||||
};
|
||||
|
||||
SymbolizedTracebacks symbolize(
|
||||
const std::vector<CapturedTraceback*>& to_symbolize) {
|
||||
SymbolizedTracebacks r;
|
||||
|
||||
std::unordered_map<void*, size_t> ip_to_frame_offset;
|
||||
std::unordered_map<CapturedTraceback::PyFrame, size_t, PyFrameHash, PyFrameEq>
|
||||
py_to_frame_offset;
|
||||
std::vector<void*> all_cpp_ips;
|
||||
|
||||
// dedup and collect any C++ frames that need symbols for
|
||||
for (const auto& e : to_symbolize) {
|
||||
for (void* f : e->cpp_frames_) {
|
||||
if (!ip_to_frame_offset.count(f)) {
|
||||
ip_to_frame_offset[f] = all_cpp_ips.size();
|
||||
all_cpp_ips.push_back(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
// gather symbol names for C++ frames
|
||||
if (all_cpp_ips.size() > 0) {
|
||||
r.all_frames = unwind::symbolize(all_cpp_ips);
|
||||
}
|
||||
|
||||
// batch symbolization requests so we dedup frame objects
|
||||
// however, we might have to request from different python interpreters
|
||||
// make sure we flush requests before switching interpreters;
|
||||
CapturedTraceback::Python* cur_python = nullptr;
|
||||
std::vector<CapturedTraceback::PyFrame> cur_py_frames;
|
||||
size_t py_frames_size_ = 0;
|
||||
|
||||
for (const auto& e : to_symbolize) {
|
||||
if (e->python_) {
|
||||
if (cur_python != e->python_ && cur_py_frames.size() > 0) {
|
||||
cur_python->appendSymbolized(cur_py_frames, r);
|
||||
cur_py_frames.clear();
|
||||
}
|
||||
cur_python = e->python_;
|
||||
for (const auto& f : e->frames_) {
|
||||
if (!py_to_frame_offset.count(f)) {
|
||||
py_to_frame_offset[f] = py_frames_size_++;
|
||||
cur_py_frames.push_back(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cur_py_frames.size() > 0) {
|
||||
cur_python->appendSymbolized(cur_py_frames, r);
|
||||
cur_py_frames.clear();
|
||||
}
|
||||
std::vector<std::vector<uint64_t>> python_frame_fragments =
|
||||
std::move(r.tracebacks);
|
||||
|
||||
for (const auto& sc : to_symbolize) {
|
||||
r.tracebacks.emplace_back();
|
||||
auto py_it = sc->frames_.begin();
|
||||
auto py_end = sc->frames_.end();
|
||||
|
||||
bool jit_appended = false;
|
||||
|
||||
auto append_python = [&](const CapturedTraceback::PyFrame& f) {
|
||||
const auto& fragment =
|
||||
python_frame_fragments.at(py_to_frame_offset.at(f));
|
||||
r.tracebacks.back().insert(
|
||||
r.tracebacks.back().end(), fragment.begin(), fragment.end());
|
||||
};
|
||||
|
||||
auto append_jit = [&]() {
|
||||
if (jit_appended) {
|
||||
return;
|
||||
}
|
||||
jit_appended = true;
|
||||
for (const auto& f : sc->script_frames_) {
|
||||
unwind::Frame frame;
|
||||
frame.funcname =
|
||||
f.filename; // sic: torchscript puts funcname in filename field
|
||||
auto flc = f.range.file_line_col();
|
||||
if (flc) {
|
||||
size_t col;
|
||||
std::tie(frame.filename, frame.lineno, col) = *flc;
|
||||
} else {
|
||||
frame.filename = "??";
|
||||
frame.lineno = 0;
|
||||
}
|
||||
r.tracebacks.back().push_back(r.all_frames.size());
|
||||
r.all_frames.emplace_back(std::move(frame));
|
||||
}
|
||||
};
|
||||
|
||||
for (void* f : sc->cpp_frames_) {
|
||||
uint64_t cpp_frame = ip_to_frame_offset.at(f);
|
||||
const unwind::Frame& uf = r.all_frames.at(cpp_frame);
|
||||
if (uf.funcname.find("PyEval_EvalFrame") != std::string::npos) {
|
||||
if (py_it != py_end) {
|
||||
append_python(*py_it++);
|
||||
}
|
||||
} else if (
|
||||
uf.funcname.rfind("torch::jit::InterpreterStateImpl::run", 0) !=
|
||||
std::string::npos) {
|
||||
append_jit();
|
||||
}
|
||||
r.tracebacks.back().push_back(cpp_frame);
|
||||
}
|
||||
|
||||
// add frames if we otherwise haven't seen the C++ frame indicating where
|
||||
// it should go
|
||||
append_jit();
|
||||
|
||||
for (; py_it != py_end; ++py_it) {
|
||||
append_python(*py_it);
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
void CapturedTraceback::addPythonUnwinder(CapturedTraceback::Python* p) {
|
||||
CapturedTraceback::Python* old_unwinder = python_support_.load();
|
||||
do {
|
||||
p->next_ = old_unwinder;
|
||||
} while (!python_support_.compare_exchange_strong(old_unwinder, p));
|
||||
}
|
||||
|
||||
} // namespace torch
|
62
torch/csrc/profiler/combined_traceback.h
Normal file
62
torch/csrc/profiler/combined_traceback.h
Normal file
@ -0,0 +1,62 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||
#include <torch/csrc/profiler/unwind/unwind.h>
|
||||
|
||||
namespace torch {
|
||||
|
||||
// struct that holds the result of symbolizing multiple tracebacks
|
||||
// each traceback is a list of indices into all_frames
|
||||
// (lots of Frames get duplicated across traces)
|
||||
struct TORCH_API SymbolizedTracebacks {
|
||||
std::vector<unwind::Frame> all_frames;
|
||||
// index into all_frames, so that
|
||||
// it is possible to dedupe frame objects in
|
||||
// construction of python objects
|
||||
std::vector<std::vector<uint64_t>> tracebacks;
|
||||
};
|
||||
|
||||
struct TORCH_API CapturedTraceback : public c10::GatheredContext {
|
||||
struct PyFrame {
|
||||
void* code; // PyCodeObject*, but python headers not present
|
||||
int lasti;
|
||||
};
|
||||
|
||||
static std::shared_ptr<CapturedTraceback> gather(
|
||||
bool python,
|
||||
bool script,
|
||||
bool cpp);
|
||||
CapturedTraceback() = default;
|
||||
CapturedTraceback(const CapturedTraceback&) = delete;
|
||||
CapturedTraceback& operator=(const CapturedTraceback&) = delete;
|
||||
~CapturedTraceback();
|
||||
struct Python {
|
||||
virtual std::vector<PyFrame> gather() = 0;
|
||||
virtual void release(std::vector<PyFrame>& frames) = 0;
|
||||
virtual void appendSymbolized(
|
||||
const std::vector<PyFrame>& to_symbolize,
|
||||
SymbolizedTracebacks& st) = 0;
|
||||
virtual ~Python() = default;
|
||||
Python* next_ = nullptr;
|
||||
};
|
||||
// called once by each python interpreter to
|
||||
// register python stack recording functionality
|
||||
// p cannot be deleted once added.
|
||||
static void addPythonUnwinder(Python* p);
|
||||
|
||||
private:
|
||||
std::vector<PyFrame> frames_;
|
||||
std::vector<void*> cpp_frames_;
|
||||
std::vector<jit::StackEntry> script_frames_;
|
||||
friend TORCH_API SymbolizedTracebacks
|
||||
symbolize(const std::vector<CapturedTraceback*>& to_symbolize);
|
||||
|
||||
// non-owning reference to one of the immortal Python* objects
|
||||
// registered above.
|
||||
Python* python_ = nullptr;
|
||||
};
|
||||
|
||||
TORCH_API SymbolizedTracebacks
|
||||
symbolize(const std::vector<CapturedTraceback*>& to_symbolize);
|
||||
|
||||
} // namespace torch
|
123
torch/csrc/profiler/python/combined_traceback.cpp
Normal file
123
torch/csrc/profiler/python/combined_traceback.cpp
Normal file
@ -0,0 +1,123 @@
|
||||
#include <torch/csrc/profiler/python/combined_traceback.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace torch {
|
||||
// Locking:
|
||||
// We need to free PyCodeObjects when ~StackContext runs, but
|
||||
// CUDACachingAllocator may hold its device lock when ~StackContext runs.
|
||||
|
||||
// Because the thread calling the allocator _may_ hold the GIL,
|
||||
// attempting to lock the GIL in ~StackContext can deadlock:
|
||||
// T0: GIL Lock -> Call Allocator ->| Waiting Device Lock
|
||||
// T1: Call Allocator -> Device Lock ->| Waiting GIL Lock
|
||||
// Instead the destructor defers freeing stack frames by putting them in
|
||||
// to_free_frames. We still need a lock to manage this vector, but
|
||||
// we can ensure an overall lock ordering of GIL -> device_lock ->
|
||||
// to_free_frames_mutex because ::gather is called outside of the device lock.
|
||||
|
||||
namespace {
|
||||
static std::mutex to_free_frames_mutex;
|
||||
static std::vector<CapturedTraceback::PyFrame> to_free_frames;
|
||||
struct PythonTraceback : public CapturedTraceback::Python {
|
||||
std::vector<CapturedTraceback::PyFrame> gather() override {
|
||||
std::vector<CapturedTraceback::PyFrame> frames;
|
||||
py::gil_scoped_acquire acquire;
|
||||
{
|
||||
std::lock_guard lock(to_free_frames_mutex);
|
||||
for (CapturedTraceback::PyFrame f : to_free_frames) {
|
||||
Py_XDECREF(f.code);
|
||||
}
|
||||
to_free_frames.clear();
|
||||
}
|
||||
PyFrameObject* f = PyEval_GetFrame();
|
||||
Py_XINCREF(f);
|
||||
while (f) {
|
||||
frames.emplace_back(
|
||||
CapturedTraceback::PyFrame{PyFrame_GetCode(f), PyFrame_GetLasti(f)});
|
||||
auto f_back = PyFrame_GetBack(f);
|
||||
Py_XDECREF(f);
|
||||
f = f_back;
|
||||
}
|
||||
return frames;
|
||||
}
|
||||
void release(std::vector<CapturedTraceback::PyFrame>& frames) override {
|
||||
std::lock_guard lock(to_free_frames_mutex);
|
||||
to_free_frames.insert(to_free_frames.end(), frames.begin(), frames.end());
|
||||
}
|
||||
void appendSymbolized(
|
||||
const std::vector<CapturedTraceback::PyFrame>& to_symbolize,
|
||||
SymbolizedTracebacks& result) override {
|
||||
py::str line_s = "line";
|
||||
py::str name_s = "name";
|
||||
py::str filename_s = "filename";
|
||||
|
||||
for (const auto& f : to_symbolize) {
|
||||
auto f_code = (PyCodeObject*)f.code;
|
||||
py::handle filename = f_code->co_filename;
|
||||
py::handle funcname = f_code->co_name;
|
||||
auto lineno = PyCode_Addr2Line(f_code, f.lasti);
|
||||
result.tracebacks.emplace_back();
|
||||
result.tracebacks.back().push_back(result.all_frames.size());
|
||||
result.all_frames.emplace_back(unwind::Frame{
|
||||
py::cast<std::string>(filename),
|
||||
py::cast<std::string>(funcname),
|
||||
(uint64_t)lineno});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<py::object> py_symbolize(
|
||||
std::vector<CapturedTraceback*>& to_symbolize) {
|
||||
// we dedup repeated to_symbolize objects to prevent
|
||||
// creating a bunch of duplicated frame objects
|
||||
std::unordered_map<CapturedTraceback*, uint64_t> cached_frames;
|
||||
std::vector<CapturedTraceback*> unique_frames;
|
||||
for (const auto& sc : to_symbolize) {
|
||||
auto it = cached_frames.find(sc);
|
||||
if (it == cached_frames.end()) {
|
||||
cached_frames.insert({sc, unique_frames.size()});
|
||||
unique_frames.push_back(sc);
|
||||
}
|
||||
}
|
||||
auto s = symbolize(unique_frames);
|
||||
|
||||
py::str line_s = "line";
|
||||
py::str name_s = "name";
|
||||
py::str filename_s = "filename";
|
||||
std::vector<py::dict> all_frames;
|
||||
for (const auto& f : s.all_frames) {
|
||||
py::dict d;
|
||||
d[name_s] = f.funcname;
|
||||
d[filename_s] = f.filename;
|
||||
d[line_s] = f.lineno;
|
||||
all_frames.emplace_back(std::move(d));
|
||||
}
|
||||
|
||||
std::vector<py::object> py_unique_frames;
|
||||
for (const auto& t : s.tracebacks) {
|
||||
py::list l;
|
||||
for (const auto& e : t) {
|
||||
l.append(all_frames.at(e));
|
||||
}
|
||||
py_unique_frames.push_back(std::move(l));
|
||||
}
|
||||
|
||||
std::vector<py::object> result;
|
||||
for (const auto& sc : to_symbolize) {
|
||||
result.push_back(py_unique_frames.at(cached_frames.at(sc)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void installCapturedTracebackPython() {
|
||||
CapturedTraceback::addPythonUnwinder(new PythonTraceback());
|
||||
}
|
||||
|
||||
} // namespace torch
|
19
torch/csrc/profiler/python/combined_traceback.h
Normal file
19
torch/csrc/profiler/python/combined_traceback.h
Normal file
@ -0,0 +1,19 @@
|
||||
#include <torch/csrc/profiler/combined_traceback.h>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
|
||||
// symbolize combined traceback objects, converting them into lists of
|
||||
// dictionaries that are easily consumed in python.
|
||||
|
||||
// returns std::vector because one use is to call it with a batch of
|
||||
// tracebacks that come from a larger datastructure (e.g. a memory snapshot)
|
||||
// and then have more c++ code to put those objects in the right place.
|
||||
std::vector<pybind11::object> py_symbolize(
|
||||
std::vector<CapturedTraceback*>& to_symbolize);
|
||||
|
||||
void installCapturedTracebackPython();
|
||||
|
||||
} // namespace torch
|
@ -6,6 +6,7 @@
|
||||
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/profiler/collection.h>
|
||||
#include <torch/csrc/profiler/python/combined_traceback.h>
|
||||
#include <torch/csrc/profiler/standalone/execution_graph_observer.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
@ -292,6 +293,12 @@ void initPythonBindings(PyObject* module) {
|
||||
m.def(
|
||||
"_disable_execution_graph_observer",
|
||||
&torch::profiler::impl::disableExecutionGraphObserver);
|
||||
|
||||
py::class_<CapturedTraceback, std::shared_ptr<CapturedTraceback>>(
|
||||
m, "CapturedTraceback");
|
||||
m.def("gather_traceback", CapturedTraceback::gather);
|
||||
m.def("symbolize_tracebacks", py_symbolize);
|
||||
installCapturedTracebackPython();
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
|
@ -1,7 +1,8 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/profiler/unwind/unwind.h>
|
||||
|
||||
#if !defined(__linux__) || !defined(__x86_64__)
|
||||
#if !defined(__linux__) || !defined(__x86_64__) || !defined(__has_include) || \
|
||||
!__has_include("ext/stdio_filebuf.h")
|
||||
namespace torch {
|
||||
namespace unwind {
|
||||
std::vector<void*> unwind() {
|
||||
|
@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
#include <c10/macros/Export.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -5,7 +7,7 @@ namespace torch {
|
||||
namespace unwind {
|
||||
// gather current stack, relatively fast.
|
||||
// gets faster once the cache of program counter locations is warm.
|
||||
std::vector<void*> unwind();
|
||||
TORCH_API std::vector<void*> unwind();
|
||||
|
||||
struct Frame {
|
||||
std::string filename;
|
||||
@ -19,7 +21,7 @@ struct Frame {
|
||||
// Callers should first batch up all the unique void* pointers
|
||||
// across a number of unwind states and make a single call to
|
||||
// symbolize.
|
||||
std::vector<Frame> symbolize(const std::vector<void*>& frames);
|
||||
TORCH_API std::vector<Frame> symbolize(const std::vector<void*>& frames);
|
||||
|
||||
struct Stats {
|
||||
size_t hits = 0;
|
||||
|
@ -19,7 +19,8 @@ def _frame_fmt(f, full_filename=False):
|
||||
func = f['name']
|
||||
return f'{fname}:{i}:{func}'
|
||||
|
||||
def _frame_filter(f):
|
||||
@cache
|
||||
def _frame_filter(name, filename):
|
||||
omit_functions = [
|
||||
"unwind::unwind",
|
||||
"StackContext::gather",
|
||||
@ -40,17 +41,17 @@ def _frame_filter(f):
|
||||
"cpython/abstract.h",
|
||||
]
|
||||
for of in omit_functions:
|
||||
if of in f['name']:
|
||||
if of in name:
|
||||
return False
|
||||
for of in omit_filenames:
|
||||
if of in f['filename']:
|
||||
if of in filename:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _frames_fmt(frames, full_filename=False, reverse=False):
|
||||
if reverse:
|
||||
frames = reversed(frames)
|
||||
return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f)]
|
||||
return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])]
|
||||
|
||||
def format_flamegraph(flamegraph_lines, flamegraph_script=None):
|
||||
if flamegraph_script is None:
|
||||
|
Reference in New Issue
Block a user