Fast standalone symbolize for unwinding (#123966)

We've had issues using addr2line. On certain versions of
CentOS it is on a version that has a performance regression making it very slow,
and even normallly it is not that fast, taking several seconds even when parallelized
for a typical memory trace dump.

Folly Symbolize or LLVMSymbolize are fast but it requires PyTorch take a dependency on those libraries to do this, and given the number of environments we run stuff in, we end up hitting cases where we fallback to slow addr2line behavior.

This adds a standalone symbolizer to PyTorch similar to the unwinder which has
no external dependencies and is ~20x faster than addr2line for unwinding PyTorch frames.

I've tested this on some memory profiling runs using all combinations of {gcc, clang} x {dwarf4, dwarf5} and it seems to do a good job at getting line numbers and function names right. It is also careful to route all reads of library data through the `CheckedLexer` object, which ensure it is not reading out of bounds of the section. Errors are routed through UnwindError so that those exceptions get caught and we produce a ?? frame rather than crash. I also added a fuzz test which gives all our symbolizer options random addresses in the process to make sure they do not crash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123966
Approved by: https://github.com/ezyang
This commit is contained in:
zdevito
2024-04-22 20:08:36 +00:00
committed by PyTorch MergeBot
parent cf98cab1b6
commit 772ae6da1e
24 changed files with 1596 additions and 111 deletions

View File

@ -16,8 +16,11 @@ except ImportError:
import collections
import gc
import json
import mmap
import os
import random
import re
import struct
import subprocess
import sys
import tempfile
@ -70,7 +73,9 @@ from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_device_type import skipCUDAVersionIn
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_ARM64,
IS_JETSON,
IS_LINUX,
IS_WINDOWS,
parametrize,
run_tests,
@ -3579,6 +3584,70 @@ aten::mm""",
finally:
os.remove("torchtidy_report.json")
@unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding")
def test_fuzz_symbolize(self):
# generate some random addresses in the text section and make sure the
# symbolizers do not throw exceptions/crash
def get_text_sections():
text_sections = []
seen = set()
for filename in os.listdir("/proc/self/map_files"):
library = os.readlink("/proc/self/map_files/" + filename)
if ".so" not in library or library in seen:
continue
seen.add(library)
with open(os.path.join("/proc/self/map_files", library), "rb") as f:
mm = mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ)
def unpack(fmt, offset):
return struct.unpack(
fmt, mm[offset : offset + struct.calcsize(fmt)]
)
if mm[:4] != b"\x7fELF":
continue
(section_headers_start,) = unpack("Q", 40)
(section_header_size,) = unpack("H", 58)
(num_section_headers,) = unpack("H", 60)
(shstrndx,) = unpack("H", 62)
(shstrtab_offset,) = unpack(
"Q", section_headers_start + shstrndx * section_header_size + 24
)
for i in range(num_section_headers):
(section_name_offset,) = unpack(
"I", section_headers_start + i * section_header_size
)
name_start = shstrtab_offset + section_name_offset
section_name = mm[name_start : name_start + 6]
if section_name != b".text\0":
continue
(section_offset,) = unpack(
"Q", section_headers_start + i * section_header_size + 24
)
(section_size,) = unpack(
"Q", section_headers_start + i * section_header_size + 32
)
start = int(filename.split("-")[0], 16) + section_offset
text_sections.append((start, section_size))
break
mm.close()
return text_sections
r = random.Random()
r.seed(1)
text_sections = get_text_sections()
addrs = []
for i in range(200):
s = r.randrange(0, len(text_sections))
start, size = text_sections[s]
addr = r.randrange(start, start + size)
addrs.append(addr)
fast = torch._C._profiler.symbolize_addresses(addrs, "fast")
dladdr = torch._C._profiler.symbolize_addresses(addrs, "dladdr")
addr2line = torch._C._profiler.symbolize_addresses(addrs, "addr2line")
self.assertEqual(len(fast), len(addrs))
self.assertEqual(len(addr2line), len(fast))
if __name__ == "__main__":
run_tests()

View File

@ -16,7 +16,6 @@ import subprocess
import random
from random import randint
import json
import torch
import torch.cuda
from torch.cuda._memory_viz import profile_plot, _profile_to_snapshot

View File

@ -163,12 +163,14 @@ static PyObject* THPModule_initExtension(
PyObject* shm_manager_path) {
HANDLE_TH_ERRORS
#if !defined(FBCODE_CAFFE2)
if (torch::get_cpp_stacktraces_enabled() && !torch::get_disable_addr2line()) {
if (torch::get_cpp_stacktraces_enabled()) {
c10::SetStackTraceFetcher([]() -> std::string {
auto tb = torch::CapturedTraceback::gather(false, false, true);
LOG(WARNING)
<< "symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1..."
<< std::endl;
if (torch::get_symbolize_mode() == torch::unwind::Mode::addr2line) {
LOG(WARNING)
<< "symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1..."
<< std::endl;
}
auto s_tbs = torch::symbolize({tb.get()});
std::stringstream oss;
oss << "C++ CapturedTraceback:" << std::endl;

View File

@ -1,4 +1,5 @@
#include <torch/csrc/profiler/combined_traceback.h>
#include <torch/csrc/utils/cpp_stacktraces.h>
namespace torch {
@ -77,7 +78,7 @@ SymbolizedTracebacks symbolize(
}
// gather symbol names for C++ frames
if (!all_cpp_ips.empty()) {
r.all_frames = unwind::symbolize(all_cpp_ips);
r.all_frames = unwind::symbolize(all_cpp_ips, torch::get_symbolize_mode());
}
// batch symbolization requests so we dedup frame objects

View File

@ -79,8 +79,7 @@ PyTypeObject THPCapturedTracebackType = {
nullptr, /* tp_new */
};
namespace pybind11 {
namespace detail {
namespace pybind11::detail {
template <>
struct type_caster<std::shared_ptr<torch::CapturedTraceback>> {
@ -107,11 +106,9 @@ struct type_caster<std::shared_ptr<torch::CapturedTraceback>> {
}
};
} // namespace detail
} // namespace pybind11
} // namespace pybind11::detail
namespace torch {
namespace profiler {
namespace torch::profiler {
/* [NOTE: RecordFunctionFast]
* This is an alternate way to call record_function from python.
@ -606,6 +603,33 @@ void initPythonBindings(PyObject* module) {
}
return py_symbolize(tb_ptrs);
});
// directly convert address pointers to frames, used for testing symbolize
m.def(
"symbolize_addresses",
[](const std::vector<uint64_t>& frames, const std::string& mode_s) {
std::vector<std::tuple<std::string, int64_t, std::string>> frames_out;
torch::unwind::Mode mode = torch::unwind::Mode::addr2line;
if (mode_s == "fast") {
mode = torch::unwind::Mode::fast;
} else if (mode_s == "addr2line") {
mode = torch::unwind::Mode::addr2line;
} else if (mode_s == "dladdr") {
mode = torch::unwind::Mode::dladdr;
} else {
TORCH_CHECK(false, "unexpected mode ", mode_s);
}
std::vector<void*> frames_p;
frames_p.reserve(frames.size());
for (auto f : frames) {
frames_p.push_back((void*)f); // NOLINT
}
auto frame_objects = unwind::symbolize(frames_p, mode);
frames_out.reserve(frame_objects.size());
for (auto& frame : frame_objects) {
frames_out.emplace_back(frame.filename, frame.lineno, frame.funcname);
}
return frames_out;
});
installCapturedTracebackPython();
// NOLINTNEXTLINE(*-c-arrays*)
@ -639,5 +663,4 @@ void initPythonBindings(PyObject* module) {
throw python_error();
}
}
} // namespace profiler
} // namespace torch
} // namespace torch::profiler

View File

@ -2,6 +2,8 @@
#include <stdint.h>
#include <ostream>
namespace torch::unwind {
enum {
A_UNDEFINED = 0x0,
A_REG_PLUS_DATA = 0x1, // exp = REG[reg] + data0
@ -53,3 +55,5 @@ struct Action {
return out;
}
};
} // namespace torch::unwind

View File

@ -5,6 +5,7 @@
#include <unistd.h>
#include <memory>
namespace torch::unwind {
// helper to open a process with stdin/stdout/stderr streams.
struct Communicate {
Communicate(const char* command, const char** args) {
@ -63,3 +64,5 @@ struct Communicate {
std::unique_ptr<std::ostream> out_;
std::unique_ptr<std::ostream> err_;
};
} // namespace torch::unwind

View File

@ -0,0 +1,279 @@
#pragma once
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
#include <torch/csrc/profiler/unwind/dwarf_symbolize_enums.h>
#include <torch/csrc/profiler/unwind/lexer.h>
#include <torch/csrc/profiler/unwind/sections.h>
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <cstdint>
#include <optional>
namespace torch::unwind {
struct DebugInfo {
DebugInfo(Sections& s) : s_(s) {}
void parse(uint64_t offset) {
auto L = parseHeader(offset);
parseCompileUnit(L);
}
unwind::optional<uint64_t> lineNumberProgramOffset() {
return line_number_program_offset_;
}
uint64_t nextOffset() {
return end_ - s_.debug_info.data;
}
std::vector<std::pair<uint64_t, uint64_t>> ranges() {
if (range_ptr_) {
auto offset = range_ptr_->first;
if (range_ptr_->second == DW_FORM_rnglistx) {
UNWIND_CHECK(rnglists_base_, "rnglistx but not rnglists_base_ set");
LOG_INFO("index for rnglistx {:x} + {:x}\n", *rnglists_base_, offset);
CheckedLexer L = s_.debug_rnglists.lexer(
*rnglists_base_ + offset * sec_offset_size_);
auto read = readSegmentOffset(L);
offset = *rnglists_base_ + read;
}
return version_ == 4 ? readRanges4(offset) : readRanges5(offset);
}
if (!highpc_) {
return {};
}
return {{lowpc_, lowpc_ + *highpc_}};
}
bool is64bit() {
return is_64bit_;
}
private:
CheckedLexer parseHeader(uint64_t offset) {
offset_ = offset;
CheckedLexer L = s_.debug_info.lexer(offset_);
std::tie(length_, is_64bit_) = L.readSectionLength();
sec_offset_size_ = is_64bit_ ? 8 : 4;
end_ = (const char*)L.loc() + length_;
version_ = L.read<uint16_t>();
UNWIND_CHECK(
version_ == 5 || version_ == 4,
"unexpected dwarf version {}",
version_);
uint8_t address_size = 0;
if (version_ == 5) {
auto unit_type = L.read<uint8_t>();
UNWIND_CHECK(unit_type == 0x1, "unexpected unit type {}", unit_type);
address_size = L.read<uint8_t>();
debug_abbrev_offset_ =
is_64bit_ ? L.read<uint64_t>() : L.read<uint32_t>();
} else {
debug_abbrev_offset_ =
is_64bit_ ? L.read<uint64_t>() : L.read<uint32_t>();
address_size = L.read<uint8_t>();
}
LOG_INFO(
"compilation unit at offset {:x} with length {:x} and debug_abbrev_offset {:x}\n",
offset,
length_,
debug_abbrev_offset_);
UNWIND_CHECK(
address_size == 8,
"expected 64-bit dwarf but found address size {}",
address_size);
return L;
}
uint64_t readSegmentOffset(CheckedLexer& L) {
return s_.readSegmentOffset(L, is_64bit_);
}
uint64_t readEncoded(CheckedLexer& L, uint64_t encoding) {
switch (encoding) {
case DW_FORM_data8:
case DW_FORM_addr:
return L.read<uint64_t>();
case DW_FORM_data4:
return L.read<uint32_t>();
case DW_FORM_addrx: {
auto idx = L.readULEB128();
return s_.debug_addr.lexer(address_base_ + sizeof(uint64_t) * idx)
.read<uint64_t>();
}
case DW_FORM_sec_offset:
return readSegmentOffset(L);
case DW_FORM_rnglistx: {
return L.readULEB128();
}
default:
UNWIND_CHECK(false, "unexpected encoding");
}
}
void parseCompileUnit(CheckedLexer& L) {
auto entry = L.readULEB128();
auto A = findAbbrev(debug_abbrev_offset_, entry);
while (true) {
auto attr = A.readULEB128();
auto form = A.readULEB128();
if (attr == 0 && form == 0) {
break;
}
if (form == DW_FORM_implicit_const) {
A.readSLEB128();
}
if (attr == DW_AT_low_pc) {
lowpc_ = readEncoded(L, form);
LOG_INFO(" lowpc {:x}\n", lowpc_);
} else if (attr == DW_AT_high_pc) {
highpc_ = readEncoded(L, form);
range_ptr_ = std::nullopt;
LOG_INFO(" highpc {:x}\n", *highpc_);
} else if (attr == DW_AT_addr_base) {
UNWIND_CHECK(form == DW_FORM_sec_offset, "unexpected addr_base form");
address_base_ = readSegmentOffset(L);
LOG_INFO(" address base {:x}\n", address_base_);
} else if (attr == DW_AT_rnglists_base) {
UNWIND_CHECK(
form == DW_FORM_sec_offset, "unexpected rnglists_base form");
rnglists_base_ = readSegmentOffset(L);
LOG_INFO(" range base {:x}\n", *rnglists_base_);
} else if (form == DW_FORM_string) {
L.readCString();
} else if (attr == DW_AT_stmt_list) {
UNWIND_CHECK(form == DW_FORM_sec_offset, "unexpected stmt_list form");
LOG_INFO(" program table offset {:x}\n", *line_number_program_offset_);
line_number_program_offset_ = readSegmentOffset(L);
} else if (form == DW_FORM_exprloc) {
auto sz = L.readULEB128();
L.skip(int64_t(sz));
} else if (form == DW_FORM_block1) {
auto sz = L.read<uint8_t>();
L.skip(int64_t(sz));
} else if (attr == DW_AT_ranges) {
auto range_offset = readEncoded(L, form);
LOG_INFO("setting range_ptr to {:x} {:x}\n", range_offset, form);
range_ptr_.emplace(range_offset, form);
} else if (
form == DW_FORM_udata || form == DW_FORM_rnglistx ||
form == DW_FORM_strx || form == DW_FORM_loclistx ||
form == DW_FORM_addrx) {
L.readULEB128();
} else if (form == DW_FORM_sdata) {
L.readSLEB128();
} else {
auto sz = formSize(form, sec_offset_size_);
UNWIND_CHECK(sz, "unsupported form in compilation unit {:x}", form);
L.skip(int64_t(*sz));
}
}
}
std::vector<std::pair<uint64_t, uint64_t>> readRanges4(uint64_t offset) {
CheckedLexer L = s_.debug_ranges.lexer(offset);
std::vector<std::pair<uint64_t, uint64_t>> ranges;
uint64_t base = lowpc_;
while (true) {
auto start = L.read<uint64_t>();
auto end = L.read<uint64_t>();
if (start == 0 && end == 0) {
break;
}
if (start == std::numeric_limits<uint64_t>::max()) {
base = end;
} else {
ranges.emplace_back(base + start, base + end);
}
}
return ranges;
}
std::vector<std::pair<uint64_t, uint64_t>> readRanges5(uint64_t offset) {
CheckedLexer L = s_.debug_rnglists.lexer(offset);
uint64_t base = 0;
LOG_INFO("BEGIN RANGES {:x}\n", offset);
std::vector<std::pair<uint64_t, uint64_t>> ranges;
while (true) {
auto op = L.read<uint8_t>();
switch (op) {
case DW_RLE_end_of_list:
LOG_INFO("END RANGES\n");
return ranges;
case DW_RLE_base_addressx: {
base = readEncoded(L, DW_FORM_addrx);
LOG_INFO("BASE ADDRX {:x}\n", base);
} break;
case DW_RLE_startx_length: {
auto s = readEncoded(L, DW_FORM_addrx);
auto e = L.readULEB128();
LOG_INFO("startx_length {:x} {:x}\n", s, e);
ranges.emplace_back(s, s + e);
} break;
case DW_RLE_base_address:
base = L.read<uint64_t>();
LOG_INFO("BASE ADDR {:x}\n", base);
break;
case DW_RLE_offset_pair: {
auto s = L.readULEB128();
auto e = L.readULEB128();
LOG_INFO("offset_pair {:x} {:x}\n", s, e);
ranges.emplace_back(base + s, base + e);
} break;
case DW_RLE_start_length: {
auto s = L.read<uint64_t>();
auto e = L.readULEB128();
LOG_INFO("start_length {:x} {:x}\n", s, e);
ranges.emplace_back(s, s + e);
} break;
default:
UNWIND_CHECK(false, "unknown range op: {}", op);
}
}
}
CheckedLexer findAbbrev(uint64_t offset, uint64_t entry) {
CheckedLexer L = s_.debug_abbrev.lexer(offset);
while (true) {
auto abbrev_code = L.readULEB128();
UNWIND_CHECK(
abbrev_code != 0,
"could not find entry {} at offset {:x}",
entry,
offset);
auto tag = L.readULEB128();
L.read<uint8_t>(); // has children
if (abbrev_code == entry) {
UNWIND_CHECK(
tag == DW_TAG_compile_unit,
"first entry was not a compile unit but {}",
tag);
return L;
}
while (true) {
auto attr = L.readULEB128();
auto form = L.readULEB128();
if (attr == 0 && form == 0) {
break;
}
if (form == DW_FORM_implicit_const) {
L.readSLEB128();
}
}
}
}
Sections& s_;
optional<uint64_t> line_number_program_offset_;
uint64_t offset_ = 0;
uint8_t sec_offset_size_ = 0;
uint64_t length_ = 0;
const char* end_ = nullptr;
uint64_t debug_abbrev_offset_ = 0;
bool is_64bit_ = false;
std::optional<std::pair<uint64_t, uint8_t>> range_ptr_;
uint64_t lowpc_ = 0;
optional<uint64_t> highpc_;
uint16_t version_ = 0;
uint64_t address_base_ = 0;
optional<uint64_t> rnglists_base_;
};
} // namespace torch::unwind

View File

@ -0,0 +1,181 @@
#pragma once
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <cstdint>
#include <optional>
enum {
DW_TAG_subprogram = 0x2e,
DW_TAG_inlined_subroutine = 0x1d,
DW_TAG_compile_unit = 0x11,
DW_AT_sibling = 0x1, // reference
DW_AT_name = 0x3, // string
DW_AT_stmt_list = 0x10, // lineptr
DW_AT_addr_base = 0x73, // sec_offset
DW_AT_rnglists_base = 0x74, // sec_offset
DW_AT_low_pc = 0x11, // address
DW_AT_high_pc = 0x12, // address
DW_AT_specification = 0x47, // reference
DW_AT_abstract_origin = 0x31, // reference
DW_AT_linkage_name = 0x6e, // string
DW_AT_ranges = 0x55, // rnglist
DW_AT_str_offsets_base = 0x72, // sec_offset
DW_FORM_addr = 0x01,
DW_FORM_block2 = 0x03,
DW_FORM_block4 = 0x04,
DW_FORM_data2 = 0x05,
DW_FORM_data4 = 0x06,
DW_FORM_data8 = 0x07,
DW_FORM_string = 0x08,
DW_FORM_block = 0x09,
DW_FORM_block1 = 0x0a,
DW_FORM_data1 = 0x0b,
DW_FORM_flag = 0x0c,
DW_FORM_sdata = 0x0d,
DW_FORM_strp = 0x0e,
DW_FORM_udata = 0x0f,
DW_FORM_ref_addr = 0x10,
DW_FORM_ref1 = 0x11,
DW_FORM_ref2 = 0x12,
DW_FORM_ref4 = 0x13,
DW_FORM_ref8 = 0x14,
DW_FORM_ref_udata = 0x15,
DW_FORM_indirect = 0x16,
DW_FORM_sec_offset = 0x17,
DW_FORM_exprloc = 0x18,
DW_FORM_flag_present = 0x19,
DW_FORM_strx = 0x1a,
DW_FORM_addrx = 0x1b,
DW_FORM_ref_sup4 = 0x1c,
DW_FORM_strp_sup = 0x1d,
DW_FORM_data16 = 0x1e,
DW_FORM_line_strp = 0x1f,
DW_FORM_ref_sig8 = 0x20,
DW_FORM_implicit_const = 0x21,
DW_FORM_loclistx = 0x22,
DW_FORM_rnglistx = 0x23,
DW_FORM_ref_sup8 = 0x24,
DW_FORM_strx1 = 0x25,
DW_FORM_strx2 = 0x26,
DW_FORM_strx3 = 0x27,
DW_FORM_strx4 = 0x28,
DW_FORM_addrx1 = 0x29,
DW_FORM_addrx2 = 0x2a,
DW_FORM_addrx3 = 0x2b,
DW_FORM_addrx4 = 0x2c,
/* GNU Debug Fission extensions. */
DW_FORM_GNU_addr_index = 0x1f01,
DW_FORM_GNU_str_index = 0x1f02,
DW_FORM_GNU_ref_alt = 0x1f20, /* offset in alternate .debuginfo. */
DW_FORM_GNU_strp_alt = 0x1f21, /* offset in alternate .debug_str. */
DW_LNCT_path = 0x1,
DW_LNCT_directory_index = 0x2,
DW_LNS_extended_op = 0x00,
DW_LNE_end_sequence = 0x01,
DW_LNE_set_address = 0x02,
DW_LNS_copy = 0x01,
DW_LNS_advance_pc = 0x02,
DW_LNS_advance_line = 0x03,
DW_LNS_set_file = 0x04,
DW_LNS_const_add_pc = 0x08,
DW_LNS_fixed_advance_pc = 0x09,
DW_RLE_end_of_list = 0x0,
DW_RLE_base_addressx = 0x1,
DW_RLE_startx_endx = 0x2,
DW_RLE_startx_length = 0x3,
DW_RLE_offset_pair = 0x4,
DW_RLE_base_address = 0x5,
DW_RLE_start_end = 0x6,
DW_RLE_start_length = 0x7
};
static torch::unwind::optional<size_t> formSize(
uint64_t form,
uint8_t sec_offset_size) {
switch (form) {
case DW_FORM_addr:
return sizeof(void*);
case DW_FORM_block2:
case DW_FORM_block4:
return std::nullopt;
case DW_FORM_data2:
return 2;
case DW_FORM_data4:
return 4;
case DW_FORM_data8:
return 8;
case DW_FORM_string:
case DW_FORM_block:
case DW_FORM_block1:
return std::nullopt;
case DW_FORM_data1:
case DW_FORM_flag:
return 1;
case DW_FORM_sdata:
return std::nullopt;
case DW_FORM_strp:
return sec_offset_size;
case DW_FORM_udata:
return std::nullopt;
case DW_FORM_ref_addr:
return sec_offset_size;
case DW_FORM_ref1:
return 1;
case DW_FORM_ref2:
return 2;
case DW_FORM_ref4:
return 4;
case DW_FORM_ref8:
return 8;
case DW_FORM_ref_udata:
case DW_FORM_indirect:
return std::nullopt;
case DW_FORM_sec_offset:
return sec_offset_size;
case DW_FORM_exprloc:
return std::nullopt;
case DW_FORM_flag_present:
return 0;
case DW_FORM_strx:
case DW_FORM_addrx:
return std::nullopt;
case DW_FORM_ref_sup4:
return 4;
case DW_FORM_strp_sup:
return sec_offset_size;
case DW_FORM_data16:
return 16;
case DW_FORM_line_strp:
return sec_offset_size;
case DW_FORM_ref_sig8:
return 8;
case DW_FORM_implicit_const:
return 0;
case DW_FORM_loclistx:
case DW_FORM_rnglistx:
return std::nullopt;
case DW_FORM_ref_sup8:
return 8;
case DW_FORM_strx1:
return 1;
case DW_FORM_strx2:
return 2;
case DW_FORM_strx3:
return 3;
case DW_FORM_strx4:
return 4;
case DW_FORM_addrx1:
return 1;
case DW_FORM_addrx2:
return 2;
case DW_FORM_addrx3:
return 3;
case DW_FORM_addrx4:
return 4;
case DW_FORM_GNU_addr_index:
case DW_FORM_GNU_str_index:
case DW_FORM_GNU_ref_alt:
case DW_FORM_GNU_strp_alt:
default:
return std::nullopt;
}
}

View File

@ -7,6 +7,7 @@
// Overview of the format described in
// https://refspecs.linuxfoundation.org/LSB_1.3.0/gLSB/gLSB/ehframehdr.html
namespace torch::unwind {
struct EHFrameHdr {
EHFrameHdr(void* base) : base_(base) {
@ -93,3 +94,5 @@ struct EHFrameHdr {
int64_t fde_count_;
uint32_t table_size_;
};
} // namespace torch::unwind

View File

@ -0,0 +1,108 @@
#pragma once
#include <fmt/format.h>
#include <sys/types.h>
#include <torch/csrc/profiler/unwind/debug_info.h>
#include <torch/csrc/profiler/unwind/line_number_program.h>
#include <torch/csrc/profiler/unwind/sections.h>
#include <torch/csrc/profiler/unwind/unwind.h>
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <cstddef>
#include <memory>
namespace torch::unwind {
#define UNWIND_WARN(w, ...) \
do { \
w.emplace_back(fmt::format(__VA_ARGS__)); \
LOG_INFO("WARNING: {}\n", w.back()); \
} while (0);
struct FastSymbolizer {
FastSymbolizer() = default;
Frame symbolize(const std::string& library, uint64_t offset) {
LOG_INFO("symbolizing {} + 0x{:x}\n", library, offset);
Frame frame;
frame.funcname = "??";
frame.filename = library;
frame.lineno = offset;
auto s = getOrCreateSections(library);
if (auto e = s->findSubprogramName(offset)) {
frame.funcname = *e;
} else {
UNWIND_WARN(
warnings_,
"failed to find subprogram name for {} 0x{:x}",
library,
offset);
}
if (auto e = findLine(s, offset)) {
frame.filename = e->first;
frame.lineno = e->second;
} else {
UNWIND_WARN(
warnings_, "failed to find file/line for {} 0x{:x}", library, offset);
}
return frame;
}
const std::vector<std::string>& warnings() {
return warnings_;
}
private:
void parseDebugInfo(Sections* s) {
uint64_t offset = 0;
while (offset < s->debug_info.size) {
DebugInfo info(*s);
info.parse(offset);
if (auto lnp_offset = info.lineNumberProgramOffset()) {
for (auto r : info.ranges()) {
s->addDebugInfoRange(r.first, r.second, line_number_programs_.size());
}
line_number_programs_.emplace_back(
std::make_unique<LineNumberProgram>(*s, *lnp_offset));
}
offset = info.nextOffset();
}
}
Sections* getOrCreateSections(const std::string& library) {
auto it = libraries_.find(library);
if (it == libraries_.end()) {
it = libraries_.insert({library, std::make_unique<Sections>()}).first;
try {
Sections* s = it->second.get();
s->parse(library.c_str());
parseDebugInfo(s);
} catch (UnwindError& err) {
UNWIND_WARN(
warnings_, "failed to parse library {}: {}", library, err.what());
}
}
return it->second.get();
}
optional<std::pair<std::string, int64_t>> findLine(
Sections* s,
uint64_t offset) {
if (auto idx = s->findDebugInfoOffset(offset)) {
auto r = line_number_programs_.at(*idx).get();
try {
r->parse();
} catch (UnwindError& err) {
UNWIND_WARN(
warnings_,
"failed to read line number program [{:x}] {}",
r->offset(),
err.what());
}
if (auto e = r->find(offset)) {
return std::make_pair(r->filename(e->file), e->line);
}
}
return std::nullopt;
}
std::unordered_map<std::string, std::unique_ptr<Sections>> libraries_;
std::vector<std::unique_ptr<LineNumberProgram>> line_number_programs_;
std::vector<std::string> warnings_;
};
} // namespace torch::unwind

View File

@ -7,6 +7,8 @@
#include <sstream>
#include <vector>
namespace torch::unwind {
struct TableState {
Action cfa;
std::array<Action, D_REG_SIZE> registers;
@ -398,3 +400,5 @@ struct FDE {
return strstr(augmentation_string_, s) != nullptr;
}
};
} // namespace torch::unwind

View File

@ -1,19 +1,31 @@
#pragma once
#include <stdint.h>
#include <string.h>
#include <cstdint>
#include <cstring>
#include <utility>
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
#include <torch/csrc/profiler/unwind/unwind_error.h>
struct Lexer {
Lexer(void* data, void* base = nullptr)
: next_((const char*)data), base_((int64_t)base) {}
namespace torch::unwind {
template <bool checked>
struct LexerImpl {
LexerImpl(void* data, void* base = nullptr, void* end = nullptr)
: next_((const char*)data),
base_((int64_t)base),
end_((const char*)end) {}
template <typename T>
T read() {
T result;
auto end = next_ + sizeof(T);
UNWIND_CHECK(
!checked || end <= end_,
"read out of bounds {} >= {}",
(void*)end,
(void*)end_);
memcpy(&result, next_, sizeof(T));
next_ += sizeof(T);
next_ = end;
return result;
}
@ -21,7 +33,7 @@ struct Lexer {
int64_t readSLEB128() {
int64_t Value = 0;
unsigned Shift = 0;
uint8_t Byte;
uint8_t Byte = 0;
do {
Byte = read<uint8_t>();
uint64_t Slice = Byte & 0x7f;
@ -29,12 +41,12 @@ struct Lexer {
(Shift == 63 && Slice != 0 && Slice != 0x7f)) {
throw UnwindError("sleb128 too big for int64");
}
Value |= Slice << Shift;
Value |= int64_t(Slice << Shift);
Shift += 7;
} while (Byte >= 128);
// Sign extend negative numbers if needed.
if (Shift < 64 && (Byte & 0x40)) {
Value |= (-1ULL) << Shift;
Value |= int64_t((-1ULL) << Shift);
}
return Value;
}
@ -42,7 +54,7 @@ struct Lexer {
uint64_t readULEB128() {
uint64_t Value = 0;
unsigned Shift = 0;
uint8_t p;
uint8_t p = 0;
do {
p = read<uint8_t>();
uint64_t Slice = p & 0x7f;
@ -56,8 +68,17 @@ struct Lexer {
}
const char* readCString() {
auto result = next_;
next_ += strlen(next_) + 1;
return result;
if (!checked) {
next_ += strlen(next_) + 1;
return result;
}
while (next_ < end_) {
if (*next_++ == '\0') {
return result;
}
}
UNWIND_CHECK(
false, "string is out of bounds {} >= {}", (void*)next_, (void*)end_);
}
int64_t readEncoded(uint8_t enc) {
int64_t r = 0;
@ -81,20 +102,27 @@ struct Lexer {
}
return readEncoded(enc);
}
int64_t read4or8Length() {
return readSectionLength().first;
}
std::pair<int64_t, bool> readSectionLength() {
int64_t length = read<uint32_t>();
if (length == 0xFFFFFFFF) {
length = read<int64_t>();
return std::make_pair(read<int64_t>(), true);
}
return length;
return std::make_pair(length, false);
}
void* loc() const {
return (void*)next_;
}
Lexer& skip(int64_t bytes) {
LexerImpl& skip(int64_t bytes) {
next_ += bytes;
return *this;
}
int64_t readEncodedValue(uint8_t enc) {
switch (enc & 0xF) {
case DW_EH_PE_udata2:
@ -121,4 +149,11 @@ struct Lexer {
private:
const char* next_;
int64_t base_;
const char* end_;
};
// using Lexer = LexerImpl<false>;
using CheckedLexer = LexerImpl<true>;
using Lexer = LexerImpl<false>;
} // namespace torch::unwind

View File

@ -0,0 +1,325 @@
#include <c10/util/irange.h>
#include <torch/csrc/profiler/unwind/debug_info.h>
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
#include <torch/csrc/profiler/unwind/dwarf_symbolize_enums.h>
#include <torch/csrc/profiler/unwind/lexer.h>
#include <torch/csrc/profiler/unwind/sections.h>
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <tuple>
namespace torch::unwind {
struct LineNumberProgram {
LineNumberProgram(Sections& s, uint64_t offset) : s_(s), offset_(offset) {}
uint64_t offset() {
return offset_;
}
void parse() {
if (parsed_) {
return;
}
parsed_ = true;
CheckedLexer L = s_.debug_line.lexer(offset_);
std::tie(length_, is_64bit_) = L.readSectionLength();
program_end_ = (char*)L.loc() + length_;
auto version = L.read<uint16_t>();
UNWIND_CHECK(
version == 5 || version == 4,
"expected version 4 or 5 but found {}",
version);
if (version == 5) {
auto address_size = L.read<uint8_t>();
UNWIND_CHECK(
address_size == 8,
"expected 64-bit dwarf but found address size {}",
address_size);
segment_selector_size_ = L.read<uint8_t>();
}
header_length_ = is_64bit_ ? L.read<uint64_t>() : L.read<uint32_t>();
program_ = L;
program_.skip(int64_t(header_length_));
minimum_instruction_length_ = L.read<uint8_t>();
maximum_operations_per_instruction_ = L.read<uint8_t>();
default_is_stmt_ = L.read<uint8_t>();
line_base_ = L.read<int8_t>();
line_range_ = L.read<uint8_t>();
opcode_base_ = L.read<uint8_t>();
UNWIND_CHECK(line_range_ != 0, "line_range_ must be non-zero");
standard_opcode_lengths_.resize(opcode_base_);
for (size_t i = 1; i < opcode_base_; i++) {
standard_opcode_lengths_[i] = L.read<uint8_t>();
}
// fmt::print("{:x} {:x} {} {} {} {} {}\n", offset_, header_length_,
// minimum_instruction_length_, maximum_operations_per_instruction_,
// line_base_, line_range_, opcode_base_);
uint8_t directory_entry_format_count = L.read<uint8_t>();
if (version == 5) {
struct Member {
uint64_t content_type;
uint64_t form;
};
std::vector<Member> directory_members;
for (size_t i = 0; i < directory_entry_format_count; i++) {
directory_members.push_back({L.readULEB128(), L.readULEB128()});
}
uint64_t directories_count = L.readULEB128();
for (size_t i = 0; i < directories_count; i++) {
for (auto& member : directory_members) {
switch (member.content_type) {
case DW_LNCT_path: {
include_directories_.emplace_back(
s_.readString(L, member.form, is_64bit_, 0));
} break;
default: {
skipForm(L, member.form);
} break;
}
}
}
for (auto i : c10::irange(directories_count)) {
(void)i;
LOG_INFO("{} {}\n", i, include_directories_[i]);
}
auto file_name_entry_format_count = L.read<uint8_t>();
std::vector<Member> file_members;
for (size_t i = 0; i < file_name_entry_format_count; i++) {
file_members.push_back({L.readULEB128(), L.readULEB128()});
}
auto files_count = L.readULEB128();
for (size_t i = 0; i < files_count; i++) {
for (auto& member : file_members) {
switch (member.content_type) {
case DW_LNCT_path: {
file_names_.emplace_back(
s_.readString(L, member.form, is_64bit_, 0));
} break;
case DW_LNCT_directory_index: {
file_directory_index_.emplace_back(readData(L, member.form));
UNWIND_CHECK(
file_directory_index_.back() < include_directories_.size(),
"directory index out of range");
} break;
default: {
skipForm(L, member.form);
} break;
}
}
}
for (auto i : c10::irange(files_count)) {
(void)i;
LOG_INFO("{} {} {}\n", i, file_names_[i], file_directory_index_[i]);
}
} else {
include_directories_.emplace_back(""); // implicit cwd
while (true) {
auto str = L.readCString();
if (*str == '\0') {
break;
}
include_directories_.emplace_back(str);
}
file_names_.emplace_back("");
file_directory_index_.emplace_back(0);
while (true) {
auto str = L.readCString();
if (*str == '\0') {
break;
}
auto directory_index = L.readULEB128();
L.readULEB128(); // mod_time
L.readULEB128(); // file_length
file_names_.emplace_back(str);
file_directory_index_.push_back(directory_index);
}
}
UNWIND_CHECK(
maximum_operations_per_instruction_ == 1,
"maximum_operations_per_instruction_ must be 1");
UNWIND_CHECK(
minimum_instruction_length_ == 1,
"minimum_instruction_length_ must be 1");
readProgram();
}
struct Entry {
uint32_t file = 1;
int64_t line = 1;
};
unwind::optional<Entry> find(uint64_t address) {
auto e = program_index_.find(address);
if (!e) {
return std::nullopt;
}
return all_programs_.at(*e).find(address);
}
std::string filename(uint64_t index) {
return fmt::format(
"{}/{}",
include_directories_.at(file_directory_index_.at(index)),
file_names_.at(index));
}
private:
void skipForm(CheckedLexer& L, uint64_t form) {
auto sz = formSize(form, is_64bit_ ? 8 : 4);
UNWIND_CHECK(sz, "unsupported form {}", form);
L.skip(int64_t(*sz));
}
uint64_t readData(CheckedLexer& L, uint64_t encoding) {
switch (encoding) {
case DW_FORM_data1:
return L.read<uint8_t>();
case DW_FORM_data2:
return L.read<uint16_t>();
case DW_FORM_data4:
return L.read<uint32_t>();
case DW_FORM_data8:
return L.read<uint64_t>();
case DW_FORM_udata:
return L.readULEB128();
default:
UNWIND_CHECK(false, "unsupported data encoding {}", encoding);
}
}
void produceEntry() {
if (shadow_) {
return;
}
if (ranges_.size() == 1) {
start_address_ = address_;
}
PRINT_LINE_TABLE(
"{:x}\t{}\t{}\n", address_, filename(entry_.file), entry_.line);
UNWIND_CHECK(
entry_.file < file_names_.size(),
"file index {} > {} entries",
entry_.file,
file_names_.size());
ranges_.add(address_, entry_, true);
}
void endSequence() {
if (shadow_) {
return;
}
PRINT_LINE_TABLE(
"{:x}\tEND\n", address_, filename(entry_.file), entry_.line);
program_index_.add(start_address_, all_programs_.size(), false);
program_index_.add(address_, std::nullopt, false);
all_programs_.emplace_back(std::move(ranges_));
ranges_ = RangeTable<Entry>();
}
void readProgram() {
while (program_.loc() < program_end_) {
PRINT_INST("{:x}: ", (char*)program_.loc() - (s_.debug_line.data));
uint8_t op = program_.read<uint8_t>();
if (op >= opcode_base_) {
auto op2 = int64_t(op - opcode_base_);
address_ += op2 / line_range_;
entry_.line += line_base_ + (op2 % line_range_);
PRINT_INST(
"address += {}, line += {}\n",
op2 / line_range_,
line_base_ + (op2 % line_range_));
produceEntry();
} else {
switch (op) {
case DW_LNS_extended_op: {
auto len = program_.readULEB128();
auto extended_op = program_.read<uint8_t>();
switch (extended_op) {
case DW_LNE_end_sequence: {
PRINT_INST("end_sequence\n");
endSequence();
entry_ = Entry{};
} break;
case DW_LNE_set_address: {
address_ = program_.read<uint64_t>();
if (!shadow_) {
PRINT_INST(
"set address {:x} {:x} {:x}\n",
address_,
min_address_,
max_address_);
}
shadow_ = address_ == 0;
} break;
default: {
PRINT_INST("skip extended op {}\n", extended_op);
program_.skip(int64_t(len - 1));
} break;
}
} break;
case DW_LNS_copy: {
PRINT_INST("copy\n");
produceEntry();
} break;
case DW_LNS_advance_pc: {
PRINT_INST("advance pc\n");
address_ += program_.readULEB128();
} break;
case DW_LNS_advance_line: {
entry_.line += program_.readSLEB128();
PRINT_INST("advance line {}\n", entry_.line);
} break;
case DW_LNS_set_file: {
PRINT_INST("set file\n");
entry_.file = program_.readULEB128();
} break;
case DW_LNS_const_add_pc: {
PRINT_INST("const add pc\n");
address_ += (255 - opcode_base_) / line_range_;
} break;
case DW_LNS_fixed_advance_pc: {
PRINT_INST("fixed advance pc\n");
address_ += program_.read<uint16_t>();
} break;
default: {
PRINT_INST("other {}\n", op);
auto n = standard_opcode_lengths_[op];
for (int i = 0; i < n; ++i) {
program_.readULEB128();
}
} break;
}
}
}
PRINT_INST(
"{:x}: end {:x}\n",
((char*)program_.loc() - s_.debug_line.data),
program_end_ - s_.debug_line.data);
}
uint64_t address_ = 0;
bool shadow_ = false;
bool parsed_ = false;
Entry entry_ = {};
std::vector<std::string> include_directories_;
std::vector<std::string> file_names_;
std::vector<uint64_t> file_directory_index_;
uint8_t segment_selector_size_ = 0;
uint8_t minimum_instruction_length_ = 0;
uint8_t maximum_operations_per_instruction_ = 0;
int8_t line_base_ = 0;
uint8_t line_range_ = 0;
uint8_t opcode_base_ = 0;
bool default_is_stmt_ = false;
CheckedLexer program_ = {nullptr};
char* program_end_ = nullptr;
uint64_t header_length_ = 0;
uint64_t length_ = 0;
bool is_64bit_ = false;
std::vector<uint8_t> standard_opcode_lengths_;
Sections& s_;
uint64_t offset_;
uint64_t start_address_ = 0;
RangeTable<uint64_t> program_index_;
std::vector<RangeTable<Entry>> all_programs_;
RangeTable<Entry> ranges_;
};
} // namespace torch::unwind

View File

@ -0,0 +1,150 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <elf.h>
#include <fcntl.h>
#include <fmt/format.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <torch/csrc/profiler/unwind/lexer.h>
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <unistd.h>
#include <cerrno>
#include <cstdio>
#include <cstring>
#include <iostream>
namespace torch::unwind {
struct Section {
char* data = nullptr;
size_t size = 0;
const char* string(size_t offset) {
return lexer(offset).readCString();
}
CheckedLexer lexer(size_t offset) {
return CheckedLexer(data + offset, data, data + size);
}
};
/// Memory maps a file into the address space read-only, and manages the
/// lifetime of the mapping. Here are a few use cases:
/// 1. Used in the loader to read in initial image, and to inspect
// ELF files for dependencies before callling dlopen.
///
/// 2. Used in unity to load the elf file.
struct MemFile {
explicit MemFile(const char* filename_)
: fd_(open(filename_, O_RDONLY)),
mem_(nullptr),
n_bytes_(0),
name_(filename_) {
UNWIND_CHECK(
fd_ != -1, "failed to open {}: {}", filename_, strerror(errno));
// NOLINTNEXTLINE
struct stat s;
if (-1 == fstat(fd_, &s)) {
close(fd_); // destructors don't run during exceptions
UNWIND_CHECK(false, "failed to stat {}: {}", filename_, strerror(errno));
}
n_bytes_ = s.st_size;
UNWIND_CHECK(
n_bytes_ > sizeof(Elf64_Ehdr), "empty shared library: {}", filename_);
mem_ = (char*)mmap(nullptr, n_bytes_, PROT_READ, MAP_SHARED, fd_, 0);
if (MAP_FAILED == mem_) {
close(fd_);
UNWIND_CHECK(false, "failed to mmap {}: {}", filename_, strerror(errno));
}
ehdr_ = (Elf64_Ehdr*)mem_;
#define ELF_CHECK(cond) UNWIND_CHECK(cond, "not an ELF file: {}", filename_)
ELF_CHECK(ehdr_->e_ident[EI_MAG0] == ELFMAG0);
ELF_CHECK(ehdr_->e_ident[EI_MAG1] == ELFMAG1);
ELF_CHECK(ehdr_->e_ident[EI_MAG2] == ELFMAG2);
ELF_CHECK(ehdr_->e_ident[EI_MAG3] == ELFMAG3);
ELF_CHECK(ehdr_->e_ident[EI_CLASS] == ELFCLASS64);
ELF_CHECK(ehdr_->e_ident[EI_VERSION] == EV_CURRENT);
ELF_CHECK(ehdr_->e_version == EV_CURRENT);
ELF_CHECK(ehdr_->e_machine == EM_X86_64);
#undef ELF_CHECK
UNWIND_CHECK(
ehdr_->e_shoff + sizeof(Elf64_Shdr) * ehdr_->e_shnum <= n_bytes_,
"invalid section header table {} {} {}",
ehdr_->e_shoff + sizeof(Elf64_Shdr) * ehdr_->e_shnum,
n_bytes_,
ehdr_->e_shnum);
shdr_ = (Elf64_Shdr*)(mem_ + ehdr_->e_shoff);
UNWIND_CHECK(
ehdr_->e_shstrndx < ehdr_->e_shnum, "invalid strtab section offset");
auto& strtab_hdr = shdr_[ehdr_->e_shstrndx];
strtab_ = getSection(strtab_hdr);
}
MemFile(const MemFile&) = delete;
MemFile& operator=(const MemFile&) = delete;
[[nodiscard]] const char* data() const {
return (const char*)mem_;
}
/// Returns whether or not the file descriptor
/// of the underlying file is valid.
int valid() {
return fcntl(fd_, F_GETFD) != -1 || errno != EBADF;
}
~MemFile() {
if (mem_) {
munmap((void*)mem_, n_bytes_);
}
if (fd_) {
close(fd_);
}
}
/// Returns the size of the underlying file defined by the `MemFile`
size_t size() {
return n_bytes_;
}
[[nodiscard]] int fd() const {
return fd_;
}
Section getSection(const Elf64_Shdr& shdr) {
UNWIND_CHECK(shdr.sh_offset + shdr.sh_size <= n_bytes_, "invalid section");
return Section{mem_ + shdr.sh_offset, shdr.sh_size};
}
Section getSection(const char* name, bool optional) {
for (int i = 0; i < ehdr_->e_shnum; i++) {
if (strcmp(strtab_.string(shdr_[i].sh_name), name) == 0) {
return getSection(shdr_[i]);
}
}
UNWIND_CHECK(optional, "{} has no section {}", name_, name);
return Section{nullptr, 0};
}
Section strtab() {
return strtab_;
}
private:
template <typename T>
T* load(size_t offset) {
UNWIND_CHECK(offset < n_bytes_, "out of range");
return (T*)(mem_ + offset);
}
int fd_;
char* mem_;
size_t n_bytes_;
std::string name_;
Elf64_Ehdr* ehdr_;
Elf64_Shdr* shdr_;
Section strtab_ = {nullptr, 0};
};
} // namespace torch::unwind

View File

@ -0,0 +1,74 @@
#pragma once
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <algorithm>
#include <memory>
#include <optional>
#include <unordered_map>
#include <vector>
namespace torch::unwind {
template <typename T>
struct RangeTable {
RangeTable() {
// guarentee that lower_bound[-1] is always valid
addresses_.push_back(0);
payloads_.emplace_back(std::nullopt);
}
void add(uint64_t address, unwind::optional<T> payload, bool sorted) {
if (addresses_.back() > address) {
UNWIND_CHECK(!sorted, "expected addresses to be sorted");
sorted_ = false;
}
addresses_.push_back(address);
payloads_.emplace_back(std::move(payload));
}
unwind::optional<T> find(uint64_t address) {
maybeSort();
auto it = std::upper_bound(addresses_.begin(), addresses_.end(), address);
return payloads_.at(it - addresses_.begin() - 1);
}
void dump() {
for (size_t i = 0; i < addresses_.size(); i++) {
fmt::print("{} {:x}: {}\n", i, addresses_[i], payloads_[i] ? "" : "END");
}
}
size_t size() const {
return addresses_.size();
}
uint64_t back() {
maybeSort();
return addresses_.back();
}
private:
void maybeSort() {
if (sorted_) {
return;
}
std::vector<uint64_t> indices;
indices.reserve(addresses_.size());
for (size_t i = 0; i < addresses_.size(); i++) {
indices.push_back(i);
}
std::sort(indices.begin(), indices.end(), [&](uint64_t a, uint64_t b) {
return addresses_[a] < addresses_[b] ||
(addresses_[a] == addresses_[b] &&
bool(payloads_[a]) < bool(payloads_[b]));
});
std::vector<uint64_t> addresses;
std::vector<unwind::optional<T>> payloads;
addresses.reserve(addresses_.size());
payloads.reserve(addresses_.size());
for (auto i : indices) {
addresses.push_back(addresses_[i]);
payloads.push_back(payloads_[i]);
}
addresses_ = std::move(addresses);
payloads_ = std::move(payloads);
sorted_ = true;
}
bool sorted_ = true;
std::vector<uint64_t> addresses_;
std::vector<unwind::optional<T>> payloads_;
};
} // namespace torch::unwind

View File

@ -0,0 +1,124 @@
#pragma once
#include <cxxabi.h>
#include <elf.h>
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
#include <torch/csrc/profiler/unwind/dwarf_symbolize_enums.h>
#include <torch/csrc/profiler/unwind/mem_file.h>
#include <torch/csrc/profiler/unwind/range_table.h>
#include <torch/csrc/profiler/unwind/unwind_error.h>
#include <cstdint>
namespace torch::unwind {
static std::string demangle(const std::string& mangled_name) {
int status = 0;
char* realname =
abi::__cxa_demangle(mangled_name.c_str(), nullptr, nullptr, &status);
if (status == 0) {
std::string demangled_name(realname);
// NOLINTNEXTLINE
free(realname);
return demangled_name;
} else {
return mangled_name;
}
}
struct Sections {
Sections() = default;
void parse(const char* name) {
library_ = std::make_unique<MemFile>(name);
strtab = library_->getSection(".strtab", false);
symtab = library_->getSection(".symtab", true);
debug_info = library_->getSection(".debug_info", true);
if (debug_info.size > 0) {
debug_abbrev = library_->getSection(".debug_abbrev", false);
debug_str = library_->getSection(".debug_str", false);
debug_line = library_->getSection(".debug_line", false);
// dwarf 5
debug_line_str = library_->getSection(".debug_line_str", true);
debug_rnglists = library_->getSection(".debug_rnglists", true);
debug_addr = library_->getSection(".debug_addr", true);
// dwarf 4
debug_ranges = library_->getSection(".debug_ranges", true);
}
parseSymtab();
}
Section debug_info;
Section debug_abbrev;
Section debug_str;
Section debug_line;
Section debug_line_str;
Section debug_rnglists;
Section debug_ranges;
Section debug_addr;
Section symtab;
Section strtab;
const char* readString(
CheckedLexer& data,
uint64_t encoding,
bool is_64bit,
uint64_t str_offsets_base) {
switch (encoding) {
case DW_FORM_string: {
return data.readCString();
}
case DW_FORM_strp: {
return debug_str.string(readSegmentOffset(data, is_64bit));
}
case DW_FORM_line_strp: {
return debug_line_str.string(readSegmentOffset(data, is_64bit));
}
default:
UNWIND_CHECK(false, "unsupported string encoding {:x}", encoding);
}
}
uint64_t readSegmentOffset(CheckedLexer& data, bool is_64bit) {
return is_64bit ? data.read<uint64_t>() : data.read<uint32_t>();
}
unwind::optional<uint64_t> findDebugInfoOffset(uint64_t address) {
return debug_info_offsets_.find(address);
}
size_t compilationUnitCount() {
return debug_info_offsets_.size() / 2;
}
void addDebugInfoRange(
uint64_t start,
uint64_t end,
uint64_t debug_info_offset) {
debug_info_offsets_.add(start, debug_info_offset, false);
debug_info_offsets_.add(end, std::nullopt, false);
}
optional<std::string> findSubprogramName(uint64_t address) {
if (auto e = symbol_table_.find(address)) {
return demangle(strtab.string(*e));
}
return std::nullopt;
}
private:
void parseSymtab() {
auto L = symtab.lexer(0);
char* end = symtab.data + symtab.size;
while (L.loc() < end) {
auto symbol = L.read<Elf64_Sym>();
if (symbol.st_shndx == SHN_UNDEF ||
ELF64_ST_TYPE(symbol.st_info) != STT_FUNC) {
continue;
}
symbol_table_.add(symbol.st_value, symbol.st_name, false);
symbol_table_.add(symbol.st_value + symbol.st_size, std::nullopt, false);
}
}
std::unique_ptr<MemFile> library_;
RangeTable<uint64_t> debug_info_offsets_;
RangeTable<uint64_t> symbol_table_;
};
} // namespace torch::unwind

View File

@ -1,6 +1,7 @@
#include <c10/util/Exception.h>
#include <torch/csrc/profiler/unwind/unwind.h>
#include <torch/csrc/utils/cpp_stacktraces.h>
#include <unordered_map>
#if !defined(__linux__) || !defined(__x86_64__) || !defined(__has_include) || \
!__has_include("ext/stdio_filebuf.h")
@ -18,7 +19,7 @@ c10::optional<std::pair<std::string, uint64_t>> libraryFor(void* addr) {
}
#ifndef FBCODE_CAFFE2
std::vector<Frame> symbolize(const std::vector<void*>& frames) {
std::vector<Frame> symbolize(const std::vector<void*>& frames, Mode mode) {
TORCH_CHECK(
false,
"record_context_cpp is not support on non-linux non-x86_64 platforms");
@ -48,10 +49,15 @@ Stats stats() {
#include <torch/csrc/profiler/unwind/communicate.h>
#include <torch/csrc/profiler/unwind/dwarf_enums.h>
#include <torch/csrc/profiler/unwind/eh_frame_hdr.h>
#include <torch/csrc/profiler/unwind/fast_symbolizer.h>
#include <torch/csrc/profiler/unwind/fde.h>
#include <torch/csrc/profiler/unwind/unwinder.h>
#include <shared_mutex>
extern "C" void unwind_c(std::vector<void*>* result, int64_t rsp, int64_t rbp);
extern "C" void unwind_entry(std::vector<void*>* result);
namespace torch::unwind {
struct UpgradeExclusive {
UpgradeExclusive(std::shared_lock<std::shared_timed_mutex>& rdlock)
: rdlock_(rdlock) {
@ -197,7 +203,7 @@ struct UnwindCache {
Unwinder unwinder = Unwinder::unknown();
try {
unwinder = libraryFor(addr).unwinderFor(addr);
} catch (UnwindError& err) {
} catch (unwind::UnwindError& err) {
// because unwinders are cached this will only print
// once per frame that cannot be unwound.
TORCH_WARN("Unsupported unwinding pattern: ", err.what());
@ -276,46 +282,6 @@ struct UnwindCache {
static UnwindCache unwind_cache;
static std::shared_timed_mutex cache_mutex_;
extern "C" void unwind_c(std::vector<void*>* result, int64_t rsp, int64_t rbp);
extern "C" void unwind_c(std::vector<void*>* result, int64_t rsp, int64_t rbp) {
std::shared_lock lock(cache_mutex_);
UnwindState state{};
// NOLINTNEXTLINE(performance-no-int-to-ptr)
state.rip = *(int64_t*)(rsp);
// +8 because we saved rsp after the return address was already pushed
// to the stack
state.rsp = rsp + 8;
state.rbp = rbp;
unwind_cache.checkRefresh(lock);
while (true) { // unwind for _start sets rip as being undefined
// NOLINTNEXTLINE(performance-no-int-to-ptr)
result->push_back((void*)state.rip);
const Unwinder& uw = unwind_cache.unwinderFor(state.rip, lock);
if (uw.terminator()) {
if (uw.isUnknown()) {
result->push_back(nullptr);
}
break;
}
state = uw.run(state);
}
}
extern "C" void unwind_entry(std::vector<void*>* result);
// calling convention puts the first three pointer/int64_t arguments in
// rdi rsi rdx (all caller-saved)
// rdi already holds the pointer to the result vector
// we add arguments for current rsp and rbp and then tail call
// into unwind_c
__asm__(
".global unwind_entry\n"
"unwind_entry:\n"
"mov %rsp, %rsi;\n"
"mov %rbp, %rdx;\n"
"jmp unwind_c;\n");
namespace torch::unwind {
std::vector<void*> unwind() {
std::vector<void*> frames;
unwind_entry(&frames);
@ -335,6 +301,15 @@ c10::optional<std::pair<std::string, uint64_t>> libraryFor(void* addr) {
library_info->name(), (uint64_t)addr - library_info->load_bias());
}
static std::string dladdr_lookup(void* addr) {
Dl_info dlinfo;
std::string funcname = "??";
if (dladdr(addr, &dlinfo) && dlinfo.dli_sname) {
funcname = demangle(dlinfo.dli_sname);
}
return funcname;
}
struct Symbolizer {
Symbolizer() {
auto envar = std::getenv("TORCH_ADDR2LINE_BINARY");
@ -345,9 +320,6 @@ struct Symbolizer {
} else {
addr2line_binary_ = "addr2line"; // default
}
if (torch::get_disable_addr2line()) {
addr2line_binary_ = nullptr;
}
}
static std::lock_guard<std::mutex> guard() {
static std::mutex mutex;
@ -367,16 +339,6 @@ struct Symbolizer {
frame_map_[addr] = Frame{"??", "<unwind unsupported>", 0};
return;
}
if (addr2line_binary_ == nullptr) {
Dl_info dlinfo;
std::string funcname = "??";
if (dladdr(addr, &dlinfo) && dlinfo.dli_sname) {
funcname = demangle(dlinfo.dli_sname);
}
frame_map_[addr] = Frame{
maybe_library->first, std::move(funcname), maybe_library->second - 1};
return;
}
has_pending_results_ = true;
auto& entry = getOrCreate(maybe_library->first);
entry.queried.push_back(addr);
@ -448,23 +410,59 @@ struct Symbolizer {
frame_map_[e.queried[e.completed]] = std::move(frame);
}
}
std::string demangle(const std::string& mangled_name) {
int status = 0;
char* realname =
abi::__cxa_demangle(mangled_name.c_str(), nullptr, nullptr, &status);
if (status == 0) {
std::string demangled_name(realname);
// NOLINTNEXTLINE
free(realname);
return demangled_name;
} else {
return mangled_name;
}
}
};
#ifndef FBCODE_CAFFE2
std::vector<Frame> symbolize(const std::vector<void*>& frames) {
static std::vector<Frame> symbolize_fast(
const std::vector<void*>& frames,
Mode mode) {
static std::mutex cache_mutex;
static std::array<ska::flat_hash_map<void*, Frame>, 2> frame_maps;
auto& frame_map = frame_maps[mode == Mode::fast ? 0 : 1];
std::vector<uint32_t> indices_to_lookup;
std::vector<Frame> results;
results.reserve(frames.size());
{
std::lock_guard<std::mutex> lock(cache_mutex);
for (auto i : c10::irange(frames.size())) {
void* f = frames.at(i);
auto it = frame_map.find(f);
if (it == frame_map.end()) {
indices_to_lookup.push_back(i);
results.emplace_back(Frame{"??", "??", 0});
} else {
results.emplace_back(it->second);
}
}
}
if (!indices_to_lookup.empty()) {
// do symbolizer work
FastSymbolizer symbolizer;
for (auto i : indices_to_lookup) {
void* addr = frames.at(i);
Frame& f = results.at(i);
auto library = libraryFor(frames.at(i));
if (library) {
if (mode == Mode::fast) {
f = symbolizer.symbolize(library->first, library->second - 1);
} else {
f = Frame{library->first, "??", library->second - 1};
}
}
if (f.funcname == "??") {
f.funcname = dladdr_lookup(addr);
}
}
std::lock_guard<std::mutex> lock(cache_mutex);
for (auto i : indices_to_lookup) {
frame_map.emplace(frames.at(i), results.at(i));
}
}
return results;
}
static std::vector<Frame> symbolize_addr2line(
const std::vector<void*>& frames) {
auto guard = Symbolizer::guard();
Symbolizer& s = Symbolizer::get();
for (auto f : frames) {
@ -477,6 +475,16 @@ std::vector<Frame> symbolize(const std::vector<void*>& frames) {
}
return results;
}
// fbcode will use llvm symbolize since there is an llvm dependency already
#ifndef FBCODE_CAFFE2
std::vector<Frame> symbolize(const std::vector<void*>& frames, Mode mode) {
if (mode == Mode::addr2line) {
return symbolize_addr2line(frames);
} else {
return symbolize_fast(frames, mode);
}
}
#endif
Stats stats() {
@ -484,4 +492,42 @@ Stats stats() {
}
} // namespace torch::unwind
extern "C" void unwind_c(std::vector<void*>* result, int64_t rsp, int64_t rbp) {
std::shared_lock lock(torch::unwind::cache_mutex_);
torch::unwind::UnwindState state{};
// NOLINTNEXTLINE(performance-no-int-to-ptr)
state.rip = *(int64_t*)(rsp);
// +8 because we saved rsp after the return address was already pushed
// to the stack
state.rsp = rsp + 8;
state.rbp = rbp;
torch::unwind::unwind_cache.checkRefresh(lock);
while (true) { // unwind for _start sets rip as being undefined
// NOLINTNEXTLINE(performance-no-int-to-ptr)
result->push_back((void*)state.rip);
const torch::unwind::Unwinder& uw =
torch::unwind::unwind_cache.unwinderFor(state.rip, lock);
if (uw.terminator()) {
if (uw.isUnknown()) {
result->push_back(nullptr);
}
break;
}
state = uw.run(state);
}
}
// calling convention puts the first three pointer/int64_t arguments in
// rdi rsi rdx (all caller-saved)
// rdi already holds the pointer to the result vector
// we add arguments for current rsp and rbp and then tail call
// into unwind_c
__asm__(
".global unwind_entry\n"
"unwind_entry:\n"
"mov %rsp, %rsi;\n"
"mov %rbp, %rdx;\n"
"jmp unwind_c;\n");
#endif

View File

@ -1,11 +1,11 @@
#pragma once
#include <c10/macros/Export.h>
#include <c10/util/Optional.h>
#include <cstdint>
#include <string>
#include <vector>
namespace torch {
namespace unwind {
namespace torch::unwind {
// gather current stack, relatively fast.
// gets faster once the cache of program counter locations is warm.
TORCH_API std::vector<void*> unwind();
@ -16,13 +16,17 @@ struct Frame {
uint64_t lineno;
};
enum class Mode { addr2line, fast, dladdr };
// note: symbolize is really slow
// it will launch an addr2line process that has to parse dwarf
// information from the libraries that frames point into.
// Callers should first batch up all the unique void* pointers
// across a number of unwind states and make a single call to
// symbolize.
TORCH_API std::vector<Frame> symbolize(const std::vector<void*>& frames);
TORCH_API std::vector<Frame> symbolize(
const std::vector<void*>& frames,
Mode mode);
// returns path to the library, and the offset of the addr inside the library
TORCH_API c10::optional<std::pair<std::string, uint64_t>> libraryFor(
@ -36,5 +40,4 @@ struct Stats {
};
Stats stats();
} // namespace unwind
} // namespace torch
} // namespace torch::unwind

View File

@ -1,6 +1,31 @@
#pragma once
#include <c10/util/Optional.h>
#include <fmt/format.h>
#include <stdexcept>
namespace torch::unwind {
struct UnwindError : public std::runtime_error {
using std::runtime_error::runtime_error;
};
#define UNWIND_CHECK(cond, fmtstring, ...) \
do { \
if (!(cond)) { \
throw unwind::UnwindError(fmt::format( \
"{}:{}: " fmtstring, __FILE__, __LINE__, ##__VA_ARGS__)); \
} \
} while (0)
// #define LOG_INFO(...) fmt::print(__VA_ARGS__)
#define LOG_INFO(...)
// #define PRINT_INST(...) LOG_INFO(__VA_ARGS__)
#define PRINT_INST(...)
// #define PRINT_LINE_TABLE(...) LOG_INFO(__VA_ARGS__)
#define PRINT_LINE_TABLE(...)
using c10::optional; // NOLINT
} // namespace torch::unwind

View File

@ -6,9 +6,9 @@
#include <torch/csrc/profiler/unwind/unwind.h>
namespace torch {
namespace unwind {
namespace torch::unwind {
std::vector<Frame> symbolize(const std::vector<void*>& frames) {
std::vector<Frame> symbolize(const std::vector<void*>& frames, Mode mode) {
static std::mutex symbolize_mutex;
static llvm::symbolize::LLVMSymbolizer symbolizer;
static ska::flat_hash_map<void*, Frame> frame_map_;
@ -38,7 +38,7 @@ std::vector<Frame> symbolize(const std::vector<void*>& frames) {
return results;
}
} // namespace unwind
} // namespace torch::unwind
} // namespace torch
#endif

View File

@ -4,6 +4,8 @@
#include <cstdint>
#include <limits>
namespace torch::unwind {
struct UnwindState {
int64_t rip, rbp, rsp;
};
@ -75,3 +77,5 @@ struct Unwinder {
int64_t rbp_off_;
bool deref_{false};
};
} // namespace torch::unwind

View File

@ -47,9 +47,31 @@ bool get_cpp_stacktraces_enabled() {
return enabled;
}
bool get_disable_addr2line() {
static bool disabled = compute_disable_addr2line();
return disabled;
static torch::unwind::Mode compute_symbolize_mode() {
auto envar_c = std::getenv("TORCH_SYMBOLIZE_MODE");
if (envar_c) {
std::string envar = envar_c;
if (envar == "dladdr") {
return unwind::Mode::dladdr;
} else if (envar == "addr2line") {
return unwind::Mode::addr2line;
} else if (envar == "fast") {
return unwind::Mode::fast;
} else {
TORCH_CHECK(
false,
"expected {dladdr, addr2line, fast} for TORCH_SYMBOLIZE_MODE, got ",
envar);
}
} else {
return compute_disable_addr2line() ? unwind::Mode::dladdr
: unwind::Mode::addr2line;
}
}
unwind::Mode get_symbolize_mode() {
static unwind::Mode mode = compute_symbolize_mode();
return mode;
}
} // namespace torch

View File

@ -1,8 +1,9 @@
#pragma once
#include <torch/csrc/Export.h>
#include <torch/csrc/profiler/unwind/unwind.h>
namespace torch {
TORCH_API bool get_cpp_stacktraces_enabled();
TORCH_API bool get_disable_addr2line();
TORCH_API torch::unwind::Mode get_symbolize_mode();
} // namespace torch