mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73368 debug_pkl file inside of pytorch's .pt file consists of a list of SourceRanges. Each SourceRange points to a Source which is a stack track, filename, and start, end numbers. Those are emitted in debug_pkl file as strings. Since many SourceRange shares the same source, the string for trace can be deduped. The newer format saves a set of unique traces in a tuple, then each SourceRange will save the offset of it's trace w.r.t. position in that tuple. (i.e. manually applying dictionary compression). The above helps with smaller file size. On loading, if we copy each trace to Source as string the runtime memory would still blowup. To mitigate this, we use SourceView directly instead of source which will take the reference of string inside of Deserializer and make that into string_view. This is safe because Deserializer is hold by Unpickler by shared_ptr, and Unpickler is also hold by shared_ptr by another Source object. That Source object will be alive during the model construction. Test Plan: unit test Took original file (312271638_930.predictor.disagg.local); loaded with `torch.jit.load` save again with `torch.jit.save`. Unzip both, look at contents: ``` [qihan@devvm5585.vll0 ~]$ du archive -h 4.0K archive/xl_model_weights 3.7M archive/extra 8.0K archive/code/__torch__/caffe2/torch/fb/model_transform/splitting 8.0K archive/code/__torch__/caffe2/torch/fb/model_transform 8.0K archive/code/__torch__/caffe2/torch/fb 8.0K archive/code/__torch__/caffe2/torch 8.0K archive/code/__torch__/caffe2 20M archive/code/__torch__/torch/fx/graph_module 20M archive/code/__torch__/torch/fx 8.0K archive/code/__torch__/torch/classes 20M archive/code/__torch__/torch 20M archive/code/__torch__ 20M archive/code 2.7M archive/constants 35M archive [qihan@devvm5585.vll0 ~]$ du resaved -h 4.0K resaved/extra 8.0K resaved/code/__torch__/caffe2/torch/fb/model_transform/splitting 8.0K resaved/code/__torch__/caffe2/torch/fb/model_transform 8.0K resaved/code/__torch__/caffe2/torch/fb 8.0K resaved/code/__torch__/caffe2/torch 8.0K resaved/code/__torch__/caffe2 1.3M resaved/code/__torch__/torch/fx/graph_module 1.3M resaved/code/__torch__/torch/fx 8.0K resaved/code/__torch__/torch/classes 1.4M resaved/code/__torch__/torch 1.4M resaved/code/__torch__ 1.4M resaved/code 2.7M resaved/constants 13M resaved [qihan@devvm5585.vll0 ~]$ ``` Reviewed By: gmagogsfm Differential Revision: D34455360 fbshipit-source-id: 8cc716f9bba7183746b1b4ecc33a2de34ac503b9 (cherry picked from commit f1a04730fc9ac8fdab6c8e4c44cb5529e42090e4)
335 lines
10 KiB
C++
335 lines
10 KiB
C++
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/frontend/source_range.h>
|
|
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
// A stringlike class backed by a vector of string_view
|
|
// the string represented are logically the concatenation of the string_views
|
|
// This has advantage of not needing continues memory.
|
|
StringCordView::StringCordView() {
|
|
accumulated_sizes_.push_back(0);
|
|
}
|
|
|
|
StringCordView::StringCordView(
|
|
std::vector<c10::string_view> inputs,
|
|
std::vector<std::shared_ptr<std::string>> ownerships)
|
|
: pieces_(std::move(inputs)), owned_strings_(std::move(ownerships)) {
|
|
accumulated_sizes_.push_back(0);
|
|
size_t running_sum = 0;
|
|
for (auto& s : pieces_) {
|
|
if (s.size() > 0) {
|
|
running_sum += s.size();
|
|
accumulated_sizes_.push_back(running_sum);
|
|
}
|
|
}
|
|
}
|
|
|
|
size_t StringCordView::find(const std::string& tok, size_t start) const {
|
|
if (tok.size() == 0) {
|
|
return 0;
|
|
}
|
|
|
|
if ((size() - start) < tok.size()) {
|
|
return std::string::npos;
|
|
}
|
|
|
|
Iterator begin = iter_for_pos(start);
|
|
Iterator end_iter = end();
|
|
size_t offset = start;
|
|
for (; begin != end_iter; ++begin, ++offset) {
|
|
if (*begin == tok[0]) {
|
|
auto mis = std::mismatch(begin, end_iter, tok.begin(), tok.end());
|
|
if (mis.second == tok.end()) {
|
|
// no mismatch, and second string (tok) is exhausted.
|
|
return offset;
|
|
}
|
|
if (mis.first == end_iter) {
|
|
// this str is exhausted but tok is not
|
|
return std::string::npos;
|
|
}
|
|
}
|
|
}
|
|
return std::string::npos;
|
|
}
|
|
|
|
StringCordView StringCordView::substr(size_t start, size_t size) const {
|
|
std::vector<c10::string_view> pieces;
|
|
std::vector<std::shared_ptr<std::string>> ownerships;
|
|
if (start >= this->size()) {
|
|
// out of bounds
|
|
return StringCordView();
|
|
}
|
|
if (start + size >= this->size()) {
|
|
size = this->size() - start;
|
|
}
|
|
Iterator begin = iter_for_pos(start);
|
|
Iterator end = iter_for_pos(start + size);
|
|
|
|
if (begin.line_ == end.line_) {
|
|
// same line
|
|
pieces.push_back(pieces_[begin.line_].substr(begin.pos_, size));
|
|
} else {
|
|
pieces.push_back(pieces_[begin.line_].substr(begin.pos_));
|
|
|
|
size_t last_line = pieces_.size();
|
|
if (end != this->end() && end.line_ < last_line) {
|
|
// end is within the string
|
|
last_line = end.line_;
|
|
}
|
|
for (size_t i = begin.line_ + 1; i < last_line; i++) {
|
|
pieces.push_back(pieces_[i]);
|
|
}
|
|
if (end != this->end()) {
|
|
pieces.push_back(pieces_[end.line_].substr(0, end.pos_));
|
|
}
|
|
}
|
|
|
|
// share ownership
|
|
std::copy(
|
|
owned_strings_.begin(),
|
|
owned_strings_.end(),
|
|
std::back_inserter(ownerships));
|
|
|
|
return StringCordView(std::move(pieces), std::move(ownerships));
|
|
}
|
|
|
|
bool StringCordView::operator==(const std::string& rhs) {
|
|
if (size() != rhs.size()) {
|
|
return false;
|
|
}
|
|
auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
|
|
// both need to exhaust
|
|
return res.first == end() && res.second == rhs.end();
|
|
}
|
|
|
|
bool StringCordView::operator==(const StringCordView& rhs) {
|
|
if (size() != rhs.size()) {
|
|
return false;
|
|
}
|
|
auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
|
|
// both need to exhaust
|
|
return res.first == end() && res.second == rhs.end();
|
|
}
|
|
|
|
StringCordView::Iterator StringCordView::iter_for_pos(size_t pos) const {
|
|
if (pos == 0) {
|
|
return begin();
|
|
}
|
|
if (pos >= size()) {
|
|
return end();
|
|
}
|
|
auto upper = std::upper_bound(
|
|
accumulated_sizes_.begin(), accumulated_sizes_.end(), pos);
|
|
if (upper == accumulated_sizes_.end()) {
|
|
return end();
|
|
}
|
|
size_t line = upper - accumulated_sizes_.begin() - 1;
|
|
assert(accumulated_sizes_[line] <= pos);
|
|
assert(accumulated_sizes_[line + 1] > pos);
|
|
return Iterator(this, line, pos - accumulated_sizes_[line], size() - pos);
|
|
}
|
|
|
|
size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const {
|
|
return (
|
|
std::hash<uintptr_t>()(reinterpret_cast<uintptr_t>(key.source().get())) ^
|
|
std::hash<size_t>()(key.start()) ^ std::hash<size_t>()(key.end()));
|
|
}
|
|
|
|
c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
|
|
const SourceRange& range) {
|
|
if (!gen_ranges_) {
|
|
return c10::nullopt;
|
|
}
|
|
return gen_ranges_->findSourceRangeThatGenerated(range);
|
|
}
|
|
|
|
C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
|
|
// Retrieve original SourceRange, if present.
|
|
if (auto orig_source_range = findSourceRangeThatGenerated()) {
|
|
orig_source_range->highlight(out);
|
|
out << "Serialized ";
|
|
}
|
|
print_with_context(out, CONTEXT, true, "");
|
|
}
|
|
|
|
C10_EXPORT void format_stack_trace(
|
|
std::ostream& out,
|
|
const std::vector<StackEntry>& entries) {
|
|
bool has_orig_ranges = false;
|
|
std::vector<SourceRange> orig_ranges;
|
|
// gather original ranges. if we have a situation where we do not have orig
|
|
// ranges for some frames, we still want to report them for the frames we do
|
|
// have,
|
|
// so substitute the current range for that frame
|
|
for (const StackEntry& entry : entries) {
|
|
if (auto orig_source_range = entry.range.findSourceRangeThatGenerated()) {
|
|
orig_ranges.emplace_back(std::move(orig_source_range.value()));
|
|
has_orig_ranges = true;
|
|
} else {
|
|
orig_ranges.emplace_back(entry.range);
|
|
}
|
|
}
|
|
out << "Traceback of TorchScript";
|
|
if (has_orig_ranges) {
|
|
out << ", serialized code";
|
|
}
|
|
out << " (most recent call last):\n";
|
|
for (const StackEntry& entry : entries) {
|
|
entry.range.print_with_context(
|
|
out, SourceRange::CONTEXT, true, entry.filename);
|
|
}
|
|
if (has_orig_ranges) {
|
|
out << "\nTraceback of TorchScript, original code (most recent call last):\n";
|
|
auto it = entries.begin();
|
|
for (const SourceRange& range : orig_ranges) {
|
|
range.print_with_context(
|
|
out, SourceRange::CONTEXT, true, (*it++).filename);
|
|
}
|
|
}
|
|
}
|
|
|
|
C10_EXPORT void SourceRange::print_with_context(
|
|
std::ostream& out,
|
|
size_t context,
|
|
bool highlight,
|
|
const std::string& funcname) const {
|
|
// This is an empty SourceRange, used as a sentinel value.
|
|
if (!source_view_) {
|
|
return;
|
|
}
|
|
|
|
auto str = source_view_->text_str().str();
|
|
if (size() == str.size()) {
|
|
// this is just the entire file, not a subset, so print it out.
|
|
// primarily used to print out python stack traces
|
|
out << str;
|
|
return;
|
|
}
|
|
|
|
size_t range_end =
|
|
(str.size() < end()
|
|
? str.size()
|
|
: end()); // use instead of 'end()' because some ranges extend past
|
|
// the length of the source
|
|
|
|
// determine CONTEXT line range
|
|
size_t begin_line = start(); // beginning of lines to highlight
|
|
size_t end_line = range_end;
|
|
while (begin_line > 0 && str[begin_line - 1] != '\n')
|
|
--begin_line;
|
|
while (end_line < str.size() && str[end_line] != '\n')
|
|
++end_line;
|
|
AT_ASSERT(begin_line == 0 || str[begin_line - 1] == '\n');
|
|
AT_ASSERT(end_line == str.size() || str[end_line] == '\n');
|
|
|
|
size_t begin_context = begin_line; // beginning of context, CONTEXT lines
|
|
// before the highlight lines
|
|
for (size_t i = 0; begin_context > 0; --begin_context) {
|
|
if (str[begin_context - 1] == '\n') {
|
|
++i;
|
|
}
|
|
if (i >= context) {
|
|
break;
|
|
}
|
|
}
|
|
AT_ASSERT(begin_context == 0 || str[begin_context - 1] == '\n');
|
|
|
|
size_t end_context =
|
|
end_line; // end of context, CONTEXT lines after the highlight lines
|
|
for (size_t i = 0; end_context < str.size(); ++end_context) {
|
|
if (str[end_context] == '\n') {
|
|
++i;
|
|
}
|
|
if (i >= context) {
|
|
break;
|
|
}
|
|
}
|
|
AT_ASSERT(end_context == str.size() || str[end_context] == '\n');
|
|
|
|
// print out location information
|
|
if (auto flc = file_line_col()) {
|
|
std::string filename;
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
size_t line, col;
|
|
std::tie(filename, line, col) = *flc;
|
|
out << " File \"" << filename << "\", line " << line;
|
|
if (funcname != "") {
|
|
out << ", in " << funcname;
|
|
}
|
|
out << "\n";
|
|
}
|
|
// print out inital context
|
|
out << str.substr(begin_context, start() - begin_context);
|
|
size_t line_start = start();
|
|
size_t line_end = range_end;
|
|
if (highlight) {
|
|
line_end = start();
|
|
while (line_start < range_end) {
|
|
// move line_end to end of line
|
|
while (line_end < str.size() && str[line_end] != '\n') {
|
|
++line_end;
|
|
}
|
|
// print line of code
|
|
auto actual_line = str.substr(line_start, (line_end - line_start) + 1);
|
|
out << actual_line;
|
|
if (actual_line.back() != '\n') {
|
|
out << "\n";
|
|
}
|
|
|
|
size_t empty_space = 0;
|
|
size_t highlight_space = 0;
|
|
size_t hightlight_begin = line_start;
|
|
size_t highlight_end = line_start;
|
|
// determine length of line which is being highlighted
|
|
while (hightlight_begin > 0 && str[hightlight_begin - 1] != '\n') {
|
|
--hightlight_begin;
|
|
}
|
|
while (highlight_end < range_end && str[highlight_end] != '\n') {
|
|
++highlight_end;
|
|
}
|
|
AT_ASSERT(hightlight_begin == 0 || str[hightlight_begin - 1] == '\n');
|
|
AT_ASSERT(highlight_end == range_end || str[highlight_end] == '\n');
|
|
// determine amount of empty space vs highlighted space
|
|
for (const auto i : c10::irange(hightlight_begin, highlight_end)) {
|
|
if (str[i] == ' ' || i < start()) {
|
|
empty_space++;
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
highlight_space = highlight_end - hightlight_begin - empty_space;
|
|
if (highlight_space > 0) {
|
|
// some ranges are off and include empty white space on new lines which
|
|
// don't need to be printed
|
|
bool more_lines = false;
|
|
for (size_t i = line_end; i <= range_end; i++) {
|
|
if (str[i] != '\n' && str[i] != ' ') {
|
|
more_lines = true;
|
|
}
|
|
}
|
|
out << std::string(empty_space, ' ');
|
|
out << std::string(highlight_space, '~');
|
|
out << (more_lines && line_end != range_end ? "\n" : " <--- HERE\n");
|
|
}
|
|
++line_end;
|
|
line_start = line_end;
|
|
}
|
|
} else {
|
|
// print out code with no highlight
|
|
out << str.substr(start(), range_end - start());
|
|
}
|
|
// print out ending context
|
|
if (line_end <= str.size()) {
|
|
auto line_substr = str.substr(line_end, end_context - line_end);
|
|
out << line_substr;
|
|
if (!line_substr.empty() && line_substr.back() != '\n') {
|
|
out << "\n";
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|