Revert "Revert "[memory profiling] add a facility to gather combined C++/Python/TorchScript stack traces. (#95541)"" (#96878)

This reverts commit e1ea584b1caf9c50de25ce69396dfeb523a452c0.
Adds __has_include check to fix fbcode build.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96878
Approved by: https://github.com/ezyang
This commit is contained in:
Zachary DeVito
2023-03-15 15:14:10 -07:00
committed by PyTorch MergeBot
parent 1f340df33c
commit e74f70d212
14 changed files with 580 additions and 221 deletions

View File

@ -402,6 +402,8 @@ core_sources_full_mobile_no_backend_interface_xplat = [
"torch/csrc/jit/tensorexpr/types.cpp", "torch/csrc/jit/tensorexpr/types.cpp",
"torch/csrc/jit/tensorexpr/unique_name_manager.cpp", "torch/csrc/jit/tensorexpr/unique_name_manager.cpp",
"torch/csrc/jit/testing/file_check.cpp", "torch/csrc/jit/testing/file_check.cpp",
"torch/csrc/profiler/unwind/unwind.cpp",
"torch/csrc/profiler/combined_traceback.cpp",
"torch/csrc/jit/testing/hooks_for_testing.cpp", "torch/csrc/jit/testing/hooks_for_testing.cpp",
"torch/csrc/utils/cpp_stacktraces.cpp", "torch/csrc/utils/cpp_stacktraces.cpp",
"torch/csrc/utils/schema_info.cpp", "torch/csrc/utils/schema_info.cpp",
@ -769,7 +771,6 @@ torch_cpp_srcs = [
libtorch_python_cuda_core_sources = [ libtorch_python_cuda_core_sources = [
"torch/csrc/cuda/Event.cpp", "torch/csrc/cuda/Event.cpp",
"torch/csrc/profiler/unwind/unwind.cpp",
"torch/csrc/cuda/Module.cpp", "torch/csrc/cuda/Module.cpp",
"torch/csrc/cuda/python_comm.cpp", "torch/csrc/cuda/python_comm.cpp",
"torch/csrc/cuda/Stream.cpp", "torch/csrc/cuda/Stream.cpp",
@ -871,6 +872,7 @@ libtorch_python_core_sources = [
"torch/csrc/multiprocessing/init.cpp", "torch/csrc/multiprocessing/init.cpp",
"torch/csrc/onnx/init.cpp", "torch/csrc/onnx/init.cpp",
"torch/csrc/profiler/python/init.cpp", "torch/csrc/profiler/python/init.cpp",
"torch/csrc/profiler/python/combined_traceback.cpp",
"torch/csrc/serialization.cpp", "torch/csrc/serialization.cpp",
"torch/csrc/tensor/python_tensor.cpp", "torch/csrc/tensor/python_tensor.cpp",
"torch/csrc/utils/init.cpp", "torch/csrc/utils/init.cpp",

View File

@ -4983,6 +4983,15 @@ class TestCudaComm(TestCase):
finally: finally:
torch.cuda.memory._record_memory_history(False) torch.cuda.memory._record_memory_history(False)
@unittest.skipIf(not IS_LINUX, "linux only cpp unwinding")
def test_direct_traceback(self):
from torch._C._profiler import gather_traceback, symbolize_tracebacks
c = gather_traceback(True, True, True)
r, = symbolize_tracebacks([c])
r = str(r)
self.assertTrue("test_cuda.py" in r)
self.assertTrue("unwind" in r)
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync") @unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync")
@unittest.skipIf(not IS_LINUX, "cpp contexts are linux only") @unittest.skipIf(not IS_LINUX, "cpp contexts are linux only")
def test_memory_snapshot_with_cpp(self): def test_memory_snapshot_with_cpp(self):
@ -5165,7 +5174,7 @@ class TestCudaComm(TestCase):
with self.assertRaises(torch.cuda.OutOfMemoryError): with self.assertRaises(torch.cuda.OutOfMemoryError):
torch.empty(1024 * 1024 * 1024 * 1024, device='cuda') torch.empty(1024 * 1024 * 1024 * 1024, device='cuda')
@unittest.skipIf(IS_WINDOWS, 'Windows CI does not like the load_inline') @unittest.skipIf(not IS_LINUX, 'cpp traces only on linux')
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync") @unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync")
def test_cpp_memory_snapshot_pickle(self): def test_cpp_memory_snapshot_pickle(self):
from torch.utils.cpp_extension import load_inline from torch.utils.cpp_extension import load_inline
@ -5175,28 +5184,45 @@ class TestCudaComm(TestCase):
std::string data = torch::cuda::_memory_snapshot_pickled(); std::string data = torch::cuda::_memory_snapshot_pickled();
return py::bytes(data); return py::bytes(data);
} }
void record(bool e) { void record(bool e, bool ctx) {
torch::cuda::_record_memory_history(e); torch::cuda::_record_memory_history(e, ctx, 10, ctx, ctx);
} }
""" """
m = load_inline(name='snapshot', cpp_sources=[source], functions=['do_snapshot', 'record']) m = load_inline(name='snapshot', cpp_sources=[source], functions=['do_snapshot', 'record'])
try: for ctx in (False, True):
m.record(True) try:
t = torch.rand(311, 411, device='cuda') m.record(True, ctx)
mem = pickle.loads(m.do_snapshot())
found = False @torch.jit.script
for s in mem['segments']: def the_script_fn():
for b in s['blocks']: return torch.rand(311, 411, device='cuda')
if b['state'] == 'active_allocated' and 'history' in b:
history = b['history'] def run():
if history and history[0]['real_size'] == 311 * 411 * 4: t = the_script_fn()
found = True return pickle.loads(m.do_snapshot())
last_action = mem['device_traces'][0][-1]
self.assertTrue(last_action['action'] == 'alloc') mem = run()
self.assertTrue(last_action['size'] == 311 * 411 * 4) found = False
self.assertTrue(found) for s in mem['segments']:
finally: for b in s['blocks']:
m.record(False) if b['state'] == 'active_allocated' and 'history' in b:
history = b['history']
if history and history[0]['real_size'] == 311 * 411 * 4:
if ctx:
frame_text = str(history[0]['frames'])
# C++ frame
self.assertTrue('::rand' in frame_text)
# script frame
self.assertTrue('the_script_fn' in frame_text)
# python frame
self.assertTrue('case.py' in frame_text)
found = True
last_action = mem['device_traces'][0][-1]
self.assertTrue(last_action['action'] == 'alloc')
self.assertTrue(last_action['size'] == 311 * 411 * 4)
self.assertTrue(found)
finally:
m.record(False, False)
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled") @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled")
def test_notifies_oom(self): def test_notifies_oom(self):

View File

@ -1830,3 +1830,9 @@ def _current_autograd_node() -> _Node: ...
class _OutOfMemoryError: ... class _OutOfMemoryError: ...
class _DistBackendError(RuntimeError): ... class _DistBackendError(RuntimeError): ...
# Defined in torch/csrc/profiler/init.cpp
class CapturedTraceback:
pass
def gather_traceback(python: _bool, script: _bool, cpp: _bool) -> CapturedTraceback: ...
def symbolize_tracebacks(tracebacks: List[CapturedTraceback]) -> List[Dict[str, Any]]: ...

View File

@ -28,8 +28,7 @@
#include <torch/csrc/cuda/CUDAPluggableAllocator.h> #include <torch/csrc/cuda/CUDAPluggableAllocator.h>
#include <torch/csrc/cuda/THCP.h> #include <torch/csrc/cuda/THCP.h>
#include <torch/csrc/cuda/python_comm.h> #include <torch/csrc/cuda/python_comm.h>
#include <torch/csrc/jit/runtime/interpreter.h> #include <torch/csrc/profiler/python/combined_traceback.h>
#include <torch/csrc/profiler/unwind/unwind.h>
#include <torch/csrc/python_headers.h> #include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/cuda_lazy_init.h> #include <torch/csrc/utils/cuda_lazy_init.h>
#include <torch/csrc/utils/pybind.h> #include <torch/csrc/utils/pybind.h>
@ -599,186 +598,14 @@ PyObject* THCPModule_resetPeakMemoryStats(PyObject* _unused, PyObject* arg) {
Py_RETURN_NONE; Py_RETURN_NONE;
} }
struct Frame { CapturedTraceback* getFromContext(
PyCodeObject* code; const std::shared_ptr<c10::GatheredContext>& x) {
int lasti; if (CapturedTraceback* sc = dynamic_cast<CapturedTraceback*>(x.get())) {
}; return sc;
static std::mutex to_free_frames_mutex;
static std::vector<Frame> to_free_frames;
struct StackContext : public c10::GatheredContext {
// Locking:
// We need to free PyCodeObjects when ~StackContext runs, but
// CUDACachingAllocator may hold its device lock when ~StackContext runs.
// Because the thread calling the allocator _may_ hold the GIL,
// attempting to lock the GIL in ~StackContext can deadlock:
// T0: GIL Lock -> Call Allocator ->| Waiting Device Lock
// T1: Call Allocator -> Device Lock ->| Waiting GIL Lock
// Instead the destructor defers freeing stack frames by putting them in
// to_free_frames. We still need a lock to manage this vector, but
// we can ensure an overall lock ordering of GIL -> device_lock ->
// to_free_frames_mutex because ::gather is called outside of the device lock.
std::vector<Frame> frames;
std::vector<void*> cpp_frames;
std::vector<jit::StackEntry> script_frames;
~StackContext() {
std::lock_guard lock(to_free_frames_mutex);
to_free_frames.insert(to_free_frames.end(), frames.begin(), frames.end());
}
static std::shared_ptr<StackContext> _gather(
bool python,
bool script,
bool cpp) {
auto r = std::make_shared<StackContext>();
if (python) {
py::gil_scoped_acquire acquire;
{
std::lock_guard lock(to_free_frames_mutex);
for (Frame f : to_free_frames) {
Py_XDECREF(f.code);
}
to_free_frames.clear();
}
PyFrameObject* f = PyEval_GetFrame();
Py_XINCREF(f);
while (f) {
r->frames.emplace_back(Frame{PyFrame_GetCode(f), PyFrame_GetLasti(f)});
auto f_back = PyFrame_GetBack(f);
Py_XDECREF(f);
f = f_back;
}
}
if (script) {
r->script_frames = torch::jit::currentCallstack();
}
if (cpp) {
r->cpp_frames = unwind::unwind();
}
return r;
}
static std::shared_ptr<c10::GatheredContext> gather() {
return _gather(true, true, false);
}
static std::shared_ptr<c10::GatheredContext> gather_with_cpp() {
return _gather(true, true, true);
}
};
void gatherFrames(
const std::vector<std::pair<StackContext*, py::dict>>& to_gather) {
py::str frames_s = "frames";
py::str filename_s = "filename";
py::str name_s = "name";
py::str line_s = "line";
std::unordered_map<void*, size_t> ip_to_frame_offset; // in all_cpp_frames
std::vector<void*> all_cpp_ips;
struct CPPFrame {
enum Kind { PYTHON, JIT, REPORT } kind;
py::object frame;
};
std::vector<CPPFrame> all_cpp_frames;
// dedup and collect any C++ frames that need symbols for
for (const auto& e : to_gather) {
for (void* f : e.first->cpp_frames) {
if (!ip_to_frame_offset.count(f)) {
ip_to_frame_offset[f] = all_cpp_ips.size();
all_cpp_ips.push_back(f);
}
}
}
// gather symbol names for C++ frames
if (all_cpp_ips.size() > 0) {
auto all_frames = unwind::symbolize(all_cpp_ips);
for (auto& f : all_frames) {
py::dict frame;
frame[filename_s] = f.filename;
frame[name_s] = f.funcname;
frame[line_s] = f.lineno;
CPPFrame::Kind kind = CPPFrame::REPORT;
if (f.funcname.find("PyEval_EvalFrame") != std::string::npos) {
kind = CPPFrame::PYTHON;
} else if (
f.funcname.rfind("torch::jit::InterpreterStateImpl::run", 0) !=
std::string::npos) {
kind = CPPFrame::JIT;
}
all_cpp_frames.emplace_back(CPPFrame{kind, frame});
}
}
std::unordered_map<StackContext*, py::list> cached_frames;
for (const auto& e : to_gather) {
auto sc = e.first;
auto it = cached_frames.find(sc);
if (it == cached_frames.end()) {
py::list frames;
auto py_it = sc->frames.begin();
auto py_end = sc->frames.end();
bool jit_appended = false;
auto append_python = [&](const Frame& f) {
py::dict frame;
frame[filename_s] =
py::reinterpret_borrow<py::object>(f.code->co_filename);
frame[name_s] = py::reinterpret_borrow<py::object>(f.code->co_name);
frame[line_s] = PyCode_Addr2Line(f.code, f.lasti);
frames.append(std::move(frame));
};
auto append_jit = [&]() {
if (jit_appended) {
return;
}
jit_appended = true;
for (const auto& f : sc->script_frames) {
py::dict frame;
frame[name_s] = f.filename;
auto flc = f.range.file_line_col();
if (flc) {
std::string filename;
size_t line;
size_t col;
std::tie(filename, line, col) = *flc;
frame[filename_s] = filename;
frame[line_s] = line;
} else {
frame[filename_s] = "??";
frame[line_s] = 0;
}
frames.append(std::move(frame));
}
};
for (void* f : sc->cpp_frames) {
const CPPFrame& wf = all_cpp_frames.at(ip_to_frame_offset.at(f));
if (wf.kind == CPPFrame::PYTHON) {
if (py_it != py_end) {
append_python(*py_it++);
}
} else if (wf.kind == CPPFrame::JIT) {
append_jit();
}
frames.append(wf.frame);
}
// add frames if we otherwise haven't seen the C++ frame indicating where
// it should go
append_jit();
for (; py_it != py_end; ++py_it) {
append_python(*py_it);
}
it = cached_frames.insert({sc, frames}).first;
}
e.second[frames_s] = it->second;
} }
TORCH_CHECK(
false,
"attempting to gather stack context from the wrong StackContext type.");
} }
PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) { PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
@ -810,7 +637,8 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
py::str history_s = "history"; py::str history_s = "history";
py::str blocks_s = "blocks"; py::str blocks_s = "blocks";
std::vector<std::pair<StackContext*, py::dict>> frames_to_gather; std::vector<CapturedTraceback*> to_gather_frames;
std::vector<py::dict> to_gather_dest;
const auto segmentInfoToDict = [&](const SegmentInfo& segmentInfo) { const auto segmentInfoToDict = [&](const SegmentInfo& segmentInfo) {
py::dict segmentDict; py::dict segmentDict;
@ -842,8 +670,9 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
history_entry[addr_s] = (int64_t)h.addr; history_entry[addr_s] = (int64_t)h.addr;
history_entry[real_size_s] = h.real_size; history_entry[real_size_s] = h.real_size;
if (h.context) { if (h.context) {
auto sc = (StackContext*)h.context.get(); auto sc = getFromContext(h.context);
frames_to_gather.emplace_back(sc, history_entry); to_gather_frames.emplace_back(sc);
to_gather_dest.emplace_back(history_entry);
} }
history.append(std::move(history_entry)); history.append(std::move(history_entry));
} }
@ -903,8 +732,9 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
py::dict trace_entry; py::dict trace_entry;
if (te.context_) { if (te.context_) {
// without further compression frames can get really large on dump // without further compression frames can get really large on dump
auto sc = (StackContext*)te.context_.get(); auto sc = getFromContext(te.context_);
frames_to_gather.emplace_back(sc, trace_entry); to_gather_frames.emplace_back(sc);
to_gather_dest.emplace_back(trace_entry);
} }
trace_entry[action_s] = action_to_str(te.action_); trace_entry[action_s] = action_to_str(te.action_);
trace_entry[TraceEntry::OOM == te.action_ ? device_free_s : addr_s] = trace_entry[TraceEntry::OOM == te.action_ ? device_free_s : addr_s] =
@ -920,7 +750,11 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
result["segments"] = segments; result["segments"] = segments;
result["device_traces"] = traces; result["device_traces"] = traces;
gatherFrames(frames_to_gather); py::str frames_s = "frames";
auto frames = py_symbolize(to_gather_frames);
for (auto i : c10::irange(frames.size())) {
to_gather_dest.at(i)[frames_s] = frames.at(i);
}
return result.release().ptr(); return result.release().ptr();
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
@ -996,6 +830,14 @@ PyObject* THCPModule_cudaGetSyncDebugMode(PyObject* self, PyObject* noargs) {
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static std::shared_ptr<c10::GatheredContext> gather() {
return CapturedTraceback::gather(true, true, false);
}
static std::shared_ptr<c10::GatheredContext> gather_with_cpp() {
return CapturedTraceback::gather(true, true, true);
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Cuda module initialization // Cuda module initialization
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1034,8 +876,7 @@ static void registerCudaDeviceProperties(PyObject* module) {
} }
c10::cuda::CUDACachingAllocator::recordHistory( c10::cuda::CUDACachingAllocator::recordHistory(
enabled, enabled,
record_context ? (record_context_cpp ? StackContext::gather_with_cpp record_context ? (record_context_cpp ? gather_with_cpp : gather)
: StackContext::gather)
: nullptr, : nullptr,
alloc_trace_max_entries, alloc_trace_max_entries,
alloc_trace_record_context); alloc_trace_record_context);

View File

@ -1,6 +1,9 @@
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include <torch/csrc/cuda/memory_snapshot.h> #include <torch/csrc/cuda/memory_snapshot.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/serialization/pickler.h> #include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/profiler/combined_traceback.h>
namespace torch { namespace torch {
namespace cuda { namespace cuda {
@ -32,10 +35,86 @@ Dict<IValue, IValue> new_dict() {
c10::List<IValue> new_list() { c10::List<IValue> new_list() {
return List<IValue>(c10::AnyType::get()); return List<IValue>(c10::AnyType::get());
} }
std::vector<IValue> ivalue_symbolize(
std::vector<CapturedTraceback*>& to_symbolize) {
// we dedup repeated to_symbolize objects to prevent
// creating a bunch of duplicated frame objects
std::unordered_map<CapturedTraceback*, uint64_t> cached_frames;
std::vector<CapturedTraceback*> unique_frames;
for (const auto& sc : to_symbolize) {
auto it = cached_frames.find(sc);
if (it == cached_frames.end()) {
cached_frames.insert({sc, unique_frames.size()});
unique_frames.push_back(sc);
}
}
auto s = symbolize(unique_frames);
IValue line_s = "line";
IValue name_s = "name";
IValue filename_s = "filename";
std::vector<IValue> all_frames;
for (const auto& f : s.all_frames) {
auto d = new_dict();
d.insert(name_s, f.funcname);
d.insert(filename_s, f.filename);
d.insert(line_s, int64_t(f.lineno));
all_frames.emplace_back(std::move(d));
}
std::vector<IValue> py_unique_frames;
for (const auto& t : s.tracebacks) {
auto l = new_list();
for (const auto& e : t) {
l.push_back(all_frames.at(e));
}
py_unique_frames.push_back(std::move(l));
}
std::vector<IValue> result;
for (const auto& sc : to_symbolize) {
result.push_back(py_unique_frames.at(cached_frames.at(sc)));
}
return result;
}
std::shared_ptr<c10::GatheredContext> gather() {
return CapturedTraceback::gather(true, true, false);
}
std::shared_ptr<c10::GatheredContext> gather_with_cpp() {
return CapturedTraceback::gather(true, true, true);
}
CapturedTraceback* getFromContext(
const std::shared_ptr<c10::GatheredContext>& x) {
if (CapturedTraceback* sc = dynamic_cast<CapturedTraceback*>(x.get())) {
return sc;
}
TORCH_CHECK(
false,
"attempting to gather stack context from the wrong StackContext type.");
}
} // namespace } // namespace
void _record_memory_history(bool enabled, int64_t alloc_trace_max_entries) {
void _record_memory_history(
bool enabled,
bool record_context,
int64_t trace_alloc_max_entries,
bool trace_alloc_record_context,
bool record_cpp_context) {
c10::cuda::CUDACachingAllocator::CreateContextFn recorder = nullptr;
if (record_context) {
if (record_cpp_context) {
recorder = gather_with_cpp;
} else {
recorder = gather;
}
}
c10::cuda::CUDACachingAllocator::recordHistory( c10::cuda::CUDACachingAllocator::recordHistory(
enabled, nullptr, alloc_trace_max_entries, false); enabled, recorder, trace_alloc_max_entries, trace_alloc_record_context);
} }
std::string _memory_snapshot_pickled() { std::string _memory_snapshot_pickled() {
@ -66,6 +145,9 @@ std::string _memory_snapshot_pickled() {
auto empty_frames = new_list(); auto empty_frames = new_list();
std::vector<CapturedTraceback*> frame_tracebacks;
std::vector<Dict<IValue, IValue>> frame_dict;
const auto segmentInfoToDict = [&](const SegmentInfo& segmentInfo) { const auto segmentInfoToDict = [&](const SegmentInfo& segmentInfo) {
auto segmentDict = new_dict(); auto segmentDict = new_dict();
segmentDict.insert(device_s, segmentInfo.device); segmentDict.insert(device_s, segmentInfo.device);
@ -98,7 +180,8 @@ std::string _memory_snapshot_pickled() {
history_entry.insert(addr_s, (int64_t)h.addr); history_entry.insert(addr_s, (int64_t)h.addr);
history_entry.insert(real_size_s, (int64_t)h.real_size); history_entry.insert(real_size_s, (int64_t)h.real_size);
if (h.context) { if (h.context) {
history_entry.insert(frames_s, empty_frames); frame_tracebacks.push_back(getFromContext(h.context));
frame_dict.push_back(history_entry);
} }
history.push_back(std::move(history_entry)); history.push_back(std::move(history_entry));
} }
@ -160,6 +243,11 @@ std::string _memory_snapshot_pickled() {
TraceEntry::OOM == te.action_ ? device_free_s : addr_s, te.addr_); TraceEntry::OOM == te.action_ ? device_free_s : addr_s, te.addr_);
trace_entry.insert(size_s, (int64_t)te.size_); trace_entry.insert(size_s, (int64_t)te.size_);
trace_entry.insert(stream_s, int64_t(te.stream_)); trace_entry.insert(stream_s, int64_t(te.stream_));
if (te.context_) {
auto sc = getFromContext(te.context_);
frame_tracebacks.push_back(sc);
frame_dict.push_back(trace_entry);
}
trace.push_back(trace_entry); trace.push_back(trace_entry);
} }
traces.push_back(trace); traces.push_back(trace);
@ -168,6 +256,12 @@ std::string _memory_snapshot_pickled() {
auto result = new_dict(); auto result = new_dict();
result.insert("segments", segments); result.insert("segments", segments);
result.insert("device_traces", traces); result.insert("device_traces", traces);
auto frames = ivalue_symbolize(frame_tracebacks);
for (auto i : c10::irange(frames.size())) {
frame_dict.at(i).insert(frames_s, frames.at(i));
}
return write_pickle(result); return write_pickle(result);
} }
} // namespace cuda } // namespace cuda

View File

@ -10,7 +10,11 @@ namespace cuda {
// those defined in cuda/Module.cpp which also record python state. // those defined in cuda/Module.cpp which also record python state.
TORCH_CUDA_CU_API void _record_memory_history( TORCH_CUDA_CU_API void _record_memory_history(
bool enabled, bool enabled,
int64_t alloc_trace_max_entries = 1); bool record_context = true,
int64_t trace_alloc_max_entries = 1,
bool trace_alloc_record_context = false,
bool record_cpp_context = false);
TORCH_CUDA_CU_API std::string _memory_snapshot_pickled(); TORCH_CUDA_CU_API std::string _memory_snapshot_pickled();
} // namespace cuda } // namespace cuda

View File

@ -0,0 +1,171 @@
#include <torch/csrc/profiler/combined_traceback.h>
namespace torch {
static std::atomic<CapturedTraceback::Python*> python_support_ = nullptr;
std::shared_ptr<CapturedTraceback> CapturedTraceback::gather(
bool python,
bool script,
bool cpp) {
auto r = std::make_shared<CapturedTraceback>();
if (python) {
auto p = python_support_.load();
while (p && r->frames_.size() == 0) {
r->frames_ = p->gather();
r->python_ = p;
p = p->next_;
}
}
if (script) {
r->script_frames_ = torch::jit::currentCallstack();
}
if (cpp) {
r->cpp_frames_ = unwind::unwind();
}
return r;
}
CapturedTraceback::~CapturedTraceback() {
if (frames_.size() > 0) {
TORCH_INTERNAL_ASSERT(python_);
python_->release(frames_);
}
}
struct PyFrameHash {
std::size_t operator()(const CapturedTraceback::PyFrame& f) const {
return std::hash<void*>()(f.code) ^ std::hash<int>()(f.lasti);
}
};
struct PyFrameEq {
std::size_t operator()(
const CapturedTraceback::PyFrame& lhs,
const CapturedTraceback::PyFrame& rhs) const {
return lhs.code == rhs.code && lhs.lasti == rhs.lasti;
}
};
SymbolizedTracebacks symbolize(
const std::vector<CapturedTraceback*>& to_symbolize) {
SymbolizedTracebacks r;
std::unordered_map<void*, size_t> ip_to_frame_offset;
std::unordered_map<CapturedTraceback::PyFrame, size_t, PyFrameHash, PyFrameEq>
py_to_frame_offset;
std::vector<void*> all_cpp_ips;
// dedup and collect any C++ frames that need symbols for
for (const auto& e : to_symbolize) {
for (void* f : e->cpp_frames_) {
if (!ip_to_frame_offset.count(f)) {
ip_to_frame_offset[f] = all_cpp_ips.size();
all_cpp_ips.push_back(f);
}
}
}
// gather symbol names for C++ frames
if (all_cpp_ips.size() > 0) {
r.all_frames = unwind::symbolize(all_cpp_ips);
}
// batch symbolization requests so we dedup frame objects
// however, we might have to request from different python interpreters
// make sure we flush requests before switching interpreters;
CapturedTraceback::Python* cur_python = nullptr;
std::vector<CapturedTraceback::PyFrame> cur_py_frames;
size_t py_frames_size_ = 0;
for (const auto& e : to_symbolize) {
if (e->python_) {
if (cur_python != e->python_ && cur_py_frames.size() > 0) {
cur_python->appendSymbolized(cur_py_frames, r);
cur_py_frames.clear();
}
cur_python = e->python_;
for (const auto& f : e->frames_) {
if (!py_to_frame_offset.count(f)) {
py_to_frame_offset[f] = py_frames_size_++;
cur_py_frames.push_back(f);
}
}
}
}
if (cur_py_frames.size() > 0) {
cur_python->appendSymbolized(cur_py_frames, r);
cur_py_frames.clear();
}
std::vector<std::vector<uint64_t>> python_frame_fragments =
std::move(r.tracebacks);
for (const auto& sc : to_symbolize) {
r.tracebacks.emplace_back();
auto py_it = sc->frames_.begin();
auto py_end = sc->frames_.end();
bool jit_appended = false;
auto append_python = [&](const CapturedTraceback::PyFrame& f) {
const auto& fragment =
python_frame_fragments.at(py_to_frame_offset.at(f));
r.tracebacks.back().insert(
r.tracebacks.back().end(), fragment.begin(), fragment.end());
};
auto append_jit = [&]() {
if (jit_appended) {
return;
}
jit_appended = true;
for (const auto& f : sc->script_frames_) {
unwind::Frame frame;
frame.funcname =
f.filename; // sic: torchscript puts funcname in filename field
auto flc = f.range.file_line_col();
if (flc) {
size_t col;
std::tie(frame.filename, frame.lineno, col) = *flc;
} else {
frame.filename = "??";
frame.lineno = 0;
}
r.tracebacks.back().push_back(r.all_frames.size());
r.all_frames.emplace_back(std::move(frame));
}
};
for (void* f : sc->cpp_frames_) {
uint64_t cpp_frame = ip_to_frame_offset.at(f);
const unwind::Frame& uf = r.all_frames.at(cpp_frame);
if (uf.funcname.find("PyEval_EvalFrame") != std::string::npos) {
if (py_it != py_end) {
append_python(*py_it++);
}
} else if (
uf.funcname.rfind("torch::jit::InterpreterStateImpl::run", 0) !=
std::string::npos) {
append_jit();
}
r.tracebacks.back().push_back(cpp_frame);
}
// add frames if we otherwise haven't seen the C++ frame indicating where
// it should go
append_jit();
for (; py_it != py_end; ++py_it) {
append_python(*py_it);
}
}
return r;
}
void CapturedTraceback::addPythonUnwinder(CapturedTraceback::Python* p) {
CapturedTraceback::Python* old_unwinder = python_support_.load();
do {
p->next_ = old_unwinder;
} while (!python_support_.compare_exchange_strong(old_unwinder, p));
}
} // namespace torch

View File

@ -0,0 +1,62 @@
#pragma once
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/profiler/unwind/unwind.h>
namespace torch {
// struct that holds the result of symbolizing multiple tracebacks
// each traceback is a list of indices into all_frames
// (lots of Frames get duplicated across traces)
struct TORCH_API SymbolizedTracebacks {
std::vector<unwind::Frame> all_frames;
// index into all_frames, so that
// it is possible to dedupe frame objects in
// construction of python objects
std::vector<std::vector<uint64_t>> tracebacks;
};
struct TORCH_API CapturedTraceback : public c10::GatheredContext {
struct PyFrame {
void* code; // PyCodeObject*, but python headers not present
int lasti;
};
static std::shared_ptr<CapturedTraceback> gather(
bool python,
bool script,
bool cpp);
CapturedTraceback() = default;
CapturedTraceback(const CapturedTraceback&) = delete;
CapturedTraceback& operator=(const CapturedTraceback&) = delete;
~CapturedTraceback();
struct Python {
virtual std::vector<PyFrame> gather() = 0;
virtual void release(std::vector<PyFrame>& frames) = 0;
virtual void appendSymbolized(
const std::vector<PyFrame>& to_symbolize,
SymbolizedTracebacks& st) = 0;
virtual ~Python() = default;
Python* next_ = nullptr;
};
// called once by each python interpreter to
// register python stack recording functionality
// p cannot be deleted once added.
static void addPythonUnwinder(Python* p);
private:
std::vector<PyFrame> frames_;
std::vector<void*> cpp_frames_;
std::vector<jit::StackEntry> script_frames_;
friend TORCH_API SymbolizedTracebacks
symbolize(const std::vector<CapturedTraceback*>& to_symbolize);
// non-owning reference to one of the immortal Python* objects
// registered above.
Python* python_ = nullptr;
};
TORCH_API SymbolizedTracebacks
symbolize(const std::vector<CapturedTraceback*>& to_symbolize);
} // namespace torch

View File

@ -0,0 +1,123 @@
#include <torch/csrc/profiler/python/combined_traceback.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <iostream>
namespace py = pybind11;
namespace torch {
// Locking:
// We need to free PyCodeObjects when ~StackContext runs, but
// CUDACachingAllocator may hold its device lock when ~StackContext runs.
// Because the thread calling the allocator _may_ hold the GIL,
// attempting to lock the GIL in ~StackContext can deadlock:
// T0: GIL Lock -> Call Allocator ->| Waiting Device Lock
// T1: Call Allocator -> Device Lock ->| Waiting GIL Lock
// Instead the destructor defers freeing stack frames by putting them in
// to_free_frames. We still need a lock to manage this vector, but
// we can ensure an overall lock ordering of GIL -> device_lock ->
// to_free_frames_mutex because ::gather is called outside of the device lock.
namespace {
static std::mutex to_free_frames_mutex;
static std::vector<CapturedTraceback::PyFrame> to_free_frames;
struct PythonTraceback : public CapturedTraceback::Python {
std::vector<CapturedTraceback::PyFrame> gather() override {
std::vector<CapturedTraceback::PyFrame> frames;
py::gil_scoped_acquire acquire;
{
std::lock_guard lock(to_free_frames_mutex);
for (CapturedTraceback::PyFrame f : to_free_frames) {
Py_XDECREF(f.code);
}
to_free_frames.clear();
}
PyFrameObject* f = PyEval_GetFrame();
Py_XINCREF(f);
while (f) {
frames.emplace_back(
CapturedTraceback::PyFrame{PyFrame_GetCode(f), PyFrame_GetLasti(f)});
auto f_back = PyFrame_GetBack(f);
Py_XDECREF(f);
f = f_back;
}
return frames;
}
void release(std::vector<CapturedTraceback::PyFrame>& frames) override {
std::lock_guard lock(to_free_frames_mutex);
to_free_frames.insert(to_free_frames.end(), frames.begin(), frames.end());
}
void appendSymbolized(
const std::vector<CapturedTraceback::PyFrame>& to_symbolize,
SymbolizedTracebacks& result) override {
py::str line_s = "line";
py::str name_s = "name";
py::str filename_s = "filename";
for (const auto& f : to_symbolize) {
auto f_code = (PyCodeObject*)f.code;
py::handle filename = f_code->co_filename;
py::handle funcname = f_code->co_name;
auto lineno = PyCode_Addr2Line(f_code, f.lasti);
result.tracebacks.emplace_back();
result.tracebacks.back().push_back(result.all_frames.size());
result.all_frames.emplace_back(unwind::Frame{
py::cast<std::string>(filename),
py::cast<std::string>(funcname),
(uint64_t)lineno});
}
}
};
} // namespace
std::vector<py::object> py_symbolize(
std::vector<CapturedTraceback*>& to_symbolize) {
// we dedup repeated to_symbolize objects to prevent
// creating a bunch of duplicated frame objects
std::unordered_map<CapturedTraceback*, uint64_t> cached_frames;
std::vector<CapturedTraceback*> unique_frames;
for (const auto& sc : to_symbolize) {
auto it = cached_frames.find(sc);
if (it == cached_frames.end()) {
cached_frames.insert({sc, unique_frames.size()});
unique_frames.push_back(sc);
}
}
auto s = symbolize(unique_frames);
py::str line_s = "line";
py::str name_s = "name";
py::str filename_s = "filename";
std::vector<py::dict> all_frames;
for (const auto& f : s.all_frames) {
py::dict d;
d[name_s] = f.funcname;
d[filename_s] = f.filename;
d[line_s] = f.lineno;
all_frames.emplace_back(std::move(d));
}
std::vector<py::object> py_unique_frames;
for (const auto& t : s.tracebacks) {
py::list l;
for (const auto& e : t) {
l.append(all_frames.at(e));
}
py_unique_frames.push_back(std::move(l));
}
std::vector<py::object> result;
for (const auto& sc : to_symbolize) {
result.push_back(py_unique_frames.at(cached_frames.at(sc)));
}
return result;
}
void installCapturedTracebackPython() {
CapturedTraceback::addPythonUnwinder(new PythonTraceback());
}
} // namespace torch

View File

@ -0,0 +1,19 @@
#include <torch/csrc/profiler/combined_traceback.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
// symbolize combined traceback objects, converting them into lists of
// dictionaries that are easily consumed in python.
// returns std::vector because one use is to call it with a batch of
// tracebacks that come from a larger datastructure (e.g. a memory snapshot)
// and then have more c++ code to put those objects in the right place.
std::vector<pybind11::object> py_symbolize(
std::vector<CapturedTraceback*>& to_symbolize);
void installCapturedTracebackPython();
} // namespace torch

View File

@ -6,6 +6,7 @@
#include <torch/csrc/autograd/utils/wrap_outputs.h> #include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/jit/python/pybind_utils.h> #include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/profiler/collection.h> #include <torch/csrc/profiler/collection.h>
#include <torch/csrc/profiler/python/combined_traceback.h>
#include <torch/csrc/profiler/standalone/execution_graph_observer.h> #include <torch/csrc/profiler/standalone/execution_graph_observer.h>
#include <torch/csrc/utils/pybind.h> #include <torch/csrc/utils/pybind.h>
@ -292,6 +293,12 @@ void initPythonBindings(PyObject* module) {
m.def( m.def(
"_disable_execution_graph_observer", "_disable_execution_graph_observer",
&torch::profiler::impl::disableExecutionGraphObserver); &torch::profiler::impl::disableExecutionGraphObserver);
py::class_<CapturedTraceback, std::shared_ptr<CapturedTraceback>>(
m, "CapturedTraceback");
m.def("gather_traceback", CapturedTraceback::gather);
m.def("symbolize_tracebacks", py_symbolize);
installCapturedTracebackPython();
} }
} // namespace profiler } // namespace profiler

View File

@ -1,7 +1,8 @@
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <torch/csrc/profiler/unwind/unwind.h> #include <torch/csrc/profiler/unwind/unwind.h>
#if !defined(__linux__) || !defined(__x86_64__) #if !defined(__linux__) || !defined(__x86_64__) || !defined(__has_include) || \
!__has_include("ext/stdio_filebuf.h")
namespace torch { namespace torch {
namespace unwind { namespace unwind {
std::vector<void*> unwind() { std::vector<void*> unwind() {

View File

@ -1,3 +1,5 @@
#pragma once
#include <c10/macros/Export.h>
#include <string> #include <string>
#include <vector> #include <vector>
@ -5,7 +7,7 @@ namespace torch {
namespace unwind { namespace unwind {
// gather current stack, relatively fast. // gather current stack, relatively fast.
// gets faster once the cache of program counter locations is warm. // gets faster once the cache of program counter locations is warm.
std::vector<void*> unwind(); TORCH_API std::vector<void*> unwind();
struct Frame { struct Frame {
std::string filename; std::string filename;
@ -19,7 +21,7 @@ struct Frame {
// Callers should first batch up all the unique void* pointers // Callers should first batch up all the unique void* pointers
// across a number of unwind states and make a single call to // across a number of unwind states and make a single call to
// symbolize. // symbolize.
std::vector<Frame> symbolize(const std::vector<void*>& frames); TORCH_API std::vector<Frame> symbolize(const std::vector<void*>& frames);
struct Stats { struct Stats {
size_t hits = 0; size_t hits = 0;

View File

@ -19,7 +19,8 @@ def _frame_fmt(f, full_filename=False):
func = f['name'] func = f['name']
return f'{fname}:{i}:{func}' return f'{fname}:{i}:{func}'
def _frame_filter(f): @cache
def _frame_filter(name, filename):
omit_functions = [ omit_functions = [
"unwind::unwind", "unwind::unwind",
"StackContext::gather", "StackContext::gather",
@ -40,17 +41,17 @@ def _frame_filter(f):
"cpython/abstract.h", "cpython/abstract.h",
] ]
for of in omit_functions: for of in omit_functions:
if of in f['name']: if of in name:
return False return False
for of in omit_filenames: for of in omit_filenames:
if of in f['filename']: if of in filename:
return False return False
return True return True
def _frames_fmt(frames, full_filename=False, reverse=False): def _frames_fmt(frames, full_filename=False, reverse=False):
if reverse: if reverse:
frames = reversed(frames) frames = reversed(frames)
return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f)] return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])]
def format_flamegraph(flamegraph_lines, flamegraph_script=None): def format_flamegraph(flamegraph_lines, flamegraph_script=None):
if flamegraph_script is None: if flamegraph_script is None: