Revert D33994011: Make debug_pkl smaller by only emitting unique traces.

Test Plan: revert-hammer

Differential Revision:
D33994011 (3d37f5b052)

Original commit changeset: 8e6224c6e942

Original Phabricator Diff: D33994011 (3d37f5b052)

fbshipit-source-id: 885e739efa1081382e1fcf9c6cccba92c57e9f7a
(cherry picked from commit a6d98c85a736c2eb321a6f38005dd0f5dc43eb87)
This commit is contained in:
Alban Desmaison
2022-02-24 07:47:48 -08:00
committed by PyTorch MergeBot
parent 5772b1afbc
commit 3bd1507ff2
24 changed files with 203 additions and 576 deletions

View File

@ -1,32 +0,0 @@
#include <gtest/gtest.h>
#include <torch/csrc/jit/frontend/source_range.h>
using namespace ::testing;
using namespace ::torch::jit;
TEST(SourceRangeTest, test_find) {
std::vector<std::shared_ptr<std::string>> strings;
strings.push_back(std::make_shared<std::string>("hello world"));
strings.push_back(std::make_shared<std::string>("nihaoma"));
std::vector<c10::string_view> pieces{*strings[0], *strings[1]};
StringCordView view(pieces, strings);
auto x = view.find("rldni", 0);
EXPECT_EQ(x, 8);
}
TEST(SourceRangeTest, test_substr) {
std::vector<std::shared_ptr<std::string>> strings;
strings.push_back(std::make_shared<std::string>("hello world"));
strings.push_back(std::make_shared<std::string>("nihaoma"));
std::vector<c10::string_view> pieces{*strings[0], *strings[1]};
StringCordView view(pieces, strings);
auto x = view.substr(4, 10).str();
EXPECT_EQ(x, view.str().substr(4, 10));
EXPECT_EQ(view.substr(0, view.size()).str(), view.str());
}

View File

@ -38,18 +38,16 @@ static inline void trim(std::string& s) {
}
} // namespace
#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
try { \
(void)statement; \
FAIL(); \
} catch (const std::exception& e) { \
std::string substring_s(substring); \
trim(substring_s); \
auto exception_string = std::string(e.what()); \
trim(exception_string); \
ASSERT_NE(exception_string.find(substring_s), std::string::npos) \
<< " Error was: \n" \
<< exception_string; \
#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
try { \
(void)statement; \
FAIL(); \
} catch (const std::exception& e) { \
std::string substring_s(substring); \
trim(substring_s); \
auto exception_string = std::string(e.what()); \
trim(exception_string); \
ASSERT_NE(exception_string.find(substring_s), std::string::npos); \
}
namespace torch {

View File

@ -4453,8 +4453,7 @@ def foo(xyz):
return list(debug_files)
debug_files = debug_records_from_mod(ft3)
for dfile in debug_files:
_, table, debug_file = dfile
for debug_file in debug_files:
for i in range(len(debug_file) - 1):
offset, source_range_tag, source_range = debug_file[i]
offset2, source_range_tag2, source_range2 = debug_file[i + 1]

View File

@ -2,7 +2,6 @@
#include <ATen/core/Reduction.h>
#include <ATen/core/type_factory.h>
#include <c10/util/Optional.h>
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/frontend/lexer.h>
#include <torch/csrc/jit/frontend/parse_string_literal.h>
@ -28,13 +27,8 @@ namespace jit {
namespace {
struct SchemaParser {
explicit SchemaParser(const std::string& str)
: L(std::make_shared<Source>(
c10::string_view(str),
c10::nullopt,
0,
nullptr,
Source::DONT_COPY)),
SchemaParser(const std::string& str)
: L(std::make_shared<SourceView>(c10::string_view(str))),
type_parser(L, /*parse_complete_tensor_types*/ false) {}
either<OperatorName, FunctionSchema> parseDeclaration() {

View File

@ -190,7 +190,7 @@ struct TORCH_API SharedParserData {
// find the longest match of str.substring(pos) against a token, return true
// if successful filling in kind, start,and len
bool match(
StringCordView str,
c10::string_view str,
size_t pos,
bool continuation, // are we inside a scope where newlines don't count
// (e.g. inside parens)
@ -241,12 +241,12 @@ struct TORCH_API SharedParserData {
// invariant: the next token is not whitespace or newline
*start = pos;
// check for a valid number
if (isNumber(str.piece(0), pos, len)) {
if (isNumber(str, pos, len)) {
*kind = TK_NUMBER;
return true;
}
// check for string
if (isString(str.piece(0), pos, len)) {
if (isString(str, pos, len)) {
*kind = TK_STRINGLITERAL;
return true;
}
@ -368,7 +368,7 @@ struct TORCH_API SharedParserData {
return isspace(n) && n != '\n';
}
// Make an exception ignoring comments for type annotation comments
bool isTypeComment(StringCordView str, size_t pos) {
bool isTypeComment(c10::string_view str, size_t pos) {
const std::string type_string = "# type:";
if (str.size() < pos + type_string.length()) {
return false;
@ -387,7 +387,7 @@ struct Token {
SourceRange range;
Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {}
std::string text() {
return range.text().str();
return range.text();
}
std::string kindString() const {
return kindToString(kind);
@ -395,7 +395,7 @@ struct Token {
};
struct Lexer {
explicit Lexer(std::shared_ptr<Source> source)
explicit Lexer(std::shared_ptr<SourceView> source)
: source(std::move(source)),
pos(0),
nesting(0),
@ -518,19 +518,25 @@ struct Lexer {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t length;
AT_ASSERT(source);
auto src = source->text_str();
if (!shared.match(
src, pos, nesting > 0, whitespace_token, &kind, &start, &length)) {
source->text(),
pos,
nesting > 0,
whitespace_token,
&kind,
&start,
&length)) {
expected(
"a valid token",
Token(source->char_at(start), SourceRange(source, start, start + 1)));
Token(
(source->text())[start], SourceRange(source, start, start + 1)));
}
auto t = Token(kind, SourceRange(source, start, start + length));
pos = start + length;
return t;
}
std::shared_ptr<Source> source;
std::shared_ptr<SourceView> source;
size_t pos;
size_t nesting; // depth of ( [ { nesting...
std::vector<int> indent_stack; // stack of indentation level of blocks

View File

@ -46,7 +46,7 @@ Decl mergeTypesFromTypeComment(
}
struct ParserImpl {
explicit ParserImpl(const std::shared_ptr<Source>& source)
explicit ParserImpl(const std::shared_ptr<SourceView>& source)
: L(source), shared(sharedParserData()) {}
Ident parseIdent() {
@ -801,7 +801,7 @@ struct ParserImpl {
SharedParserData& shared;
};
Parser::Parser(const std::shared_ptr<Source>& src)
Parser::Parser(const std::shared_ptr<SourceView>& src)
: pImpl(new ParserImpl(src)) {}
Parser::~Parser() = default;

View File

@ -17,7 +17,7 @@ TORCH_API Decl mergeTypesFromTypeComment(
bool is_method);
struct TORCH_API Parser {
explicit Parser(const std::shared_ptr<Source>& src);
explicit Parser(const std::shared_ptr<SourceView>& src);
TreeRef parseFunction(bool is_method);
TreeRef parseClass();
Decl parseTypeComment();

View File

@ -227,7 +227,7 @@ TypePtr ScriptTypeParser::parseTypeFromExpr(const Expr& expr) const {
// expression and base type names.
if (resolver_) {
if (auto typePtr =
resolver_->resolveType(expr.range().text().str(), expr.range())) {
resolver_->resolveType(expr.range().text(), expr.range())) {
return typePtr;
}
}

View File

@ -4,140 +4,13 @@
namespace torch {
namespace jit {
// A stringlike class backed by a vector of string_view
// the string represented are logically the concatenation of the string_views
// This has advantage of not needing continues memory.
StringCordView::StringCordView() {
accumulated_sizes_.push_back(0);
}
StringCordView::StringCordView(
std::vector<c10::string_view> inputs,
std::vector<std::shared_ptr<std::string>> ownerships)
: pieces_(std::move(inputs)), owned_strings_(std::move(ownerships)) {
accumulated_sizes_.push_back(0);
size_t running_sum = 0;
for (auto& s : pieces_) {
if (s.size() > 0) {
running_sum += s.size();
accumulated_sizes_.push_back(running_sum);
}
}
}
size_t StringCordView::find(const std::string& tok, size_t start) const {
if (tok.size() == 0) {
return 0;
}
if ((size() - start) < tok.size()) {
return std::string::npos;
}
Iterator begin = iter_for_pos(start);
Iterator end_iter = end();
size_t offset = start;
for (; begin != end_iter; ++begin, ++offset) {
if (*begin == tok[0]) {
auto mis = std::mismatch(begin, end_iter, tok.begin(), tok.end());
if (mis.second == tok.end()) {
// no mismatch, and second string (tok) is exhausted.
return offset;
}
if (mis.first == end_iter) {
// this str is exhausted but tok is not
return std::string::npos;
}
}
}
return std::string::npos;
}
StringCordView StringCordView::substr(size_t start, size_t size) const {
std::vector<c10::string_view> pieces;
std::vector<std::shared_ptr<std::string>> ownerships;
if (start >= this->size()) {
// out of bounds
return StringCordView();
}
if (start + size >= this->size()) {
size = this->size() - start;
}
Iterator begin = iter_for_pos(start);
Iterator end = iter_for_pos(start + size);
if (begin.line_ == end.line_) {
// same line
pieces.push_back(pieces_[begin.line_].substr(begin.pos_, size));
} else {
pieces.push_back(pieces_[begin.line_].substr(begin.pos_));
size_t last_line = pieces_.size();
if (end != this->end() && end.line_ < last_line) {
// end is within the string
last_line = end.line_;
}
for (size_t i = begin.line_ + 1; i < last_line; i++) {
pieces.push_back(pieces_[i]);
}
if (end != this->end()) {
pieces.push_back(pieces_[end.line_].substr(0, end.pos_));
}
}
// share ownership
std::copy(
owned_strings_.begin(),
owned_strings_.end(),
std::back_inserter(ownerships));
return StringCordView(std::move(pieces), std::move(ownerships));
}
bool StringCordView::operator==(const std::string& rhs) {
if (size() != rhs.size()) {
return false;
}
auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
// both need to exhaust
return res.first == end() && res.second == rhs.end();
}
bool StringCordView::operator==(const StringCordView& rhs) {
if (size() != rhs.size()) {
return false;
}
auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
// both need to exhaust
return res.first == end() && res.second == rhs.end();
}
StringCordView::Iterator StringCordView::iter_for_pos(size_t pos) const {
if (pos == 0) {
return begin();
}
if (pos >= size()) {
return end();
}
auto upper = std::upper_bound(
accumulated_sizes_.begin(), accumulated_sizes_.end(), pos);
if (upper == accumulated_sizes_.end()) {
return end();
}
size_t line = upper - accumulated_sizes_.begin() - 1;
assert(accumulated_sizes_[line] <= pos);
assert(accumulated_sizes_[line + 1] > pos);
return Iterator(this, line, pos - accumulated_sizes_[line], size() - pos);
}
size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const {
return (
std::hash<uintptr_t>()(reinterpret_cast<uintptr_t>(key.source().get())) ^
std::hash<size_t>()(key.start()) ^ std::hash<size_t>()(key.end()));
}
c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
c10::optional<SourceRange> SourceView::findSourceRangeThatGenerated(
const SourceRange& range) {
if (!gen_ranges_) {
return c10::nullopt;
@ -200,7 +73,7 @@ C10_EXPORT void SourceRange::print_with_context(
return;
}
auto str = source_view_->text_str().str();
c10::string_view str = source_view_->text();
if (size() == str.size()) {
// this is just the entire file, not a subset, so print it out.
// primarily used to print out python stack traces
@ -268,7 +141,7 @@ C10_EXPORT void SourceRange::print_with_context(
line_end = start();
while (line_start < range_end) {
// move line_end to end of line
while (line_end < str.size() && str[line_end] != '\n') {
while (str[line_end] != '\n' && line_end < str.size()) {
++line_end;
}
// print line of code

View File

@ -4,172 +4,43 @@
#include <algorithm>
#include <iostream>
#include <iterator>
#include <memory>
#include <numeric>
#include <unordered_map>
namespace torch {
namespace jit {
class SourceRangeUnpickler;
struct SourceRange;
// A stringlike class backed by a vector of string_view
// the string represented are logically the concatenation of the string_views
// This has advantage of not needing continues memory.
struct TORCH_API StringCordView {
StringCordView();
StringCordView(
std::vector<c10::string_view> inputs,
std::vector<std::shared_ptr<std::string>> ownerships);
size_t size() const {
return accumulated_sizes_.back();
}
size_t find(const std::string& tok, size_t start) const;
StringCordView substr(size_t start, size_t size) const;
char at(size_t index) const {
return *iter_for_pos(index);
}
char operator[](size_t index) const {
return at(index);
}
std::string str() const {
std::stringstream ss;
for (auto s : pieces_) {
ss << std::string(s);
}
return ss.str();
}
bool operator==(const std::string& rhs);
bool operator==(const StringCordView& rhs);
c10::string_view piece(size_t index) const {
return pieces_[index];
}
struct Iterator {
Iterator(
const StringCordView* str,
size_t start_line,
size_t start_pos,
size_t size)
: line_(start_line), pos_(start_pos), str_(str), size_(size) {}
explicit Iterator(const StringCordView* str)
: Iterator(str, 0, 0, str->size()) {}
Iterator(const Iterator&) = default;
Iterator(Iterator&&) = default;
Iterator& operator=(const Iterator&) = default;
Iterator& operator=(Iterator&&) = default;
Iterator operator++() {
if (size_ == 0) {
return *this;
}
if ((pos_ + 1) < str_->pieces_[line_].size()) {
pos_++;
} else {
line_++;
pos_ = 0;
}
return *this;
}
Iterator operator++(int) {
Iterator prev(*this);
++(*this);
return prev;
}
bool operator==(const Iterator& rhs) const {
if (!has_next() && !rhs.has_next()) {
return true;
}
return (str_ == rhs.str_) && (line_ == rhs.line_) && (pos_ == rhs.pos_);
}
bool operator!=(const Iterator& rhs) {
return !((*this) == rhs);
}
bool has_next() const {
return size_ > 0 && (line_ < str_->pieces_.size());
}
char operator*() const {
TORCH_INTERNAL_ASSERT(line_ < str_->pieces_.size());
TORCH_INTERNAL_ASSERT(pos_ < str_->pieces_[line_].size());
return str_->pieces_[line_].at(pos_);
}
private:
size_t line_;
size_t pos_;
const StringCordView* str_;
size_t size_;
friend struct StringCordView;
};
Iterator begin() const {
return Iterator(this, 0, 0, size());
}
Iterator end() const {
return Iterator(this, pieces_.size(), 0, 0);
}
private:
Iterator iter_for_pos(size_t pos) const;
std::vector<c10::string_view> pieces_;
std::vector<size_t> accumulated_sizes_;
std::vector<std::shared_ptr<std::string>> owned_strings_;
};
// Source represents a code segment. It keeps track of:
// SourceView represents a code segment. It keeps track of:
// - text_view : the view into text of the code segment
// - filename (optional) : if present, represents the name of the file from
// which the code segment originated.
// - starting_line_no : represents the line in the original file where the
// code segment started.
struct TORCH_API Source {
// Whether or not Source should copy the string passed in the constructor.
enum CopiesString { COPIES_STRING, DONT_COPY };
explicit Source(
struct SourceView {
explicit SourceView(
c10::string_view text_view,
c10::optional<std::string> filename = c10::nullopt,
size_t starting_line_no = 0,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr,
CopiesString copies_str = COPIES_STRING)
: filename_(std::move(filename)),
starting_line_no_(starting_line_no),
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: text_view_(text_view),
filename_(c10::nullopt),
starting_line_no_(0),
gen_ranges_(std::move(gen_ranges)) {
if (copies_str == COPIES_STRING) {
std::shared_ptr<std::string> allocated_str =
std::make_shared<std::string>(text_view.data(), text_view.size());
text_view_ = StringCordView({*allocated_str}, {allocated_str});
} else {
text_view_ = StringCordView({text_view}, {});
}
calc_line_start_offsets();
}
Source(
StringCordView str,
c10::optional<std::string> filename = c10::nullopt,
size_t starting_line_no = 0,
SourceView(
c10::string_view text_view,
c10::optional<std::string> filename,
size_t starting_line_no,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: text_view_(str),
: text_view_(text_view),
filename_(std::move(filename)),
starting_line_no_(starting_line_no),
gen_ranges_(std::move(gen_ranges)) {
calc_line_start_offsets();
}
// Given a line number (within source_), return the byte offset of the
// beginning of that line.
size_t offset_for_line(size_t line) const {
@ -183,9 +54,11 @@ struct TORCH_API Source {
// Calculate the line (within the code segment) on which `offset` resides.
size_t lineno_for_offset(size_t offset) const {
auto iter = std::upper_bound(
line_starting_offsets_.begin(), line_starting_offsets_.end(), offset);
return iter - line_starting_offsets_.begin() - 1;
return std::upper_bound(
line_starting_offsets_.begin(),
line_starting_offsets_.end(),
offset) -
line_starting_offsets_.begin() - 1;
}
// Calculate the line (within the original source file, if present) on which
@ -198,27 +71,11 @@ struct TORCH_API Source {
}
}
StringCordView get_line(size_t lineno) const {
auto start = offset_for_line(lineno);
auto size = (lineno + 1) < num_lines() ? offset_for_line(lineno + 1) - start
: text_view_.size() - start;
return text_view_.substr(start, size);
}
// Note: this makes a copy
StringCordView text_str() const {
const c10::string_view text() const {
return text_view_;
}
char char_at(size_t index) const {
return text_view_.at(index);
}
size_t size() const {
return text_view_.size();
}
c10::optional<std::string>& filename() {
const c10::optional<std::string>& filename() const {
return filename_;
}
@ -229,20 +86,18 @@ struct TORCH_API Source {
c10::optional<SourceRange> findSourceRangeThatGenerated(
const SourceRange& range);
~Source() = default;
protected:
c10::string_view text_view_;
private:
void calc_line_start_offsets() {
line_starting_offsets_.clear();
line_starting_offsets_.push_back(0);
size_t pos = 0;
while ((pos = text_view_.find("\n", pos)) != std::string::npos) {
while ((pos = text().find('\n', pos)) != std::string::npos) {
line_starting_offsets_.push_back(++pos);
}
}
StringCordView text_view_;
c10::optional<std::string> filename_;
// If filename_ is not present, starting_line_no_ is don't care
size_t starting_line_no_;
@ -253,15 +108,67 @@ struct TORCH_API Source {
std::shared_ptr<SourceRangeUnpickler> gen_ranges_;
};
// Source represents a code segment like SourceView, but the former owns a copy
// of source text while the latter doesn't.
struct Source : public SourceView {
explicit Source(
std::string text,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: SourceView(text, gen_ranges), text_(std::move(text)) {
text_view_ = text_;
}
explicit Source(
c10::string_view text_view,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: SourceView(text_view, gen_ranges),
text_(text_view.begin(), text_view.end()) {
text_view_ = text_;
}
explicit Source(
std::string text,
c10::optional<std::string> filename,
size_t starting_line_no,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: SourceView(text, filename, starting_line_no, gen_ranges),
text_(std::move(text)) {
text_view_ = text_;
}
explicit Source(
c10::string_view text_view,
c10::optional<std::string> filename,
size_t starting_line_no,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: SourceView(text_view, filename, starting_line_no, gen_ranges),
text_(text_view.begin(), text_view.end()) {
text_view_ = text_;
}
// Constructor that deepcopies and owns source text referenced in
// `source_view`.
explicit Source(const SourceView& source_view) : SourceView(source_view) {
text_ = std::string(text_view_.begin(), text_view_.end());
text_view_ = text_;
}
std::string text_;
};
// A SourceRange is a reference to subset of a Source, specified by `start` and
// `end` byte offsets into the source text.
struct TORCH_API SourceRange {
SourceRange(std::shared_ptr<Source> source_view_, size_t start_, size_t end_)
SourceRange(
std::shared_ptr<SourceView> source_view_,
size_t start_,
size_t end_)
: source_view_(std::move(source_view_)), start_(start_), end_(end_) {}
SourceRange() : source_view_(nullptr), start_(0), end_(0) {}
const StringCordView text() const {
return source_view_->text_str().substr(start(), end() - start());
const std::string text() const {
auto text_view = source_view_->text().substr(start(), end() - start());
return std::string(text_view.begin(), text_view.end());
}
size_t size() const {
return end() - start();
@ -276,7 +183,7 @@ struct TORCH_API SourceRange {
bool highlight,
const std::string& funcname) const;
const std::shared_ptr<Source>& source() const {
const std::shared_ptr<SourceView>& source() const {
return source_view_;
}
size_t start() const {
@ -322,7 +229,7 @@ struct TORCH_API SourceRange {
}
protected:
std::shared_ptr<Source> source_view_;
std::shared_ptr<SourceView> source_view_;
private:
size_t start_;
@ -330,16 +237,13 @@ struct TORCH_API SourceRange {
};
// OwnedSourceRange is just like a SourceRange except that it owns a `Source`
// instead of `Source`. Thus OwnedSourceRange owns a copy of source text.
// instead of `SourceView`. Thus OwnedSourceRange owns a copy of source text.
struct OwnedSourceRange : public SourceRange {
explicit OwnedSourceRange(const SourceRange& source_range)
OwnedSourceRange(const SourceRange& source_range)
: SourceRange(source_range) {
const auto& source = source_range.source();
if (source) {
source_view_ = std::make_shared<Source>(
source->text_str().str(),
source->filename(),
source->starting_line_no());
source_view_ = std::make_shared<Source>(*source);
}
}
};
@ -377,14 +281,3 @@ using SourceRangeTagMap =
} // namespace jit
} // namespace torch
namespace std {
template <>
struct iterator_traits<torch::jit::StringCordView::Iterator> {
using value_type = char;
using difference_type = ptrdiff_t;
using pointer = char*;
using reference = char&;
using iterator_category = std::forward_iterator_tag;
};
} // namespace std

View File

@ -21,26 +21,26 @@ namespace jit {
*/
class TORCH_API SourceRef : public CustomClassHolder {
public:
explicit SourceRef(std::shared_ptr<Source> source_view)
explicit SourceRef(std::shared_ptr<SourceView> source_view)
: source_view_(std::move(source_view)) {}
bool operator==(const SourceRef& other) const {
return source_view_ == other.source_view_;
}
bool operator<(const Source& other) const {
bool operator<(const SourceView& other) const {
return source_view_.get() < &other;
}
friend bool operator<(const Source& other, const SourceRef& self) {
friend bool operator<(const SourceView& other, const SourceRef& self) {
return &other < self.source_view_.get();
}
bool operator<(const SourceRef& other) const {
return *this < *other.source_view_.get();
}
const Source* operator->() const {
const SourceView* operator->() const {
return source_view_.get();
}
private:
std::shared_ptr<Source> source_view_;
std::shared_ptr<SourceView> source_view_;
};
} // namespace jit

View File

@ -122,26 +122,17 @@ MobileDebugTable::MobileDebugTable(
at::DataPtr debug_data;
size_t debug_size{0};
std::tie(debug_data, debug_size) = reader->getRecord(record_name);
auto ivalueTuple = jit::unpickle(
reinterpret_cast<const char*>(debug_data.get()),
debug_size,
nullptr,
{},
c10::parseType);
const auto& ivalues = ivalueTuple.toTuple()->elements();
IValue lines;
std::unique_ptr<SourceRangeDeserializer> deserializer;
if (ivalues.size() == 3 && ivalues[0].isString() &&
kFormatWithStringTable == ivalues[0].toStringRef()) {
// new format
deserializer = std::make_unique<SourceRangeDeserializer>(ivalues[1]);
lines = ivalues[2];
} else {
deserializer = std::make_unique<SourceRangeDeserializer>();
lines = ivalueTuple;
}
for (auto& val : lines.toTuple()->elements()) {
auto ivalues =
std::move(*jit::unpickle(
reinterpret_cast<const char*>(debug_data.get()),
debug_size,
nullptr,
{},
c10::parseType)
.toTuple())
.elements();
SourceRangeDeserializer deserializer;
for (auto& val : ivalues) {
auto tup_elems = std::move(*std::move(val).toTuple()).elements();
// For BC we decode only tuples with 3 elements
// assuming it contains
@ -149,7 +140,7 @@ MobileDebugTable::MobileDebugTable(
if (tup_elems.size() == 3) {
int64_t debug_handle = tup_elems[kSourceRangeTagIndex].toInt();
auto source_range =
deserializer->deserialize(tup_elems[kSourceRangeIndex]);
deserializer.deserialize(tup_elems[kSourceRangeIndex]);
source_range_map.emplace(debug_handle, std::move(source_range));
}
}

View File

@ -3,7 +3,6 @@
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/ir/scope.h>
#include <torch/csrc/jit/serialization/source_range_serialization.h>
namespace torch {
namespace jit {

View File

@ -7,7 +7,7 @@ bool IndexingPatternFinder::IsSameSource(const Node* n, const Node* m) {
const auto source_n = n->sourceRange().source();
const auto source_m = m->sourceRange().source();
return (
(source_n->text_str() == source_m->text_str()) &&
(source_n->text() == source_m->text()) &&
(source_n->starting_line_no() == source_m->starting_line_no()));
}

View File

@ -104,8 +104,9 @@ void initTreeViewBindings(PyObject* module) {
return SourceRange(self.source_, start, end);
})
.def_property_readonly("source", [](const SourceRangeFactory& self) {
auto text_view = self.source_->text_str().str();
return text_view;
auto text_view = self.source_->text();
std::string text(text_view.begin(), text_view.end());
return text;
});
py::class_<TreeView>(m, "TreeView")

View File

@ -61,7 +61,7 @@ auto initBindings() {
return static_cast<int64_t>((*self)->starting_line_no());
})
.def("text", [](const c10::intrusive_ptr<SourceRef>& self) {
return (*self)->text_str().str();
return (*self)->text();
});
torch::class_<InstructionStats>("profiling", "InstructionStats")

View File

@ -22,7 +22,7 @@ std::string qualifierToArchivePath(
return export_prefix + path + "." + kExportSuffix;
}
std::shared_ptr<Source> findSourceInArchiveFromQualifier(
std::shared_ptr<SourceView> findSourceInArchiveFromQualifier(
caffe2::serialize::PyTorchStreamReader& reader,
const std::string& export_prefix,
const std::string& qualifier) {

View File

@ -12,7 +12,7 @@ class PyTorchStreamReader;
namespace torch {
namespace jit {
struct Source;
struct SourceView;
// Convert a class type's qualifier name to the corresponding path the source
// file it should be written to.
@ -23,7 +23,7 @@ std::string qualifierToArchivePath(
const std::string& qualifier,
const std::string& export_prefix);
std::shared_ptr<Source> findSourceInArchiveFromQualifier(
std::shared_ptr<SourceView> findSourceInArchiveFromQualifier(
caffe2::serialize::PyTorchStreamReader& reader,
const std::string& export_prefix,
const std::string& qualifier);

View File

@ -159,7 +159,7 @@ void SourceImporterImpl::parseSourceIfNeeded(const std::string& qualifier) {
return;
}
loaded_sources_.insert(qualifier);
std::shared_ptr<Source> src = source_loader_(qualifier);
std::shared_ptr<SourceView> src = source_loader_(qualifier);
// The importer, when looking for classes/functions doesn't know if 'foo'
// contains definitions or if it is a prefix of 'foo.bar', we only figure it

View File

@ -20,7 +20,8 @@
namespace torch {
namespace jit {
using SourceLoader = std::function<std::shared_ptr<Source>(const std::string&)>;
using SourceLoader =
std::function<std::shared_ptr<SourceView>(const std::string&)>;
struct SourceImporterImpl : public Resolver,
std::enable_shared_from_this<SourceImporterImpl> {

View File

@ -1,10 +1,8 @@
#include <torch/csrc/jit/serialization/source_range_serialization.h>
#include <torch/csrc/jit/serialization/source_range_serialization_impl.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <algorithm>
namespace torch {
namespace jit {
@ -12,80 +10,45 @@ namespace jit {
class SourceRangeSerializer {
public:
// Serialize SourceRange as Tuple[SourceType, int, int]
// where SourceType = Tuple[int, int, int, List[int]],
// The first 2 ints are positions into the vector returned by textSaved
// after all the Ranges are processed. textSaved() returns a vector of str
// where SourceType = Tuple[str, Optional[str], int, List[int]],
// the serialized form of Source
c10::IValue serialize(const SourceRange& sr);
const std::vector<c10::IValue>& texts_saved() {
return texts_;
}
SourceRangeSerializer() {
texts_.emplace_back("");
text_to_idx_[texts_.back().toStringRef()] = 0;
}
private:
// Serialize Source as Tuple[str, Optional[str], int, List[int]]
// This caches serialized sources, since many SourceRanges can
// refer to the same one.
c10::IValue serialize_source(const std::shared_ptr<Source>& s);
std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
c10::IValue serialize_source(const std::shared_ptr<SourceView>& s);
int64_t store_text_and_get_index(const std::string& text_view);
std::vector<c10::IValue> texts_;
std::unordered_map<c10::string_view, int64_t> text_to_idx_;
std::unordered_map<std::shared_ptr<SourceView>, c10::IValue>
serialized_sources;
};
SourceRange SourceRangeDeserializer::deserialize(const c10::IValue& iv) {
const auto& tup_elems = iv.toTupleRef().elements();
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]);
std::shared_ptr<SourceView> source_ = deserialize_source(tup_elems[0]);
int64_t start_ = tup_elems[1].toInt();
int64_t end_ = tup_elems[2].toInt();
return SourceRange(source_, start_, end_);
}
std::shared_ptr<Source> SourceRangeDeserializer::deserialize_source(
std::shared_ptr<SourceView> SourceRangeDeserializer::deserialize_source(
const c10::IValue& iv) {
auto tup = iv.toTuple();
auto it = cached_sources.find(tup);
if (it != cached_sources.end()) {
return it->second;
}
std::shared_ptr<Source> source;
const auto& tup_elems = tup->elements();
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
if (!text_table_.empty()) {
const auto& textIndex = tup_elems[0].toIntList();
int64_t fnameIndex = tup_elems[1].toInt();
int64_t starting_line_no_ = tup_elems[2].toInt();
c10::optional<std::string> filename = c10::nullopt;
std::string text_ = tup_elems[0].toString()->string();
c10::optional<std::string> filename_ = tup_elems[1].toOptional<std::string>();
int64_t starting_line_no_ = tup_elems[2].toInt();
filename = *text_table_[fnameIndex];
std::vector<c10::string_view> pieces;
std::vector<std::shared_ptr<std::string>> strs;
for (int64_t i : textIndex) {
pieces.emplace_back(*text_table_[i]);
strs.emplace_back(text_table_[i]);
}
StringCordView str_cord(std::move(pieces), std::move(strs));
source = std::make_shared<Source>(str_cord, filename, starting_line_no_);
} else {
std::string text_ = tup_elems[0].toString()->string();
c10::optional<std::string> filename_ =
tup_elems[1].toOptional<std::string>();
int64_t starting_line_no_ = tup_elems[2].toInt();
source = std::make_shared<Source>(
std::move(text_), std::move(filename_), starting_line_no_);
}
auto source = std::make_shared<Source>(
std::move(text_), std::move(filename_), starting_line_no_);
cached_sources[tup] = source;
return source;
}
@ -95,41 +58,17 @@ c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) {
serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end());
}
int64_t SourceRangeSerializer::store_text_and_get_index(
const std::string& text_view) {
auto text_iter = text_to_idx_.find(text_view);
if (text_iter == text_to_idx_.end()) {
int64_t text_pos = static_cast<int64_t>(texts_.size());
texts_.emplace_back(text_view);
text_to_idx_[texts_.back().toStringView()] = text_pos;
return text_pos;
} else {
return text_iter->second;
}
}
c10::IValue SourceRangeSerializer::serialize_source(
const std::shared_ptr<Source>& s) {
const std::shared_ptr<SourceView>& s) {
if (serialized_sources.count(s)) {
return serialized_sources.at(s);
}
c10::intrusive_ptr<c10::ivalue::Tuple> serialized;
c10::List<int64_t> lines;
if (s == nullptr) {
serialized = c10::ivalue::Tuple::create({lines, 0, 0});
serialized = c10::ivalue::Tuple::create({"", "", 0});
} else {
for (size_t lineno = 0; lineno < s->num_lines(); lineno++) {
std::string line_content = s->get_line(lineno).str();
int64_t text_pos = store_text_and_get_index(line_content);
lines.push_back(text_pos);
}
int64_t fname_pos = 0;
if (s->filename().has_value()) {
fname_pos = store_text_and_get_index(*s->filename());
}
serialized = c10::ivalue::Tuple::create(
{lines, fname_pos, (int64_t)s->starting_line_no()});
{s->text(), s->filename(), (int64_t)s->starting_line_no()});
}
serialized_sources[s] = serialized;
return serialized;
@ -147,19 +86,14 @@ std::vector<char> SourceRangePickler::pickle(
if (it != source_range_tags.end()) {
source_range_tag = it->second;
}
ivalues.emplace_back(c10::ivalue::Tuple::create(
{(int64_t)range.bytes,
srs->serialize(range.range),
static_cast<int64_t>(source_range_tag)}));
}
std::vector<at::Tensor> table;
auto textTable = c10::ivalue::Tuple::create(srs->texts_saved());
auto ivalue = c10::ivalue::Tuple::create(std::move(ivalues));
auto result = jit::pickle(
c10::ivalue::Tuple::create({kFormatWithStringTable, textTable, ivalue}),
&table);
auto result = jit::pickle(ivalue, &table);
TORCH_CHECK(table.size() == 0, "Expected 0 tensors to be written");
return result;
}
@ -169,7 +103,7 @@ ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
size_t size)
: data(std::move(data)),
size(size),
deserializer(nullptr),
deserializer(new SourceRangeDeserializer()),
unpickled_records(nullptr) {}
void ConcreteSourceRangeUnpickler::unpickle() {
@ -185,19 +119,10 @@ void ConcreteSourceRangeUnpickler::unpickle() {
{},
c10::parseType)
.toTuple();
const auto& ivalues = ivaluesTuple->elements();
unpickled_records = std::make_shared<SourceRangeRecords>();
IValue lines;
if (ivalues[0].isString() &&
kFormatWithStringTable == ivalues[0].toStringRef()) {
deserializer.reset(new SourceRangeDeserializer(ivalues[1]));
lines = ivalues[2];
} else {
deserializer.reset(new SourceRangeDeserializer());
lines = ivaluesTuple;
}
for (auto& val : lines.toTuple()->elements()) {
for (auto& val : ivalues) {
const auto& tup_elems = val.toTupleRef().elements();
int64_t offset = tup_elems[kByteOffsetIndex].toInt();
auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]);

View File

@ -20,7 +20,6 @@ class SourceRangeSerializer;
static constexpr size_t kByteOffsetIndex = 0;
static constexpr size_t kSourceRangeIndex = 1;
static constexpr size_t kSourceRangeTagIndex = 2;
constexpr c10::string_view kFormatWithStringTable = "FORMAT_WITH_STRING_TABLE";
class SourceRangePickler {
public:
@ -36,21 +35,14 @@ class SourceRangePickler {
class SourceRangeDeserializer {
public:
SourceRangeDeserializer() = default;
explicit SourceRangeDeserializer(c10::IValue text_table) {
for (const auto& x : text_table.toTuple()->elements()) {
text_table_.emplace_back(std::make_shared<std::string>(x.toStringRef()));
}
}
SourceRange deserialize(const c10::IValue& iv);
private:
std::shared_ptr<Source> deserialize_source(const c10::IValue& iv);
std::shared_ptr<SourceView> deserialize_source(const c10::IValue& iv);
std::unordered_map<
c10::intrusive_ptr<c10::ivalue::Tuple>,
std::shared_ptr<Source>>
std::shared_ptr<SourceView>>
cached_sources;
std::vector<std::shared_ptr<std::string>> text_table_;
};
class SourceRangeUnpickler {

View File

@ -94,7 +94,7 @@ size_t assertFind(
const SourceRange& search_range,
const std::string& sub,
const std::function<void(std::ostream& out)>& extra_msg = nullptr) {
auto pos = search_range.source()->text_str().find(sub, search_range.start());
auto pos = search_range.source()->text().find(sub, search_range.start());
if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
auto found_range =
SourceRange(search_range.source(), search_range.start(), sub.size());
@ -122,18 +122,19 @@ size_t assertFind(
}
size_t assertFind(
const std::shared_ptr<Source>& source,
const std::shared_ptr<SourceView>& source,
const std::string& sub,
size_t start,
const Check& check) {
return assertFind(SourceRange(source, start, source->size()), sub, check);
return assertFind(
SourceRange(source, start, source->text().size()), sub, check);
}
void assertNotFind(
const SourceRange& search_range,
const std::string& sub,
const Check& check) {
auto pos = search_range.source()->text_str().find(sub, search_range.start());
auto pos = search_range.source()->text().find(sub, search_range.start());
if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) {
auto found_range =
SourceRange(search_range.source(), pos, sub.size() + pos);
@ -201,7 +202,9 @@ struct FileCheckImpl {
friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
private:
bool parseSingleCheck(const std::shared_ptr<Source>& source, size_t* start) {
bool parseSingleCheck(
const std::shared_ptr<SourceView>& source,
size_t* start) {
const static std::vector<std::pair<CheckType, std::string>> check_pairs = {
{CHECK, ": "},
{CHECK_NEXT, "-NEXT: "},
@ -214,35 +217,31 @@ struct FileCheckImpl {
for (const auto& check_pair : check_pairs) {
const std::string& check_suffix = check_pair.second;
auto suffix_pos = source->text_str().find(check_suffix, *start);
auto suffix_pos = source->text().find(check_suffix, *start);
if (suffix_pos != *start) {
continue;
}
size_t end_check_string = suffix_pos + check_suffix.size();
CheckType type = check_pair.first;
c10::optional<size_t> count = c10::nullopt;
auto end_line = source->text_str().find("\n", end_check_string);
auto end_line = source->text().find('\n', end_check_string);
bool exactly = false;
if (type == CHECK_COUNT) {
const std::string exact = "EXACTLY-";
if (source->text_str().find(exact, end_check_string) ==
end_check_string) {
if (source->text().find(exact, end_check_string) == end_check_string) {
exactly = true;
end_check_string += exact.size();
}
size_t end =
assertFind(SourceRange(source, end_check_string, end_line), ":");
auto count_view = source->text_str()
.substr(end_check_string, end - end_check_string)
.str();
auto count_view =
source->text().substr(end_check_string, end - end_check_string);
count = c10::stoll(std::string(count_view.begin(), count_view.end()));
end_check_string = end + 2; // add ':' and the space
}
auto check = Check(
type,
source->text_str()
.substr(end_check_string, end_line - end_check_string)
.str(),
source->text().substr(end_check_string, end_line - end_check_string),
count);
addCheck(check);
if (exactly) {
@ -254,30 +253,32 @@ struct FileCheckImpl {
return false;
}
size_t findNextStart(const std::shared_ptr<Source>& source, size_t prev_end) {
size_t start = source->text_str().find("#", prev_end);
size_t findNextStart(
const std::shared_ptr<SourceView>& source,
size_t prev_end) {
size_t start = source->text().find('#', prev_end);
if (start == std::string::npos) {
return start;
}
start += 1;
static constexpr size_t max_whitespace = 6;
size_t i = 0;
while (start + i < source->size() && i < max_whitespace) {
auto c = source->char_at(start + i);
while (start + i < source->text().size() && i < max_whitespace) {
auto c = source->text().at(start + i);
if (c != ' ' && c != '\t') {
break;
}
i++;
}
static const std::string check = "CHECK";
if (source->text_str().substr(start + i, check.size()) == check) {
if (source->text().substr(start + i, check.size()) == check) {
return start + i + check.size();
} else {
return findNextStart(source, start + i + 1);
}
}
void parseStrings(const std::shared_ptr<Source>& source) {
void parseStrings(const std::shared_ptr<SourceView>& source) {
size_t start = 0;
start = findNextStart(source, 0);
while (start != std::string::npos) {
@ -296,7 +297,7 @@ struct FileCheckImpl {
void doCheckNot(
const std::vector<Check>& nots,
const std::shared_ptr<Source>& source,
const std::shared_ptr<SourceView>& source,
const SourceRange& prev,
const SourceRange& next) {
auto start = prev.end(); // inclusive
@ -313,7 +314,7 @@ struct FileCheckImpl {
// Checks that source token is highlighted, does not advance search range.
void doCheckSourceHighlighted(
const Check& check,
const std::shared_ptr<Source>& source,
const std::shared_ptr<SourceView>& source,
size_t start_offset) {
auto construct_error_and_throw = [&](size_t error_start_pos) {
SourceRange error_range(
@ -329,8 +330,8 @@ struct FileCheckImpl {
size_t search_start_offset = start_offset;
bool found_token_at_least_once = false;
size_t pos = search_start_offset;
while (pos < source->size()) {
pos = source->text_str().find(check.search_str_, search_start_offset);
while (pos < source->text().size()) {
pos = source->text().find(check.search_str_, search_start_offset);
if (pos == std::string::npos) {
break;
}
@ -348,16 +349,17 @@ struct FileCheckImpl {
auto highlight_start_offset =
source->offset_for_line(highlight_lineno) + col;
auto highlight_end_offset = std::min(
highlight_start_offset + check.search_str_.size(), source->size());
highlight_start_offset + check.search_str_.size(),
source->text().size());
if (highlight_end_offset >= source->size()) {
if (highlight_end_offset >= source->text().size()) {
construct_error_and_throw(pos);
}
bool found_highlight = true;
for (const auto posi :
c10::irange(highlight_start_offset, highlight_end_offset)) {
if (source->char_at(posi) != '~') {
if (source->text()[posi] != '~') {
found_highlight = false;
}
}
@ -388,7 +390,7 @@ struct FileCheckImpl {
SourceRange matchDagGroup(
const std::vector<Check>& group,
const std::shared_ptr<Source>& source,
const std::shared_ptr<SourceView>& source,
const SourceRange& prev) {
size_t group_beg = std::string::npos;
size_t group_end = 0;
@ -406,7 +408,7 @@ struct FileCheckImpl {
SourceRange matchGroup(
const std::vector<Check>& group,
const std::shared_ptr<Source>& source,
const std::shared_ptr<SourceView>& source,
const SourceRange& prev) {
AT_ASSERT(group.size() != 0);
CheckType type = group[0].type_;
@ -465,7 +467,7 @@ struct FileCheckImpl {
return SourceRange(source, start_range, end_range);
}
void doChecks(const std::shared_ptr<Source>& source) {
void doChecks(const std::shared_ptr<SourceView>& source) {
SourceRange prev(source, 0, 0);
for (size_t i = 0; i < groups.size(); i++) {
const auto& curr_group = groups[i];
@ -482,7 +484,7 @@ struct FileCheckImpl {
++i; // already checked the group after
} else {
SourceRange end_of_file(
source, source->size() + 1, source->size() + 1);
source, source->text().size() + 1, source->text().size() + 1);
doCheckNot(curr_group, source, prev, end_of_file);
}
}

View File

@ -258,22 +258,7 @@ def get_model_info(
# Parse debug info and add begin/end markers if not present
# to ensure that we cover the entire source code.
debug_info_t = pickle.loads(raw_debug)
text_table = None
if (len(debug_info_t) == 3 and
isinstance(debug_info_t[0], str) and
debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'):
_, text_table, content = debug_info_t
def parse_new_format(line):
# (0, (('', '', 0), 0, 0))
num, ((text_indexes, fname_idx, offset), start, end), tag = line
text = ''.join(text_table[x] for x in text_indexes) # type: ignore
fname = text_table[fname_idx] # type: ignore
return num, ((text, fname, offset), start, end), tag
debug_info_t = map(parse_new_format, content)
assert isinstance(debug_info_t, tuple)
debug_info = list(debug_info_t)
if not debug_info:
debug_info.append((0, (('', '', 0), 0, 0)))