mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Reland "Make debug_pkl smaller by only emitting unique traces." (#73368)
Summary: ## Original commit message: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73368 debug_pkl file inside of pytorch's .pt file consists of a list of SourceRanges. Each SourceRange points to a Source which is a stack track, filename, and start, end numbers. Those are emitted in debug_pkl file as strings. Since many SourceRange shares the same source, the string for trace can be deduped. The newer format saves a set of unique traces in a tuple, then each SourceRange will save the offset of it's trace w.r.t. position in that tuple. (i.e. manually applying dictionary compression). The above helps with smaller file size. On loading, if we copy each trace to Source as string the runtime memory would still blowup. To mitigate this, we use SourceView directly instead of source which will take the reference of string inside of Deserializer and make that into string_view. This is safe because Deserializer is hold by Unpickler by shared_ptr, and Unpickler is also hold by shared_ptr by another Source object. That Source object will be alive during the model construction. Test Plan: ## Original Test plan unit test Took original file (312271638_930.predictor.disagg.local); loaded with `torch.jit.load` save again with `torch.jit.save`. Unzip both, look at contents: ``` [qihan@devvm5585.vll0 ~]$ du archive -h 4.0K archive/xl_model_weights 3.7M archive/extra 8.0K archive/code/__torch__/caffe2/torch/fb/model_transform/splitting 8.0K archive/code/__torch__/caffe2/torch/fb/model_transform 8.0K archive/code/__torch__/caffe2/torch/fb 8.0K archive/code/__torch__/caffe2/torch 8.0K archive/code/__torch__/caffe2 20M archive/code/__torch__/torch/fx/graph_module 20M archive/code/__torch__/torch/fx 8.0K archive/code/__torch__/torch/classes 20M archive/code/__torch__/torch 20M archive/code/__torch__ 20M archive/code 2.7M archive/constants 35M archive [qihan@devvm5585.vll0 ~]$ du resaved -h 4.0K resaved/extra 8.0K resaved/code/__torch__/caffe2/torch/fb/model_transform/splitting 8.0K resaved/code/__torch__/caffe2/torch/fb/model_transform 8.0K resaved/code/__torch__/caffe2/torch/fb 8.0K resaved/code/__torch__/caffe2/torch 8.0K resaved/code/__torch__/caffe2 1.3M resaved/code/__torch__/torch/fx/graph_module 1.3M resaved/code/__torch__/torch/fx 8.0K resaved/code/__torch__/torch/classes 1.4M resaved/code/__torch__/torch 1.4M resaved/code/__torch__ 1.4M resaved/code 2.7M resaved/constants 13M resaved [qihan@devvm5585.vll0 ~]$ ``` ## Additional test: `buck test mode/dev-tsan //caffe2/benchmarks/static_runtime:static_runtime_cpptest -- --exact 'caffe2/benchmarks/static_runtime:static_runtime_cpptest - StaticRuntime.to'` passes test jest.fbios.startup_cold_start.local.simulator f333356873 - Differential Revision: D35196883 Pull Request resolved: https://github.com/pytorch/pytorch/pull/74869 Approved by: https://github.com/gmagogsfm
This commit is contained in:
51
test/cpp/jit/source_range_test.cpp
Normal file
51
test/cpp/jit/source_range_test.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/jit/frontend/source_range.h>
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace ::torch::jit;
|
||||
|
||||
TEST(SourceRangeTest, test_find) {
|
||||
std::vector<std::shared_ptr<std::string>> strings;
|
||||
strings.push_back(std::make_shared<std::string>("hello world"));
|
||||
strings.push_back(std::make_shared<std::string>("nihaoma"));
|
||||
|
||||
std::vector<c10::string_view> pieces{*strings[0], *strings[1]};
|
||||
|
||||
StringCordView view(pieces, strings);
|
||||
|
||||
auto x = view.find("rldni", 0);
|
||||
EXPECT_EQ(x, 8);
|
||||
}
|
||||
|
||||
TEST(SourceRangeTest, test_substr) {
|
||||
std::vector<std::shared_ptr<std::string>> strings;
|
||||
strings.push_back(std::make_shared<std::string>("hello world"));
|
||||
strings.push_back(std::make_shared<std::string>("nihaoma"));
|
||||
|
||||
std::vector<c10::string_view> pieces{*strings[0], *strings[1]};
|
||||
|
||||
StringCordView view(pieces, strings);
|
||||
|
||||
auto x = view.substr(4, 10).str();
|
||||
EXPECT_EQ(x, view.str().substr(4, 10));
|
||||
EXPECT_EQ(view.substr(0, view.size()).str(), view.str());
|
||||
}
|
||||
|
||||
TEST(SourceRangeTest, test_iter) {
|
||||
std::vector<std::shared_ptr<std::string>> strings;
|
||||
strings.push_back(std::make_shared<std::string>("hello world"));
|
||||
strings.push_back(std::make_shared<std::string>("nihaoma"));
|
||||
|
||||
std::vector<c10::string_view> pieces{*strings[0], *strings[1]};
|
||||
|
||||
StringCordView view(pieces, strings);
|
||||
|
||||
auto iter = view.iter_for_pos(5);
|
||||
EXPECT_EQ(*iter, ' ');
|
||||
EXPECT_EQ(iter.rest_line(), " world");
|
||||
EXPECT_EQ(*iter.next_iter(), 'w');
|
||||
EXPECT_EQ(iter.pos(), 5);
|
||||
|
||||
iter = view.iter_for_pos(13);
|
||||
EXPECT_EQ(iter.pos(), 13);
|
||||
}
|
@ -143,6 +143,38 @@ TEST(BackendTest, TestCompiler) {
|
||||
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
|
||||
}
|
||||
|
||||
TEST(BackendTest, TestCompilerWithStringTable) {
|
||||
setShouldUseFormatWithStringTable(true);
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, x, h):
|
||||
return x + h
|
||||
)");
|
||||
|
||||
std::vector<IValue> inputs;
|
||||
inputs.emplace_back(2.0 * torch::ones({}));
|
||||
inputs.emplace_back(1.0 * torch::ones({}));
|
||||
auto ref = m.forward(inputs);
|
||||
|
||||
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
|
||||
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
|
||||
fake_dict.insert("", "");
|
||||
compile_spec.insert("forward", fake_dict);
|
||||
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
|
||||
// lowered module
|
||||
auto lm = torch::jit::detail::codegen_backend_module(
|
||||
"backend_with_compiler_demo", m, compile_spec, any_dict_ty);
|
||||
auto res = lm.forward(inputs);
|
||||
AT_ASSERT(res.toTensor().equal(ref.toTensor()));
|
||||
|
||||
std::stringstream ss;
|
||||
lm._save_for_mobile(ss);
|
||||
auto mlm = _load_for_mobile(ss);
|
||||
auto mres = mlm.forward(inputs);
|
||||
setShouldUseFormatWithStringTable(false);
|
||||
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
|
||||
}
|
||||
|
||||
TEST(BackendTest, TestComposite) {
|
||||
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
|
||||
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
|
||||
@ -384,6 +416,56 @@ Traceback of TorchScript (most recent call last):
|
||||
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
||||
}
|
||||
|
||||
TEST(BackendTestDebugInfo, TestCompilerWithStringTable) {
|
||||
setShouldUseFormatWithStringTable(true);
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, x, h):
|
||||
return x + h
|
||||
)");
|
||||
|
||||
std::vector<IValue> inputs;
|
||||
inputs.emplace_back(torch::rand({2, 4}));
|
||||
inputs.emplace_back(torch::rand({13, 9}));
|
||||
|
||||
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
|
||||
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
|
||||
fake_dict.insert("", "");
|
||||
compile_spec.insert("forward", fake_dict);
|
||||
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
|
||||
// lowered module
|
||||
auto lm = torch::jit::detail::codegen_backend_module(
|
||||
"backend_with_compiler_demo", m, compile_spec, any_dict_ty);
|
||||
|
||||
std::stringstream ss;
|
||||
lm._save_for_mobile(ss, ExtraFilesMap(), true);
|
||||
auto mlm = _load_for_mobile(ss);
|
||||
std::string error_pattern = R"(
|
||||
Module hierarchy:top(m)::<unknown>.__loweredModule__(m)::forward.aten::add
|
||||
Traceback of TorchScript (most recent call last):
|
||||
File "<string>", line 3, in <unknown>
|
||||
|
||||
def forward(self, x: Tensor, h: Tensor):
|
||||
return self.__loweredModule__.forward(x, h)
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
|
||||
|
||||
File "<string>", line 5, in forward
|
||||
typed_inputs: List[Any] = [x, h, ]
|
||||
if self.__backend.is_available() :
|
||||
_0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
|
||||
~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
|
||||
assert isinstance(_0, Tensor)
|
||||
return _0
|
||||
File "<string>", line 3, in <unknown>
|
||||
|
||||
def forward(self, x, h):
|
||||
return x + h
|
||||
~~~~~ <--- HERE
|
||||
)";
|
||||
setShouldUseFormatWithStringTable(false);
|
||||
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
|
||||
}
|
||||
|
||||
TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) {
|
||||
Module a("A");
|
||||
a.define(R"(
|
||||
|
@ -41,7 +41,9 @@ static inline void trim(std::string& s) {
|
||||
trim(substring_s); \
|
||||
auto exception_string = std::string(e.what()); \
|
||||
trim(exception_string); \
|
||||
ASSERT_NE(exception_string.find(substring_s), std::string::npos); \
|
||||
ASSERT_NE(exception_string.find(substring_s), std::string::npos) \
|
||||
<< " Error was: \n" \
|
||||
<< exception_string; \
|
||||
}
|
||||
|
||||
namespace torch {
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <ATen/core/Reduction.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/string_utils.h>
|
||||
#include <torch/csrc/jit/frontend/lexer.h>
|
||||
#include <torch/csrc/jit/frontend/parse_string_literal.h>
|
||||
@ -27,8 +28,13 @@ namespace jit {
|
||||
|
||||
namespace {
|
||||
struct SchemaParser {
|
||||
SchemaParser(const std::string& str)
|
||||
: L(std::make_shared<SourceView>(c10::string_view(str))),
|
||||
explicit SchemaParser(const std::string& str)
|
||||
: L(std::make_shared<Source>(
|
||||
c10::string_view(str),
|
||||
c10::nullopt,
|
||||
0,
|
||||
nullptr,
|
||||
Source::DONT_COPY)),
|
||||
type_parser(L, /*parse_complete_tensor_types*/ false) {}
|
||||
|
||||
either<OperatorName, FunctionSchema> parseDeclaration() {
|
||||
|
@ -187,39 +187,39 @@ struct TORCH_API SharedParserData {
|
||||
#undef ADD_CASE
|
||||
}
|
||||
|
||||
// find the longest match of str.substring(pos) against a token, return true
|
||||
// if successful filling in kind, start,and len
|
||||
bool match(
|
||||
c10::string_view str,
|
||||
size_t pos,
|
||||
StringCordView::Iterator pos,
|
||||
bool continuation, // are we inside a scope where newlines don't count
|
||||
// (e.g. inside parens)
|
||||
bool whitespace_token, // should we treat whitespace as a token
|
||||
int* kind,
|
||||
size_t* start,
|
||||
size_t* len) {
|
||||
StringCordView::Iterator* start,
|
||||
StringCordView::Iterator* end) {
|
||||
*start = pos;
|
||||
// skip whitespace
|
||||
while (pos < str.size() && isblank(str[pos]))
|
||||
pos++;
|
||||
while (pos.has_next() && isblank(*pos)) {
|
||||
++pos;
|
||||
}
|
||||
|
||||
// special handling
|
||||
if (pos < str.size()) {
|
||||
if (str[pos] == '#' && !isTypeComment(str, pos)) {
|
||||
if (pos.has_next()) {
|
||||
if (*pos == '#' && !isTypeComment(pos)) {
|
||||
// skip comments
|
||||
while (pos < str.size() && str[pos] != '\n')
|
||||
pos++;
|
||||
while (pos.has_next() && *pos != '\n')
|
||||
++pos;
|
||||
// tail call, handle whitespace and more comments
|
||||
return match(
|
||||
str, pos, continuation, whitespace_token, kind, start, len);
|
||||
return match(pos, continuation, whitespace_token, kind, start, end);
|
||||
}
|
||||
if (str[pos] == '\\' && pos + 1 < str.size() && str[pos + 1] == '\n' &&
|
||||
!whitespace_token) {
|
||||
return match(str, pos + 2, continuation, false, kind, start, len);
|
||||
if (*pos == '\\') {
|
||||
auto newiter = pos;
|
||||
++newiter;
|
||||
if (newiter.has_next() && *newiter == '\n' && !whitespace_token) {
|
||||
++newiter;
|
||||
return match(newiter, continuation, false, kind, start, end);
|
||||
}
|
||||
if (str[pos] == '\n') {
|
||||
return match(
|
||||
str, pos + 1, continuation, !continuation, kind, start, len);
|
||||
}
|
||||
if (*pos == '\n') {
|
||||
return match(++pos, continuation, !continuation, kind, start, end);
|
||||
}
|
||||
}
|
||||
// we handle white space before EOF because in the case we have something
|
||||
@ -228,26 +228,31 @@ struct TORCH_API SharedParserData {
|
||||
// else:
|
||||
// pass
|
||||
if (whitespace_token) {
|
||||
*kind = pos == str.size() ? TK_WHITESPACE_EOF : TK_WHITESPACE;
|
||||
*len = pos - *start;
|
||||
*kind = !pos.has_next() ? TK_WHITESPACE_EOF : TK_WHITESPACE;
|
||||
*end = pos;
|
||||
return true;
|
||||
}
|
||||
if (pos == str.size()) {
|
||||
if (!pos.has_next()) {
|
||||
*kind = TK_EOF;
|
||||
*start = pos;
|
||||
*len = 0;
|
||||
*end = *start;
|
||||
return true;
|
||||
}
|
||||
// invariant: the next token is not whitespace or newline
|
||||
*start = pos;
|
||||
// check for a valid number
|
||||
if (isNumber(str, pos, len)) {
|
||||
size_t len;
|
||||
if (isNumber(pos.rest_line(), 0, &len)) {
|
||||
*end = *start;
|
||||
*end += len;
|
||||
*kind = TK_NUMBER;
|
||||
return true;
|
||||
}
|
||||
// check for string
|
||||
if (isString(str, pos, len)) {
|
||||
if (isString(pos.rest_line(), 0, &len)) {
|
||||
*kind = TK_STRINGLITERAL;
|
||||
*end = *start;
|
||||
*end += len;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -257,11 +262,14 @@ struct TORCH_API SharedParserData {
|
||||
bool matched = false;
|
||||
bool ident = true;
|
||||
TokenTrie* cur = head.get();
|
||||
for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr); i++) {
|
||||
ident = ident && validIdent(i, str[pos + i]);
|
||||
// for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr);
|
||||
// i++)
|
||||
for (size_t i = 0; pos.has_next() && (ident || cur != nullptr);
|
||||
++pos, ++i) {
|
||||
ident = ident && validIdent(i, *pos);
|
||||
if (ident) {
|
||||
matched = true;
|
||||
*len = i + 1;
|
||||
*end = pos.next_iter();
|
||||
*kind = TK_IDENT;
|
||||
}
|
||||
// check for token second, so that e.g. 'max' matches the token TK_MAX
|
||||
@ -270,14 +278,14 @@ struct TORCH_API SharedParserData {
|
||||
if (cur) {
|
||||
const auto begin_it = cur->child_chars.begin();
|
||||
const auto end_it = cur->child_chars.end();
|
||||
const auto ch_it = std::find(begin_it, end_it, str[pos + i]);
|
||||
const auto ch_it = std::find(begin_it, end_it, *pos);
|
||||
|
||||
cur = (ch_it == end_it) ? nullptr
|
||||
: cur->child_tries[ch_it - begin_it].get();
|
||||
|
||||
if (cur && cur->kind != 0) {
|
||||
matched = true;
|
||||
*len = i + 1;
|
||||
*end = pos.next_iter();
|
||||
*kind = cur->kind;
|
||||
}
|
||||
}
|
||||
@ -368,8 +376,19 @@ struct TORCH_API SharedParserData {
|
||||
bool isblank(int n) {
|
||||
return isspace(n) && n != '\n';
|
||||
}
|
||||
|
||||
bool isTypeComment(StringCordView::Iterator str_iter) {
|
||||
c10::string_view rest_line = str_iter.rest_line();
|
||||
const std::string type_string = "# type:";
|
||||
if (rest_line.size() < type_string.length()) {
|
||||
return false;
|
||||
}
|
||||
auto match_string = rest_line.substr(0, type_string.size());
|
||||
return match_string == type_string;
|
||||
}
|
||||
|
||||
// Make an exception ignoring comments for type annotation comments
|
||||
bool isTypeComment(c10::string_view str, size_t pos) {
|
||||
bool isTypeComment(StringCordView str, size_t pos) {
|
||||
const std::string type_string = "# type:";
|
||||
if (str.size() < pos + type_string.length()) {
|
||||
return false;
|
||||
@ -388,7 +407,7 @@ struct Token {
|
||||
SourceRange range;
|
||||
Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {}
|
||||
std::string text() {
|
||||
return range.text();
|
||||
return std::string(range.token_text());
|
||||
}
|
||||
std::string kindString() const {
|
||||
return kindToString(kind);
|
||||
@ -396,7 +415,7 @@ struct Token {
|
||||
};
|
||||
|
||||
struct Lexer {
|
||||
explicit Lexer(std::shared_ptr<SourceView> source)
|
||||
explicit Lexer(std::shared_ptr<Source> source)
|
||||
: source(std::move(source)),
|
||||
pos(0),
|
||||
nesting(0),
|
||||
@ -514,30 +533,37 @@ struct Lexer {
|
||||
Token lexRaw(bool whitespace_token = false) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int kind;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
size_t start;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
size_t length;
|
||||
AT_ASSERT(source);
|
||||
if (current == nullptr) {
|
||||
AT_ASSERT(pos == 0);
|
||||
current = std::make_unique<StringCordView::Iterator>(
|
||||
source->text_str().begin());
|
||||
}
|
||||
|
||||
StringCordView::Iterator start_iter = *current;
|
||||
StringCordView::Iterator end_iter = *current;
|
||||
if (!shared.match(
|
||||
source->text(),
|
||||
pos,
|
||||
*current,
|
||||
nesting > 0,
|
||||
whitespace_token,
|
||||
&kind,
|
||||
&start,
|
||||
&length)) {
|
||||
&start_iter,
|
||||
&end_iter)) {
|
||||
expected(
|
||||
"a valid token",
|
||||
Token(
|
||||
(source->text())[start], SourceRange(source, start, start + 1)));
|
||||
**current,
|
||||
SourceRange(source, start_iter, start_iter.pos() + 1)));
|
||||
}
|
||||
auto t = Token(kind, SourceRange(source, start, start + length));
|
||||
pos = start + length;
|
||||
|
||||
auto t = Token(kind, SourceRange(source, start_iter, end_iter.pos()));
|
||||
pos = end_iter.pos();
|
||||
*current = end_iter;
|
||||
return t;
|
||||
}
|
||||
|
||||
std::shared_ptr<SourceView> source;
|
||||
std::shared_ptr<Source> source;
|
||||
std::unique_ptr<StringCordView::Iterator> current;
|
||||
size_t pos;
|
||||
size_t nesting; // depth of ( [ { nesting...
|
||||
std::vector<int> indent_stack; // stack of indentation level of blocks
|
||||
|
@ -46,7 +46,7 @@ Decl mergeTypesFromTypeComment(
|
||||
}
|
||||
|
||||
struct ParserImpl {
|
||||
explicit ParserImpl(const std::shared_ptr<SourceView>& source)
|
||||
explicit ParserImpl(const std::shared_ptr<Source>& source)
|
||||
: L(source), shared(sharedParserData()) {}
|
||||
|
||||
Ident parseIdent() {
|
||||
@ -801,7 +801,7 @@ struct ParserImpl {
|
||||
SharedParserData& shared;
|
||||
};
|
||||
|
||||
Parser::Parser(const std::shared_ptr<SourceView>& src)
|
||||
Parser::Parser(const std::shared_ptr<Source>& src)
|
||||
: pImpl(new ParserImpl(src)) {}
|
||||
|
||||
Parser::~Parser() = default;
|
||||
|
@ -17,7 +17,7 @@ TORCH_API Decl mergeTypesFromTypeComment(
|
||||
bool is_method);
|
||||
|
||||
struct TORCH_API Parser {
|
||||
explicit Parser(const std::shared_ptr<SourceView>& src);
|
||||
explicit Parser(const std::shared_ptr<Source>& src);
|
||||
TreeRef parseFunction(bool is_method);
|
||||
TreeRef parseClass();
|
||||
Decl parseTypeComment();
|
||||
|
@ -227,7 +227,7 @@ TypePtr ScriptTypeParser::parseTypeFromExpr(const Expr& expr) const {
|
||||
// expression and base type names.
|
||||
if (resolver_) {
|
||||
if (auto typePtr =
|
||||
resolver_->resolveType(expr.range().text(), expr.range())) {
|
||||
resolver_->resolveType(expr.range().text().str(), expr.range())) {
|
||||
return typePtr;
|
||||
}
|
||||
}
|
||||
|
@ -4,13 +4,140 @@
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// A stringlike class backed by a vector of string_view
|
||||
// the string represented are logically the concatenation of the string_views
|
||||
// This has advantage of not needing continues memory.
|
||||
StringCordView::StringCordView() {
|
||||
accumulated_sizes_.push_back(0);
|
||||
}
|
||||
|
||||
StringCordView::StringCordView(
|
||||
std::vector<c10::string_view> inputs,
|
||||
std::vector<std::shared_ptr<std::string>> ownerships)
|
||||
: pieces_(std::move(inputs)), owned_strings_(std::move(ownerships)) {
|
||||
accumulated_sizes_.push_back(0);
|
||||
size_t running_sum = 0;
|
||||
for (auto& s : pieces_) {
|
||||
if (s.size() > 0) {
|
||||
running_sum += s.size();
|
||||
accumulated_sizes_.push_back(running_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t StringCordView::find(const std::string& tok, size_t start) const {
|
||||
if (tok.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if ((size() - start) < tok.size()) {
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
Iterator begin = iter_for_pos(start);
|
||||
Iterator end_iter = end();
|
||||
size_t offset = start;
|
||||
for (; begin != end_iter; ++begin, ++offset) {
|
||||
if (*begin == tok[0]) {
|
||||
auto mis = std::mismatch(begin, end_iter, tok.begin(), tok.end());
|
||||
if (mis.second == tok.end()) {
|
||||
// no mismatch, and second string (tok) is exhausted.
|
||||
return offset;
|
||||
}
|
||||
if (mis.first == end_iter) {
|
||||
// this str is exhausted but tok is not
|
||||
return std::string::npos;
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
StringCordView StringCordView::substr(size_t start, size_t size) const {
|
||||
std::vector<c10::string_view> pieces;
|
||||
std::vector<std::shared_ptr<std::string>> ownerships;
|
||||
if (start >= this->size()) {
|
||||
// out of bounds
|
||||
return StringCordView();
|
||||
}
|
||||
if (start + size >= this->size()) {
|
||||
size = this->size() - start;
|
||||
}
|
||||
Iterator begin = iter_for_pos(start);
|
||||
Iterator end = iter_for_pos(start + size);
|
||||
|
||||
if (begin.line_ == end.line_) {
|
||||
// same line
|
||||
pieces.push_back(pieces_[begin.line_].substr(begin.pos_, size));
|
||||
} else {
|
||||
pieces.push_back(pieces_[begin.line_].substr(begin.pos_));
|
||||
|
||||
size_t last_line = pieces_.size();
|
||||
if (end != this->end() && end.line_ < last_line) {
|
||||
// end is within the string
|
||||
last_line = end.line_;
|
||||
}
|
||||
for (size_t i = begin.line_ + 1; i < last_line; i++) {
|
||||
pieces.push_back(pieces_[i]);
|
||||
}
|
||||
if (end != this->end()) {
|
||||
pieces.push_back(pieces_[end.line_].substr(0, end.pos_));
|
||||
}
|
||||
}
|
||||
|
||||
// share ownership
|
||||
std::copy(
|
||||
owned_strings_.begin(),
|
||||
owned_strings_.end(),
|
||||
std::back_inserter(ownerships));
|
||||
|
||||
return StringCordView(std::move(pieces), std::move(ownerships));
|
||||
}
|
||||
|
||||
bool StringCordView::operator==(const std::string& rhs) const {
|
||||
if (size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
|
||||
// both need to exhaust
|
||||
return res.first == end() && res.second == rhs.end();
|
||||
}
|
||||
|
||||
bool StringCordView::operator==(const StringCordView& rhs) const {
|
||||
if (size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
|
||||
// both need to exhaust
|
||||
return res.first == end() && res.second == rhs.end();
|
||||
}
|
||||
|
||||
StringCordView::Iterator StringCordView::iter_for_pos(size_t pos) const {
|
||||
if (pos == 0) {
|
||||
return begin();
|
||||
}
|
||||
if (pos >= size()) {
|
||||
return end();
|
||||
}
|
||||
auto upper = std::upper_bound(
|
||||
accumulated_sizes_.begin(), accumulated_sizes_.end(), pos);
|
||||
if (upper == accumulated_sizes_.end()) {
|
||||
return end();
|
||||
}
|
||||
size_t line = upper - accumulated_sizes_.begin() - 1;
|
||||
assert(accumulated_sizes_[line] <= pos);
|
||||
assert(accumulated_sizes_[line + 1] > pos);
|
||||
return Iterator(this, line, pos - accumulated_sizes_[line], size() - pos);
|
||||
}
|
||||
|
||||
size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const {
|
||||
return (
|
||||
std::hash<uintptr_t>()(reinterpret_cast<uintptr_t>(key.source().get())) ^
|
||||
std::hash<size_t>()(key.start()) ^ std::hash<size_t>()(key.end()));
|
||||
}
|
||||
|
||||
c10::optional<SourceRange> SourceView::findSourceRangeThatGenerated(
|
||||
c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
|
||||
const SourceRange& range) {
|
||||
if (!gen_ranges_) {
|
||||
return c10::nullopt;
|
||||
@ -73,7 +200,7 @@ void SourceRange::print_with_context(
|
||||
return;
|
||||
}
|
||||
|
||||
c10::string_view str = source_view_->text();
|
||||
auto str = source_view_->text_str().str();
|
||||
if (size() == str.size()) {
|
||||
// this is just the entire file, not a subset, so print it out.
|
||||
// primarily used to print out python stack traces
|
||||
@ -141,7 +268,7 @@ void SourceRange::print_with_context(
|
||||
line_end = start();
|
||||
while (line_start < range_end) {
|
||||
// move line_end to end of line
|
||||
while (str[line_end] != '\n' && line_end < str.size()) {
|
||||
while (line_end < str.size() && str[line_end] != '\n') {
|
||||
++line_end;
|
||||
}
|
||||
// print line of code
|
||||
|
@ -4,43 +4,217 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
class SourceRangeUnpickler;
|
||||
struct SourceRange;
|
||||
|
||||
// SourceView represents a code segment. It keeps track of:
|
||||
// A stringlike class backed by a vector of string_view
|
||||
// the string represented are logically the concatenation of the string_views
|
||||
// This has advantage of not needing continues memory.
|
||||
struct TORCH_API StringCordView {
|
||||
StringCordView();
|
||||
StringCordView(const StringCordView&) = default;
|
||||
StringCordView(
|
||||
std::vector<c10::string_view> inputs,
|
||||
std::vector<std::shared_ptr<std::string>> ownerships);
|
||||
|
||||
StringCordView& operator=(const StringCordView&) = default;
|
||||
|
||||
size_t size() const {
|
||||
return accumulated_sizes_.back();
|
||||
}
|
||||
|
||||
size_t find(const std::string& tok, size_t start) const;
|
||||
StringCordView substr(size_t start, size_t size) const;
|
||||
|
||||
char at(size_t index) const {
|
||||
return *iter_for_pos(index);
|
||||
}
|
||||
char operator[](size_t index) const {
|
||||
return at(index);
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
std::stringstream ss;
|
||||
for (auto s : pieces_) {
|
||||
ss << std::string(s);
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool operator==(const std::string& rhs) const;
|
||||
|
||||
bool operator==(const StringCordView& rhs) const;
|
||||
|
||||
c10::string_view piece(size_t index) const {
|
||||
return pieces_[index];
|
||||
}
|
||||
|
||||
struct Iterator {
|
||||
Iterator(
|
||||
const StringCordView* str,
|
||||
size_t start_line,
|
||||
size_t start_pos,
|
||||
size_t size)
|
||||
: line_(start_line), pos_(start_pos), str_(str), size_(size) {}
|
||||
explicit Iterator(const StringCordView* str)
|
||||
: Iterator(str, 0, 0, str->size()) {}
|
||||
|
||||
Iterator() : Iterator(nullptr, 0, 0, 0) {}
|
||||
|
||||
Iterator(const Iterator&) = default;
|
||||
Iterator(Iterator&&) = default;
|
||||
Iterator& operator=(const Iterator&) = default;
|
||||
Iterator& operator=(Iterator&&) = default;
|
||||
|
||||
Iterator operator++() {
|
||||
if (size_ == 0) {
|
||||
return *this;
|
||||
}
|
||||
if ((pos_ + 1) < str_->pieces_[line_].size()) {
|
||||
pos_++;
|
||||
} else {
|
||||
line_++;
|
||||
pos_ = 0;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Iterator operator++(int) {
|
||||
Iterator prev(*this);
|
||||
++(*this);
|
||||
return prev;
|
||||
}
|
||||
|
||||
Iterator next_iter() const {
|
||||
Iterator next(*this);
|
||||
++next;
|
||||
return next;
|
||||
}
|
||||
|
||||
Iterator& operator+=(size_t num) {
|
||||
if (!has_next()) {
|
||||
return *this;
|
||||
}
|
||||
size_t target_pos = pos_ + num;
|
||||
if (target_pos >= str_->accumulated_sizes_[line_] &&
|
||||
(line_ + 1) < str_->accumulated_sizes_.size() &&
|
||||
target_pos < str_->accumulated_sizes_[line_ + 1]) {
|
||||
pos_ = target_pos;
|
||||
return *this;
|
||||
}
|
||||
|
||||
size_t target_abs_pos = pos() + num;
|
||||
*this = str_->iter_for_pos(target_abs_pos);
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool operator==(const Iterator& rhs) const {
|
||||
if (!has_next() && !rhs.has_next()) {
|
||||
return true;
|
||||
}
|
||||
return (str_ == rhs.str_) && (line_ == rhs.line_) && (pos_ == rhs.pos_);
|
||||
}
|
||||
bool operator!=(const Iterator& rhs) {
|
||||
return !((*this) == rhs);
|
||||
}
|
||||
bool has_next() const {
|
||||
return size_ > 0 && (line_ < str_->pieces_.size());
|
||||
}
|
||||
|
||||
char operator*() const {
|
||||
TORCH_INTERNAL_ASSERT(line_ < str_->pieces_.size());
|
||||
TORCH_INTERNAL_ASSERT(pos_ < str_->pieces_[line_].size());
|
||||
return str_->pieces_[line_].at(pos_);
|
||||
}
|
||||
|
||||
// returns rest of the line of the current iterator
|
||||
c10::string_view rest_line() const {
|
||||
if (line_ >= str_->pieces_.size()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
c10::string_view cur_line = str_->pieces_[line_];
|
||||
return cur_line.substr(pos_, std::string::npos);
|
||||
}
|
||||
|
||||
size_t pos() const {
|
||||
if (size_ == 0) {
|
||||
return 0;
|
||||
}
|
||||
return str_->accumulated_sizes_[line_] + pos_;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t line_;
|
||||
size_t pos_;
|
||||
const StringCordView* str_;
|
||||
size_t size_;
|
||||
friend struct StringCordView;
|
||||
};
|
||||
|
||||
Iterator begin() const {
|
||||
return Iterator(this, 0, 0, size());
|
||||
}
|
||||
Iterator end() const {
|
||||
return Iterator(this, pieces_.size(), 0, 0);
|
||||
}
|
||||
Iterator iter_for_pos(size_t pos) const;
|
||||
|
||||
private:
|
||||
std::vector<c10::string_view> pieces_;
|
||||
std::vector<size_t> accumulated_sizes_;
|
||||
std::vector<std::shared_ptr<std::string>> owned_strings_;
|
||||
};
|
||||
|
||||
// Source represents a code segment. It keeps track of:
|
||||
// - text_view : the view into text of the code segment
|
||||
// - filename (optional) : if present, represents the name of the file from
|
||||
// which the code segment originated.
|
||||
// - starting_line_no : represents the line in the original file where the
|
||||
// code segment started.
|
||||
struct SourceView {
|
||||
explicit SourceView(
|
||||
struct TORCH_API Source {
|
||||
// Whether or not Source should copy the string passed in the constructor.
|
||||
enum CopiesString { COPIES_STRING, DONT_COPY };
|
||||
|
||||
explicit Source(
|
||||
c10::string_view text_view,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: text_view_(text_view),
|
||||
filename_(c10::nullopt),
|
||||
starting_line_no_(0),
|
||||
c10::optional<std::string> filename = c10::nullopt,
|
||||
size_t starting_line_no = 0,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr,
|
||||
CopiesString copies_str = COPIES_STRING)
|
||||
: filename_(std::move(filename)),
|
||||
starting_line_no_(starting_line_no),
|
||||
gen_ranges_(std::move(gen_ranges)) {
|
||||
if (copies_str == COPIES_STRING) {
|
||||
std::shared_ptr<std::string> allocated_str =
|
||||
std::make_shared<std::string>(text_view.data(), text_view.size());
|
||||
text_view_ = StringCordView({*allocated_str}, {allocated_str});
|
||||
} else {
|
||||
text_view_ = StringCordView({text_view}, {});
|
||||
}
|
||||
|
||||
calc_line_start_offsets();
|
||||
}
|
||||
|
||||
SourceView(
|
||||
c10::string_view text_view,
|
||||
c10::optional<std::string> filename,
|
||||
size_t starting_line_no,
|
||||
explicit Source(
|
||||
StringCordView str,
|
||||
c10::optional<std::string> filename = c10::nullopt,
|
||||
size_t starting_line_no = 0,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: text_view_(text_view),
|
||||
: text_view_(str),
|
||||
filename_(std::move(filename)),
|
||||
starting_line_no_(starting_line_no),
|
||||
gen_ranges_(std::move(gen_ranges)) {
|
||||
calc_line_start_offsets();
|
||||
}
|
||||
|
||||
// Given a line number (within source_), return the byte offset of the
|
||||
// beginning of that line.
|
||||
size_t offset_for_line(size_t line) const {
|
||||
@ -54,11 +228,9 @@ struct SourceView {
|
||||
|
||||
// Calculate the line (within the code segment) on which `offset` resides.
|
||||
size_t lineno_for_offset(size_t offset) const {
|
||||
return std::upper_bound(
|
||||
line_starting_offsets_.begin(),
|
||||
line_starting_offsets_.end(),
|
||||
offset) -
|
||||
line_starting_offsets_.begin() - 1;
|
||||
auto iter = std::upper_bound(
|
||||
line_starting_offsets_.begin(), line_starting_offsets_.end(), offset);
|
||||
return iter - line_starting_offsets_.begin() - 1;
|
||||
}
|
||||
|
||||
// Calculate the line (within the original source file, if present) on which
|
||||
@ -71,11 +243,26 @@ struct SourceView {
|
||||
}
|
||||
}
|
||||
|
||||
const c10::string_view text() const {
|
||||
StringCordView get_line(size_t lineno) const {
|
||||
auto start = offset_for_line(lineno);
|
||||
auto size = (lineno + 1) < num_lines() ? offset_for_line(lineno + 1) - start
|
||||
: text_view_.size() - start;
|
||||
return text_view_.substr(start, size);
|
||||
}
|
||||
|
||||
const StringCordView& text_str() const {
|
||||
return text_view_;
|
||||
}
|
||||
|
||||
const c10::optional<std::string>& filename() const {
|
||||
char char_at(size_t index) const {
|
||||
return text_view_.at(index);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return text_view_.size();
|
||||
}
|
||||
|
||||
c10::optional<std::string>& filename() {
|
||||
return filename_;
|
||||
}
|
||||
|
||||
@ -86,18 +273,20 @@ struct SourceView {
|
||||
c10::optional<SourceRange> findSourceRangeThatGenerated(
|
||||
const SourceRange& range);
|
||||
|
||||
protected:
|
||||
c10::string_view text_view_;
|
||||
~Source() = default;
|
||||
|
||||
private:
|
||||
void calc_line_start_offsets() {
|
||||
line_starting_offsets_.clear();
|
||||
line_starting_offsets_.push_back(0);
|
||||
size_t pos = 0;
|
||||
while ((pos = text().find('\n', pos)) != std::string::npos) {
|
||||
while ((pos = text_view_.find("\n", pos)) != std::string::npos) {
|
||||
line_starting_offsets_.push_back(++pos);
|
||||
}
|
||||
}
|
||||
|
||||
StringCordView text_view_;
|
||||
|
||||
c10::optional<std::string> filename_;
|
||||
// If filename_ is not present, starting_line_no_ is don't care
|
||||
size_t starting_line_no_;
|
||||
@ -108,67 +297,34 @@ struct SourceView {
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges_;
|
||||
};
|
||||
|
||||
// Source represents a code segment like SourceView, but the former owns a copy
|
||||
// of source text while the latter doesn't.
|
||||
struct Source : public SourceView {
|
||||
explicit Source(
|
||||
std::string text,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: SourceView(text, gen_ranges), text_(std::move(text)) {
|
||||
text_view_ = text_;
|
||||
}
|
||||
|
||||
explicit Source(
|
||||
c10::string_view text_view,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: SourceView(text_view, gen_ranges),
|
||||
text_(text_view.begin(), text_view.end()) {
|
||||
text_view_ = text_;
|
||||
}
|
||||
|
||||
explicit Source(
|
||||
std::string text,
|
||||
c10::optional<std::string> filename,
|
||||
size_t starting_line_no,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: SourceView(text, filename, starting_line_no, gen_ranges),
|
||||
text_(std::move(text)) {
|
||||
text_view_ = text_;
|
||||
}
|
||||
|
||||
explicit Source(
|
||||
c10::string_view text_view,
|
||||
c10::optional<std::string> filename,
|
||||
size_t starting_line_no,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: SourceView(text_view, filename, starting_line_no, gen_ranges),
|
||||
text_(text_view.begin(), text_view.end()) {
|
||||
text_view_ = text_;
|
||||
}
|
||||
|
||||
// Constructor that deepcopies and owns source text referenced in
|
||||
// `source_view`.
|
||||
explicit Source(const SourceView& source_view) : SourceView(source_view) {
|
||||
text_ = std::string(text_view_.begin(), text_view_.end());
|
||||
text_view_ = text_;
|
||||
}
|
||||
|
||||
std::string text_;
|
||||
};
|
||||
|
||||
// A SourceRange is a reference to subset of a Source, specified by `start` and
|
||||
// `end` byte offsets into the source text.
|
||||
struct TORCH_API SourceRange {
|
||||
SourceRange(
|
||||
std::shared_ptr<SourceView> source_view_,
|
||||
size_t start_,
|
||||
size_t end_)
|
||||
: source_view_(std::move(source_view_)), start_(start_), end_(end_) {}
|
||||
SourceRange(std::shared_ptr<Source> source_view_, size_t start_, size_t end_)
|
||||
: source_view_(std::move(source_view_)), start_(start_), end_(end_) {
|
||||
if (source_view_) {
|
||||
start_iter_ = source_view_->text_str().iter_for_pos(start_);
|
||||
}
|
||||
}
|
||||
|
||||
SourceRange() : source_view_(nullptr), start_(0), end_(0) {}
|
||||
|
||||
const std::string text() const {
|
||||
auto text_view = source_view_->text().substr(start(), end() - start());
|
||||
return std::string(text_view.begin(), text_view.end());
|
||||
SourceRange(
|
||||
std::shared_ptr<Source> source_view_,
|
||||
StringCordView::Iterator start_iter,
|
||||
size_t end_)
|
||||
: source_view_(std::move(source_view_)),
|
||||
start_(start_iter.pos()),
|
||||
end_(end_),
|
||||
start_iter_(start_iter) {}
|
||||
|
||||
const c10::string_view token_text() const {
|
||||
size_t size = end() - start();
|
||||
return start_iter_.rest_line().substr(0, size);
|
||||
}
|
||||
|
||||
const StringCordView text() const {
|
||||
return source_view_->text_str().substr(start(), end() - start());
|
||||
}
|
||||
size_t size() const {
|
||||
return end() - start();
|
||||
@ -183,7 +339,7 @@ struct TORCH_API SourceRange {
|
||||
bool highlight,
|
||||
const std::string& funcname) const;
|
||||
|
||||
const std::shared_ptr<SourceView>& source() const {
|
||||
const std::shared_ptr<Source>& source() const {
|
||||
return source_view_;
|
||||
}
|
||||
size_t start() const {
|
||||
@ -229,21 +385,25 @@ struct TORCH_API SourceRange {
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<SourceView> source_view_;
|
||||
std::shared_ptr<Source> source_view_;
|
||||
|
||||
private:
|
||||
size_t start_;
|
||||
size_t end_;
|
||||
StringCordView::Iterator start_iter_;
|
||||
};
|
||||
|
||||
// OwnedSourceRange is just like a SourceRange except that it owns a `Source`
|
||||
// instead of `SourceView`. Thus OwnedSourceRange owns a copy of source text.
|
||||
// instead of `Source`. Thus OwnedSourceRange owns a copy of source text.
|
||||
struct OwnedSourceRange : public SourceRange {
|
||||
OwnedSourceRange(const SourceRange& source_range)
|
||||
explicit OwnedSourceRange(const SourceRange& source_range)
|
||||
: SourceRange(source_range) {
|
||||
const auto& source = source_range.source();
|
||||
if (source) {
|
||||
source_view_ = std::make_shared<Source>(*source);
|
||||
source_view_ = std::make_shared<Source>(
|
||||
source->text_str().str(),
|
||||
source->filename(),
|
||||
source->starting_line_no());
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -281,3 +441,14 @@ using SourceRangeTagMap =
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct iterator_traits<torch::jit::StringCordView::Iterator> {
|
||||
using value_type = char;
|
||||
using difference_type = ptrdiff_t;
|
||||
using pointer = char*;
|
||||
using reference = char&;
|
||||
using iterator_category = std::forward_iterator_tag;
|
||||
};
|
||||
} // namespace std
|
||||
|
@ -21,26 +21,26 @@ namespace jit {
|
||||
*/
|
||||
class TORCH_API SourceRef : public CustomClassHolder {
|
||||
public:
|
||||
explicit SourceRef(std::shared_ptr<SourceView> source_view)
|
||||
explicit SourceRef(std::shared_ptr<Source> source_view)
|
||||
: source_view_(std::move(source_view)) {}
|
||||
bool operator==(const SourceRef& other) const {
|
||||
return source_view_ == other.source_view_;
|
||||
}
|
||||
bool operator<(const SourceView& other) const {
|
||||
bool operator<(const Source& other) const {
|
||||
return source_view_.get() < &other;
|
||||
}
|
||||
friend bool operator<(const SourceView& other, const SourceRef& self) {
|
||||
friend bool operator<(const Source& other, const SourceRef& self) {
|
||||
return &other < self.source_view_.get();
|
||||
}
|
||||
bool operator<(const SourceRef& other) const {
|
||||
return *this < *other.source_view_.get();
|
||||
}
|
||||
const SourceView* operator->() const {
|
||||
const Source* operator->() const {
|
||||
return source_view_.get();
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<SourceView> source_view_;
|
||||
std::shared_ptr<Source> source_view_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
@ -122,17 +122,26 @@ MobileDebugTable::MobileDebugTable(
|
||||
at::DataPtr debug_data;
|
||||
size_t debug_size{0};
|
||||
std::tie(debug_data, debug_size) = reader->getRecord(record_name);
|
||||
auto ivalues =
|
||||
std::move(*jit::unpickle(
|
||||
auto ivalueTuple = jit::unpickle(
|
||||
reinterpret_cast<const char*>(debug_data.get()),
|
||||
debug_size,
|
||||
nullptr,
|
||||
{},
|
||||
c10::parseType)
|
||||
.toTuple())
|
||||
.elements();
|
||||
SourceRangeDeserializer deserializer;
|
||||
for (auto& val : ivalues) {
|
||||
c10::parseType);
|
||||
const auto& ivalues = ivalueTuple.toTuple()->elements();
|
||||
IValue lines;
|
||||
std::unique_ptr<SourceRangeDeserializer> deserializer;
|
||||
if (ivalues.size() == 3 && ivalues[0].isString() &&
|
||||
kFormatWithStringTable == ivalues[0].toStringRef()) {
|
||||
// new format
|
||||
deserializer = std::make_unique<SourceRangeDeserializer>(ivalues[1]);
|
||||
lines = ivalues[2];
|
||||
} else {
|
||||
deserializer = std::make_unique<SourceRangeDeserializer>();
|
||||
lines = ivalueTuple;
|
||||
}
|
||||
|
||||
for (auto& val : lines.toTuple()->elements()) {
|
||||
auto tup_elems = std::move(*std::move(val).toTuple()).elements();
|
||||
// For BC we decode only tuples with 3 elements
|
||||
// assuming it contains
|
||||
@ -140,7 +149,7 @@ MobileDebugTable::MobileDebugTable(
|
||||
if (tup_elems.size() == 3) {
|
||||
int64_t debug_handle = tup_elems[kSourceRangeTagIndex].toInt();
|
||||
auto source_range =
|
||||
deserializer.deserialize(tup_elems[kSourceRangeIndex]);
|
||||
deserializer->deserialize(tup_elems[kSourceRangeIndex]);
|
||||
source_range_map.emplace(debug_handle, std::move(source_range));
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/ir/scope.h>
|
||||
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
@ -7,7 +7,7 @@ bool IndexingPatternFinder::IsSameSource(const Node* n, const Node* m) {
|
||||
const auto source_n = n->sourceRange().source();
|
||||
const auto source_m = m->sourceRange().source();
|
||||
return (
|
||||
(source_n->text() == source_m->text()) &&
|
||||
(source_n->text_str() == source_m->text_str()) &&
|
||||
(source_n->starting_line_no() == source_m->starting_line_no()));
|
||||
}
|
||||
|
||||
|
@ -104,9 +104,8 @@ void initTreeViewBindings(PyObject* module) {
|
||||
return SourceRange(self.source_, start, end);
|
||||
})
|
||||
.def_property_readonly("source", [](const SourceRangeFactory& self) {
|
||||
auto text_view = self.source_->text();
|
||||
std::string text(text_view.begin(), text_view.end());
|
||||
return text;
|
||||
auto text_view = self.source_->text_str().str();
|
||||
return text_view;
|
||||
});
|
||||
|
||||
py::class_<TreeView>(m, "TreeView")
|
||||
|
@ -2214,6 +2214,10 @@ void initJitScriptBindings(PyObject* module) {
|
||||
m.def(
|
||||
"_run_emit_module_hook", [](const Module& m) { didFinishEmitModule(m); });
|
||||
|
||||
m.def(
|
||||
"_set_should_use_format_with_string_table",
|
||||
setShouldUseFormatWithStringTable);
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-unused-raii)
|
||||
py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
|
||||
m, "LoggerBase");
|
||||
|
@ -61,7 +61,7 @@ auto initBindings() {
|
||||
return static_cast<int64_t>((*self)->starting_line_no());
|
||||
})
|
||||
.def("text", [](const c10::intrusive_ptr<SourceRef>& self) {
|
||||
return (*self)->text();
|
||||
return (*self)->text_str().str();
|
||||
});
|
||||
|
||||
torch::class_<InstructionStats>("profiling", "InstructionStats")
|
||||
|
@ -529,13 +529,13 @@ Module jitModuleFromSourceAndConstants(
|
||||
SourceImporter importer(
|
||||
compilation_unit,
|
||||
&constants,
|
||||
[&source](const std::string& qualifier) -> std::shared_ptr<SourceView> {
|
||||
[&source](const std::string& qualifier) -> std::shared_ptr<Source> {
|
||||
auto source_iter = source.find(qualifier);
|
||||
if (source_iter == source.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<Source>(
|
||||
source_iter->second, qualifier, 1, nullptr);
|
||||
source_iter->second, qualifier, 1, nullptr, Source::COPIES_STRING);
|
||||
},
|
||||
version);
|
||||
auto type_resolver = [&](const c10::QualifiedName& qn) {
|
||||
|
@ -22,7 +22,7 @@ std::string qualifierToArchivePath(
|
||||
return export_prefix + path + "." + kExportSuffix;
|
||||
}
|
||||
|
||||
std::shared_ptr<SourceView> findSourceInArchiveFromQualifier(
|
||||
std::shared_ptr<Source> findSourceInArchiveFromQualifier(
|
||||
caffe2::serialize::PyTorchStreamReader& reader,
|
||||
const std::string& export_prefix,
|
||||
const std::string& qualifier) {
|
||||
|
@ -12,7 +12,7 @@ class PyTorchStreamReader;
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct SourceView;
|
||||
struct Source;
|
||||
|
||||
// Convert a class type's qualifier name to the corresponding path the source
|
||||
// file it should be written to.
|
||||
@ -23,7 +23,7 @@ std::string qualifierToArchivePath(
|
||||
const std::string& qualifier,
|
||||
const std::string& export_prefix);
|
||||
|
||||
std::shared_ptr<SourceView> findSourceInArchiveFromQualifier(
|
||||
std::shared_ptr<Source> findSourceInArchiveFromQualifier(
|
||||
caffe2::serialize::PyTorchStreamReader& reader,
|
||||
const std::string& export_prefix,
|
||||
const std::string& qualifier);
|
||||
|
@ -159,7 +159,7 @@ void SourceImporterImpl::parseSourceIfNeeded(const std::string& qualifier) {
|
||||
return;
|
||||
}
|
||||
loaded_sources_.insert(qualifier);
|
||||
std::shared_ptr<SourceView> src = source_loader_(qualifier);
|
||||
std::shared_ptr<Source> src = source_loader_(qualifier);
|
||||
|
||||
// The importer, when looking for classes/functions doesn't know if 'foo'
|
||||
// contains definitions or if it is a prefix of 'foo.bar', we only figure it
|
||||
|
@ -20,8 +20,7 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using SourceLoader =
|
||||
std::function<std::shared_ptr<SourceView>(const std::string&)>;
|
||||
using SourceLoader = std::function<std::shared_ptr<Source>(const std::string&)>;
|
||||
|
||||
struct SourceImporterImpl : public Resolver,
|
||||
std::enable_shared_from_this<SourceImporterImpl> {
|
||||
|
@ -1,54 +1,97 @@
|
||||
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
||||
#include <torch/csrc/jit/serialization/source_range_serialization_impl.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Flags.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <algorithm>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// "Whether to emit compact debug_pkl when saving a model to .pt file."
|
||||
// "Compact file is smaller but cannot be loaded by old torch binaries."
|
||||
// TODO(qihan) remove when all binaries are using string table.
|
||||
thread_local bool should_use_format_with_string_table_ = false;
|
||||
|
||||
class SourceRangeSerializer {
|
||||
public:
|
||||
// Serialize SourceRange as Tuple[SourceType, int, int]
|
||||
// where SourceType = Tuple[str, Optional[str], int, List[int]],
|
||||
// where SourceType = Tuple[int, int, int, List[int]],
|
||||
// The first 2 ints are positions into the vector returned by textSaved
|
||||
// after all the Ranges are processed. textSaved() returns a vector of str
|
||||
// the serialized form of Source
|
||||
c10::IValue serialize(const SourceRange& sr);
|
||||
|
||||
const std::vector<c10::IValue>& texts_saved() {
|
||||
return texts_;
|
||||
}
|
||||
|
||||
SourceRangeSerializer() {
|
||||
texts_.emplace_back("");
|
||||
text_to_idx_[texts_.back().toStringRef()] = 0;
|
||||
}
|
||||
|
||||
private:
|
||||
// Serialize Source as Tuple[str, Optional[str], int, List[int]]
|
||||
// This caches serialized sources, since many SourceRanges can
|
||||
// refer to the same one.
|
||||
c10::IValue serialize_source(const std::shared_ptr<SourceView>& s);
|
||||
c10::IValue serialize_source(const std::shared_ptr<Source>& s);
|
||||
std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
|
||||
|
||||
std::unordered_map<std::shared_ptr<SourceView>, c10::IValue>
|
||||
serialized_sources;
|
||||
int64_t store_text_and_get_index(const std::string& text_view);
|
||||
|
||||
std::vector<c10::IValue> texts_;
|
||||
std::unordered_map<c10::string_view, int64_t> text_to_idx_;
|
||||
};
|
||||
|
||||
SourceRange SourceRangeDeserializer::deserialize(const c10::IValue& iv) {
|
||||
const auto& tup_elems = iv.toTupleRef().elements();
|
||||
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
|
||||
std::shared_ptr<SourceView> source_ = deserialize_source(tup_elems[0]);
|
||||
std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]);
|
||||
int64_t start_ = tup_elems[1].toInt();
|
||||
int64_t end_ = tup_elems[2].toInt();
|
||||
return SourceRange(source_, start_, end_);
|
||||
}
|
||||
|
||||
std::shared_ptr<SourceView> SourceRangeDeserializer::deserialize_source(
|
||||
std::shared_ptr<Source> SourceRangeDeserializer::deserialize_source(
|
||||
const c10::IValue& iv) {
|
||||
auto tup = iv.toTuple();
|
||||
auto it = cached_sources.find(tup);
|
||||
if (it != cached_sources.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::shared_ptr<Source> source;
|
||||
const auto& tup_elems = tup->elements();
|
||||
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
|
||||
std::string text_ = tup_elems[0].toString()->string();
|
||||
c10::optional<std::string> filename_ = tup_elems[1].toOptional<std::string>();
|
||||
if (!text_table_.empty()) {
|
||||
const auto& textIndex = tup_elems[0].toIntList();
|
||||
int64_t fnameIndex = tup_elems[1].toInt();
|
||||
int64_t starting_line_no_ = tup_elems[2].toInt();
|
||||
c10::optional<std::string> filename = c10::nullopt;
|
||||
|
||||
auto source = std::make_shared<Source>(
|
||||
filename = *text_table_[fnameIndex];
|
||||
|
||||
std::vector<c10::string_view> pieces;
|
||||
std::vector<std::shared_ptr<std::string>> strs;
|
||||
|
||||
for (int64_t i : textIndex) {
|
||||
pieces.emplace_back(*text_table_[i]);
|
||||
strs.emplace_back(text_table_[i]);
|
||||
}
|
||||
|
||||
StringCordView str_cord(std::move(pieces), std::move(strs));
|
||||
|
||||
source = std::make_shared<Source>(str_cord, filename, starting_line_no_);
|
||||
} else {
|
||||
std::string text_ = tup_elems[0].toString()->string();
|
||||
c10::optional<std::string> filename_ =
|
||||
tup_elems[1].toOptional<std::string>();
|
||||
int64_t starting_line_no_ = tup_elems[2].toInt();
|
||||
source = std::make_shared<Source>(
|
||||
std::move(text_), std::move(filename_), starting_line_no_);
|
||||
}
|
||||
cached_sources[tup] = source;
|
||||
return source;
|
||||
}
|
||||
@ -58,17 +101,50 @@ c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) {
|
||||
serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end());
|
||||
}
|
||||
|
||||
int64_t SourceRangeSerializer::store_text_and_get_index(
|
||||
const std::string& text_view) {
|
||||
auto text_iter = text_to_idx_.find(text_view);
|
||||
if (text_iter == text_to_idx_.end()) {
|
||||
int64_t text_pos = static_cast<int64_t>(texts_.size());
|
||||
texts_.emplace_back(text_view);
|
||||
text_to_idx_[texts_.back().toStringView()] = text_pos;
|
||||
return text_pos;
|
||||
} else {
|
||||
return text_iter->second;
|
||||
}
|
||||
}
|
||||
|
||||
c10::IValue SourceRangeSerializer::serialize_source(
|
||||
const std::shared_ptr<SourceView>& s) {
|
||||
const std::shared_ptr<Source>& s) {
|
||||
if (serialized_sources.count(s)) {
|
||||
return serialized_sources.at(s);
|
||||
}
|
||||
c10::intrusive_ptr<c10::ivalue::Tuple> serialized;
|
||||
c10::List<int64_t> lines;
|
||||
if (should_use_format_with_string_table_) {
|
||||
if (s == nullptr) {
|
||||
serialized = c10::ivalue::Tuple::create({lines, 0, 0});
|
||||
} else {
|
||||
for (size_t lineno = 0; lineno < s->num_lines(); lineno++) {
|
||||
std::string line_content = s->get_line(lineno).str();
|
||||
int64_t text_pos = store_text_and_get_index(line_content);
|
||||
lines.push_back(text_pos);
|
||||
}
|
||||
|
||||
int64_t fname_pos = 0;
|
||||
if (s->filename().has_value()) {
|
||||
fname_pos = store_text_and_get_index(*s->filename());
|
||||
}
|
||||
serialized = c10::ivalue::Tuple::create(
|
||||
{lines, fname_pos, (int64_t)s->starting_line_no()});
|
||||
}
|
||||
} else {
|
||||
if (s == nullptr) {
|
||||
serialized = c10::ivalue::Tuple::create({"", "", 0});
|
||||
} else {
|
||||
serialized = c10::ivalue::Tuple::create(
|
||||
{s->text(), s->filename(), (int64_t)s->starting_line_no()});
|
||||
{s->text_str().str(), s->filename(), (int64_t)s->starting_line_no()});
|
||||
}
|
||||
}
|
||||
serialized_sources[s] = serialized;
|
||||
return serialized;
|
||||
@ -86,14 +162,24 @@ std::vector<char> SourceRangePickler::pickle(
|
||||
if (it != source_range_tags.end()) {
|
||||
source_range_tag = it->second;
|
||||
}
|
||||
|
||||
ivalues.emplace_back(c10::ivalue::Tuple::create(
|
||||
{(int64_t)range.bytes,
|
||||
srs->serialize(range.range),
|
||||
static_cast<int64_t>(source_range_tag)}));
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> table;
|
||||
auto textTable = c10::ivalue::Tuple::create(srs->texts_saved());
|
||||
auto ivalue = c10::ivalue::Tuple::create(std::move(ivalues));
|
||||
auto result = jit::pickle(ivalue, &table);
|
||||
std::vector<char> result;
|
||||
if (should_use_format_with_string_table_) {
|
||||
result = jit::pickle(
|
||||
c10::ivalue::Tuple::create({kFormatWithStringTable, textTable, ivalue}),
|
||||
&table);
|
||||
} else {
|
||||
result = jit::pickle(ivalue, &table);
|
||||
}
|
||||
TORCH_CHECK(table.size() == 0, "Expected 0 tensors to be written");
|
||||
return result;
|
||||
}
|
||||
@ -103,7 +189,7 @@ ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
|
||||
size_t size)
|
||||
: data(std::move(data)),
|
||||
size(size),
|
||||
deserializer(new SourceRangeDeserializer()),
|
||||
deserializer(nullptr),
|
||||
unpickled_records(nullptr) {}
|
||||
|
||||
void ConcreteSourceRangeUnpickler::unpickle() {
|
||||
@ -119,10 +205,19 @@ void ConcreteSourceRangeUnpickler::unpickle() {
|
||||
{},
|
||||
c10::parseType)
|
||||
.toTuple();
|
||||
const auto& ivalues = ivaluesTuple->elements();
|
||||
|
||||
const auto& ivalues = ivaluesTuple->elements();
|
||||
unpickled_records = std::make_shared<SourceRangeRecords>();
|
||||
for (auto& val : ivalues) {
|
||||
IValue lines;
|
||||
if (ivalues[0].isString() &&
|
||||
kFormatWithStringTable == ivalues[0].toStringRef()) {
|
||||
deserializer.reset(new SourceRangeDeserializer(ivalues[1]));
|
||||
lines = ivalues[2];
|
||||
} else {
|
||||
deserializer.reset(new SourceRangeDeserializer());
|
||||
lines = ivaluesTuple;
|
||||
}
|
||||
for (auto& val : lines.toTuple()->elements()) {
|
||||
const auto& tup_elems = val.toTupleRef().elements();
|
||||
int64_t offset = tup_elems[kByteOffsetIndex].toInt();
|
||||
auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]);
|
||||
@ -152,5 +247,10 @@ c10::optional<SourceRange> ConcreteSourceRangeUnpickler::
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
TORCH_API void setShouldUseFormatWithStringTable(
|
||||
bool should_use_format_with_string_table) {
|
||||
should_use_format_with_string_table_ = should_use_format_with_string_table;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -20,6 +20,7 @@ class SourceRangeSerializer;
|
||||
static constexpr size_t kByteOffsetIndex = 0;
|
||||
static constexpr size_t kSourceRangeIndex = 1;
|
||||
static constexpr size_t kSourceRangeTagIndex = 2;
|
||||
constexpr c10::string_view kFormatWithStringTable = "FORMAT_WITH_STRING_TABLE";
|
||||
|
||||
class SourceRangePickler {
|
||||
public:
|
||||
@ -35,14 +36,21 @@ class SourceRangePickler {
|
||||
|
||||
class SourceRangeDeserializer {
|
||||
public:
|
||||
SourceRangeDeserializer() = default;
|
||||
explicit SourceRangeDeserializer(c10::IValue text_table) {
|
||||
for (const auto& x : text_table.toTuple()->elements()) {
|
||||
text_table_.emplace_back(std::make_shared<std::string>(x.toStringRef()));
|
||||
}
|
||||
}
|
||||
SourceRange deserialize(const c10::IValue& iv);
|
||||
|
||||
private:
|
||||
std::shared_ptr<SourceView> deserialize_source(const c10::IValue& iv);
|
||||
std::shared_ptr<Source> deserialize_source(const c10::IValue& iv);
|
||||
std::unordered_map<
|
||||
c10::intrusive_ptr<c10::ivalue::Tuple>,
|
||||
std::shared_ptr<SourceView>>
|
||||
std::shared_ptr<Source>>
|
||||
cached_sources;
|
||||
std::vector<std::shared_ptr<std::string>> text_table_;
|
||||
};
|
||||
|
||||
class SourceRangeUnpickler {
|
||||
@ -53,5 +61,8 @@ class SourceRangeUnpickler {
|
||||
virtual ~SourceRangeUnpickler() = default;
|
||||
};
|
||||
|
||||
TORCH_API void setShouldUseFormatWithStringTable(
|
||||
bool should_use_format_with_string_table);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -94,7 +94,7 @@ size_t assertFind(
|
||||
const SourceRange& search_range,
|
||||
const std::string& sub,
|
||||
const std::function<void(std::ostream& out)>& extra_msg = nullptr) {
|
||||
auto pos = search_range.source()->text().find(sub, search_range.start());
|
||||
auto pos = search_range.source()->text_str().find(sub, search_range.start());
|
||||
if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
|
||||
auto found_range =
|
||||
SourceRange(search_range.source(), search_range.start(), sub.size());
|
||||
@ -122,19 +122,18 @@ size_t assertFind(
|
||||
}
|
||||
|
||||
size_t assertFind(
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
const std::shared_ptr<Source>& source,
|
||||
const std::string& sub,
|
||||
size_t start,
|
||||
const Check& check) {
|
||||
return assertFind(
|
||||
SourceRange(source, start, source->text().size()), sub, check);
|
||||
return assertFind(SourceRange(source, start, source->size()), sub, check);
|
||||
}
|
||||
|
||||
void assertNotFind(
|
||||
const SourceRange& search_range,
|
||||
const std::string& sub,
|
||||
const Check& check) {
|
||||
auto pos = search_range.source()->text().find(sub, search_range.start());
|
||||
auto pos = search_range.source()->text_str().find(sub, search_range.start());
|
||||
if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) {
|
||||
auto found_range =
|
||||
SourceRange(search_range.source(), pos, sub.size() + pos);
|
||||
@ -202,9 +201,7 @@ struct FileCheckImpl {
|
||||
friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
|
||||
|
||||
private:
|
||||
bool parseSingleCheck(
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
size_t* start) {
|
||||
bool parseSingleCheck(const std::shared_ptr<Source>& source, size_t* start) {
|
||||
const static std::vector<std::pair<CheckType, std::string>> check_pairs = {
|
||||
{CHECK, ": "},
|
||||
{CHECK_NEXT, "-NEXT: "},
|
||||
@ -217,31 +214,35 @@ struct FileCheckImpl {
|
||||
|
||||
for (const auto& check_pair : check_pairs) {
|
||||
const std::string& check_suffix = check_pair.second;
|
||||
auto suffix_pos = source->text().find(check_suffix, *start);
|
||||
auto suffix_pos = source->text_str().find(check_suffix, *start);
|
||||
if (suffix_pos != *start) {
|
||||
continue;
|
||||
}
|
||||
size_t end_check_string = suffix_pos + check_suffix.size();
|
||||
CheckType type = check_pair.first;
|
||||
c10::optional<size_t> count = c10::nullopt;
|
||||
auto end_line = source->text().find('\n', end_check_string);
|
||||
auto end_line = source->text_str().find("\n", end_check_string);
|
||||
bool exactly = false;
|
||||
if (type == CHECK_COUNT) {
|
||||
const std::string exact = "EXACTLY-";
|
||||
if (source->text().find(exact, end_check_string) == end_check_string) {
|
||||
if (source->text_str().find(exact, end_check_string) ==
|
||||
end_check_string) {
|
||||
exactly = true;
|
||||
end_check_string += exact.size();
|
||||
}
|
||||
size_t end =
|
||||
assertFind(SourceRange(source, end_check_string, end_line), ":");
|
||||
auto count_view =
|
||||
source->text().substr(end_check_string, end - end_check_string);
|
||||
auto count_view = source->text_str()
|
||||
.substr(end_check_string, end - end_check_string)
|
||||
.str();
|
||||
count = c10::stoll(std::string(count_view.begin(), count_view.end()));
|
||||
end_check_string = end + 2; // add ':' and the space
|
||||
}
|
||||
auto check = Check(
|
||||
type,
|
||||
source->text().substr(end_check_string, end_line - end_check_string),
|
||||
source->text_str()
|
||||
.substr(end_check_string, end_line - end_check_string)
|
||||
.str(),
|
||||
count);
|
||||
addCheck(check);
|
||||
if (exactly) {
|
||||
@ -253,32 +254,30 @@ struct FileCheckImpl {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t findNextStart(
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
size_t prev_end) {
|
||||
size_t start = source->text().find('#', prev_end);
|
||||
size_t findNextStart(const std::shared_ptr<Source>& source, size_t prev_end) {
|
||||
size_t start = source->text_str().find("#", prev_end);
|
||||
if (start == std::string::npos) {
|
||||
return start;
|
||||
}
|
||||
start += 1;
|
||||
static constexpr size_t max_whitespace = 6;
|
||||
size_t i = 0;
|
||||
while (start + i < source->text().size() && i < max_whitespace) {
|
||||
auto c = source->text().at(start + i);
|
||||
while (start + i < source->size() && i < max_whitespace) {
|
||||
auto c = source->char_at(start + i);
|
||||
if (c != ' ' && c != '\t') {
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
static const std::string check = "CHECK";
|
||||
if (source->text().substr(start + i, check.size()) == check) {
|
||||
if (source->text_str().substr(start + i, check.size()) == check) {
|
||||
return start + i + check.size();
|
||||
} else {
|
||||
return findNextStart(source, start + i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
void parseStrings(const std::shared_ptr<SourceView>& source) {
|
||||
void parseStrings(const std::shared_ptr<Source>& source) {
|
||||
size_t start = 0;
|
||||
start = findNextStart(source, 0);
|
||||
while (start != std::string::npos) {
|
||||
@ -297,7 +296,7 @@ struct FileCheckImpl {
|
||||
|
||||
void doCheckNot(
|
||||
const std::vector<Check>& nots,
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
const std::shared_ptr<Source>& source,
|
||||
const SourceRange& prev,
|
||||
const SourceRange& next) {
|
||||
auto start = prev.end(); // inclusive
|
||||
@ -314,7 +313,7 @@ struct FileCheckImpl {
|
||||
// Checks that source token is highlighted, does not advance search range.
|
||||
void doCheckSourceHighlighted(
|
||||
const Check& check,
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
const std::shared_ptr<Source>& source,
|
||||
size_t start_offset) {
|
||||
auto construct_error_and_throw = [&](size_t error_start_pos) {
|
||||
SourceRange error_range(
|
||||
@ -330,8 +329,8 @@ struct FileCheckImpl {
|
||||
size_t search_start_offset = start_offset;
|
||||
bool found_token_at_least_once = false;
|
||||
size_t pos = search_start_offset;
|
||||
while (pos < source->text().size()) {
|
||||
pos = source->text().find(check.search_str_, search_start_offset);
|
||||
while (pos < source->size()) {
|
||||
pos = source->text_str().find(check.search_str_, search_start_offset);
|
||||
if (pos == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
@ -349,17 +348,16 @@ struct FileCheckImpl {
|
||||
auto highlight_start_offset =
|
||||
source->offset_for_line(highlight_lineno) + col;
|
||||
auto highlight_end_offset = std::min(
|
||||
highlight_start_offset + check.search_str_.size(),
|
||||
source->text().size());
|
||||
highlight_start_offset + check.search_str_.size(), source->size());
|
||||
|
||||
if (highlight_end_offset >= source->text().size()) {
|
||||
if (highlight_end_offset >= source->size()) {
|
||||
construct_error_and_throw(pos);
|
||||
}
|
||||
|
||||
bool found_highlight = true;
|
||||
for (const auto posi :
|
||||
c10::irange(highlight_start_offset, highlight_end_offset)) {
|
||||
if (source->text()[posi] != '~') {
|
||||
if (source->char_at(posi) != '~') {
|
||||
found_highlight = false;
|
||||
}
|
||||
}
|
||||
@ -390,7 +388,7 @@ struct FileCheckImpl {
|
||||
|
||||
SourceRange matchDagGroup(
|
||||
const std::vector<Check>& group,
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
const std::shared_ptr<Source>& source,
|
||||
const SourceRange& prev) {
|
||||
size_t group_beg = std::string::npos;
|
||||
size_t group_end = 0;
|
||||
@ -408,7 +406,7 @@ struct FileCheckImpl {
|
||||
|
||||
SourceRange matchGroup(
|
||||
const std::vector<Check>& group,
|
||||
const std::shared_ptr<SourceView>& source,
|
||||
const std::shared_ptr<Source>& source,
|
||||
const SourceRange& prev) {
|
||||
AT_ASSERT(group.size() != 0);
|
||||
CheckType type = group[0].type_;
|
||||
@ -467,7 +465,7 @@ struct FileCheckImpl {
|
||||
return SourceRange(source, start_range, end_range);
|
||||
}
|
||||
|
||||
void doChecks(const std::shared_ptr<SourceView>& source) {
|
||||
void doChecks(const std::shared_ptr<Source>& source) {
|
||||
SourceRange prev(source, 0, 0);
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
const auto& curr_group = groups[i];
|
||||
@ -484,7 +482,7 @@ struct FileCheckImpl {
|
||||
++i; // already checked the group after
|
||||
} else {
|
||||
SourceRange end_of_file(
|
||||
source, source->text().size() + 1, source->text().size() + 1);
|
||||
source, source->size() + 1, source->size() + 1);
|
||||
doCheckNot(curr_group, source, prev, end_of_file);
|
||||
}
|
||||
}
|
||||
|
@ -258,7 +258,22 @@ def get_model_info(
|
||||
# Parse debug info and add begin/end markers if not present
|
||||
# to ensure that we cover the entire source code.
|
||||
debug_info_t = pickle.loads(raw_debug)
|
||||
assert isinstance(debug_info_t, tuple)
|
||||
text_table = None
|
||||
|
||||
if (len(debug_info_t) == 3 and
|
||||
isinstance(debug_info_t[0], str) and
|
||||
debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'):
|
||||
_, text_table, content = debug_info_t
|
||||
|
||||
def parse_new_format(line):
|
||||
# (0, (('', '', 0), 0, 0))
|
||||
num, ((text_indexes, fname_idx, offset), start, end), tag = line
|
||||
text = ''.join(text_table[x] for x in text_indexes) # type: ignore[index]
|
||||
fname = text_table[fname_idx] # type: ignore[index]
|
||||
return num, ((text, fname, offset), start, end), tag
|
||||
|
||||
debug_info_t = map(parse_new_format, content)
|
||||
|
||||
debug_info = list(debug_info_t)
|
||||
if not debug_info:
|
||||
debug_info.append((0, (('', '', 0), 0, 0)))
|
||||
|
Reference in New Issue
Block a user