mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Load original SourceRanges on import (#22180)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22180 ghimport-source-id: efa46dcb845c099f0a746f523901ab2c2cd3b004 Test Plan: Imported from OSS Differential Revision: D15981425 Pulled By: jamesr66a fbshipit-source-id: bef682bd13c1a5be95bdb97e025690c6f2d523d3
This commit is contained in:
committed by
Facebook Github Bot
parent
2c2a913a4f
commit
ffa15d2285
@ -440,6 +440,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/source_range_serialization.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
|
||||
|
||||
@ -24,6 +24,7 @@ import inspect
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import textwrap
|
||||
|
||||
@ -120,6 +121,8 @@ class JitTestCase(TestCase):
|
||||
self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
|
||||
main_module = archive.open('archive/code/archive.py')
|
||||
main_module_code = "".join([line.decode() for line in main_module])
|
||||
main_module_debug_file = archive.open('archive/debug/archive.pkl')
|
||||
main_module_debug = pickle.load(main_module_debug_file)
|
||||
except RuntimeError as e:
|
||||
if not self._isHookExceptionOk(e):
|
||||
raise
|
||||
@ -138,8 +141,11 @@ class JitTestCase(TestCase):
|
||||
archive2 = zipfile.ZipFile(saved_module_buffer_2)
|
||||
main_module_2 = archive2.open('archive/code/archive.py')
|
||||
main_module_2_code = "".join([line.decode() for line in main_module_2])
|
||||
main_module_2_debug_file = archive.open('archive/debug/archive.pkl')
|
||||
main_module_2_debug = pickle.load(main_module_2_debug_file)
|
||||
|
||||
self.assertMultiLineEqual(main_module_code, main_module_2_code)
|
||||
self.assertEqual(main_module_debug, main_module_2_debug)
|
||||
|
||||
def getExportImportCopy(self, m, also_test_file=True, map_location=None):
|
||||
if isinstance(m, torch._C.Function):
|
||||
|
||||
@ -3323,6 +3323,82 @@ def foo(xyz):
|
||||
fc.run(scripted.graph)
|
||||
fc.run(str(scripted.graph))
|
||||
|
||||
def test_serialized_source_ranges(self):
|
||||
|
||||
class FooTest(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, x, w):
|
||||
return torch.mm(x, w.t())
|
||||
|
||||
ft = FooTest()
|
||||
loaded = self.getExportImportCopy(ft)
|
||||
_, lineno = inspect.getsourcelines(FooTest)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'test_jit.py:{}'.format(lineno + 3)):
|
||||
loaded(torch.rand(3, 4), torch.rand(30, 40))
|
||||
|
||||
def test_serialized_source_ranges2(self):
|
||||
|
||||
class FooTest2(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self):
|
||||
raise RuntimeError('foo')
|
||||
|
||||
_, lineno = inspect.getsourcelines(FooTest2)
|
||||
|
||||
with self.assertRaisesRegex(torch._C.JITException, 'test_jit.py:{}'.format(lineno + 3)):
|
||||
ft = FooTest2()
|
||||
loaded = self.getExportImportCopy(ft)
|
||||
loaded()
|
||||
|
||||
def test_serialized_source_ranges_dont_jitter(self):
|
||||
class FooTest3(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, lim):
|
||||
first = 1
|
||||
second = 1
|
||||
i = 1
|
||||
somenum = 5
|
||||
dontmutateme = 3
|
||||
third = 0
|
||||
while bool(i < lim):
|
||||
third = first + second
|
||||
first = second
|
||||
second = third
|
||||
j = 0
|
||||
while j < 10:
|
||||
somenum = somenum * 2
|
||||
j = j + 1
|
||||
i = i + j
|
||||
i = i + dontmutateme
|
||||
|
||||
st = second + third
|
||||
fs = first + second
|
||||
return third, st, fs
|
||||
|
||||
ft3 = FooTest3()
|
||||
|
||||
def debug_records_from_mod(mod):
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(ft3, buffer)
|
||||
buffer.seek(0)
|
||||
archive = zipfile.ZipFile(buffer)
|
||||
debug_file = archive.open('archive/debug/archive.pkl')
|
||||
return pickle.load(debug_file), buffer
|
||||
|
||||
records1, buffer = debug_records_from_mod(ft3)
|
||||
|
||||
buffer.seek(0)
|
||||
loaded = torch.jit.load(buffer)
|
||||
records2, buffer = debug_records_from_mod(loaded)
|
||||
|
||||
buffer.seek(0)
|
||||
loaded2 = torch.jit.load(buffer)
|
||||
records3, _ = debug_records_from_mod(loaded2)
|
||||
|
||||
self.assertEqual(records1, records2)
|
||||
self.assertEqual(records2, records3)
|
||||
|
||||
def test_tensor_shape(self):
|
||||
x = torch.empty(34, 56, 78)
|
||||
|
||||
|
||||
@ -123,6 +123,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/script/class_type.cpp",
|
||||
"torch/csrc/jit/script/parser.cpp",
|
||||
"torch/csrc/jit/script/jit_exception.cpp",
|
||||
"torch/csrc/jit/source_range_serialization.cpp",
|
||||
"torch/csrc/jit/testing/file_check.cpp",
|
||||
"torch/csrc/jit/import_source.cpp",
|
||||
"torch/csrc/jit/hooks_for_testing.cpp",
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/python_print.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include <torch/csrc/jit/source_range_serializer.h>
|
||||
#include <torch/csrc/jit/source_range_serialization.h>
|
||||
|
||||
#include <caffe2/core/types.h>
|
||||
#include <caffe2/proto/caffe2_pb.h>
|
||||
@ -931,21 +931,14 @@ void ScriptModuleSerializer::convertModule(
|
||||
// Write out debug records
|
||||
torch::RecordRef* debug_record =
|
||||
module_def->mutable_torchscript_debug_arena();
|
||||
Pickler p;
|
||||
SourceRangeSerializer srs;
|
||||
p.start();
|
||||
p.startTuple();
|
||||
for (const auto& range : source_ranges) {
|
||||
std::vector<c10::IValue> row_elems{(int64_t)range.bytes,
|
||||
srs.serialize(range.range)};
|
||||
p.addIValue(c10::ivalue::Tuple::create(std::move(row_elems)));
|
||||
}
|
||||
p.endTuple();
|
||||
p.finish();
|
||||
|
||||
SourceRangePickler source_range_pickler;
|
||||
source_range_pickler.pickle(source_ranges);
|
||||
const auto& range_data = source_range_pickler.get_data();
|
||||
std::stringstream debug_filename;
|
||||
debug_filename << "debug/" << module_name.str() << ".pkl";
|
||||
writer_.writeRecord(
|
||||
debug_filename.str(), p.stack().data(), p.stack().size());
|
||||
debug_filename.str(), range_data.data(), range_data.size());
|
||||
debug_record->set_key(debug_filename.str());
|
||||
}
|
||||
|
||||
|
||||
@ -9,6 +9,8 @@
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include <torch/csrc/jit/script/script_type_parser.h>
|
||||
#include <torch/csrc/jit/source_range_serialization.h>
|
||||
#include <torch/csrc/jit/source_range_serialization_impl.h>
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/types.h"
|
||||
@ -328,6 +330,20 @@ void ScriptModuleDeserializer::convertModule(
|
||||
module.register_attribute(
|
||||
attr_def.name(), typeParser.parseType(attr_def.type()), ivalue);
|
||||
}
|
||||
|
||||
// If present, load in the table of source ranges from the original
|
||||
// generating code.
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr;
|
||||
if (module_def.has_torchscript_debug_arena()) {
|
||||
at::DataPtr data;
|
||||
size_t size;
|
||||
std::tie(data, size) =
|
||||
reader_.getRecord(module_def.torchscript_debug_arena().key());
|
||||
|
||||
gen_ranges =
|
||||
std::make_shared<ConcreteSourceRangeUnpickler>(std::move(data), size);
|
||||
}
|
||||
|
||||
if (module_def.has_torchscript_arena()) {
|
||||
at::DataPtr data;
|
||||
size_t size;
|
||||
@ -337,7 +353,8 @@ void ScriptModuleDeserializer::convertModule(
|
||||
auto src = std::make_shared<Source>(
|
||||
std::string(static_cast<const char*>(data.get()), size),
|
||||
module_def.torchscript_arena().key(),
|
||||
1);
|
||||
1,
|
||||
std::move(gen_ranges));
|
||||
|
||||
std::function<void(const std::string&)> import_callback =
|
||||
[this](const std::string& qualifier) { importCallback(qualifier); };
|
||||
|
||||
@ -171,10 +171,13 @@ struct PythonPrintPass {
|
||||
SourceRangeStack source_range_stack_ = {SourceRange("")};
|
||||
|
||||
struct WithSourceRange {
|
||||
explicit WithSourceRange(SourceRangeStack* stack, SourceRange sr)
|
||||
: stack(stack) {
|
||||
explicit WithSourceRange(SourceRangeStack* stack, Node* n) : stack(stack) {
|
||||
TORCH_INTERNAL_ASSERT(stack);
|
||||
stack->push_back(std::move(sr));
|
||||
if (auto gen_source = n->sourceRange().findSourceRangeThatGenerated()) {
|
||||
stack->push_back(std::move(gen_source.value()));
|
||||
} else {
|
||||
stack->push_back(std::move(n->sourceRange()));
|
||||
}
|
||||
}
|
||||
|
||||
~WithSourceRange() {
|
||||
@ -190,6 +193,13 @@ struct PythonPrintPass {
|
||||
TaggedStringStream(TaggedStringStream&& rhs) = default;
|
||||
|
||||
TaggedStringStream& operator<<(const std::string& s) {
|
||||
// This prevents having redundant entries at the same offset,
|
||||
// which can happen for example in printValueList when begin
|
||||
// and end are the empty string.
|
||||
if (s.size() == 0) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
if (!ranges_.size() || ranges_.back().range != srs_->back()) {
|
||||
ranges_.emplace_back((size_t)oss_.tellp(), srs_->back());
|
||||
}
|
||||
@ -233,6 +243,21 @@ struct PythonPrintPass {
|
||||
return ranges_;
|
||||
}
|
||||
|
||||
// Write out this TaggedStringStream's text and source ranges to
|
||||
// os and source_ranges_out, respectively. stream_pos gives
|
||||
// the byte offset into the current stream, so we can accurately
|
||||
// record source ranges as byte offsets.
|
||||
void print(
|
||||
std::ostream& os,
|
||||
SourceRangeRecords* source_ranges_out,
|
||||
int64_t stream_pos) {
|
||||
os << str();
|
||||
for (const auto& x : ranges()) {
|
||||
source_ranges_out->push_back(x);
|
||||
source_ranges_out->back().bytes += stream_pos;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::ostringstream oss_;
|
||||
std::vector<TaggedRange> ranges_;
|
||||
@ -756,7 +781,7 @@ struct PythonPrintPass {
|
||||
}
|
||||
|
||||
void printNode(Node* node, bool print_const) {
|
||||
WithSourceRange guard(&source_range_stack_, node->sourceRange());
|
||||
WithSourceRange guard(&source_range_stack_, node);
|
||||
// Check for class dependencies. If this node inputs or outputs a class
|
||||
// type, we need to add it to our table of dependencies.
|
||||
for (const auto input : node->inputs()) {
|
||||
@ -1149,8 +1174,7 @@ struct PythonPrintPass {
|
||||
Graph& graph = *func.graph();
|
||||
used_names_.clear(); // each graph can reuse local names
|
||||
|
||||
WithSourceRange guard(
|
||||
&source_range_stack_, graph.param_node()->sourceRange());
|
||||
WithSourceRange guard(&source_range_stack_, graph.param_node());
|
||||
|
||||
indent();
|
||||
body_ << "def " << func.name() << "(";
|
||||
@ -1250,8 +1274,9 @@ struct PythonPrintPass {
|
||||
}
|
||||
|
||||
void print(std::ostream& out, SourceRangeRecords& source_ranges_out) {
|
||||
out << getImports() << body_.str();
|
||||
source_ranges_out = body_.ranges();
|
||||
out << getImports();
|
||||
int64_t source_offset = out.tellp();
|
||||
body_.print(out, &source_ranges_out, source_offset);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -12,16 +12,6 @@ struct Method;
|
||||
struct Module;
|
||||
} // namespace script
|
||||
|
||||
// A pair of (byte offset, SourceRange) describing a specific segment
|
||||
// of the output stream
|
||||
struct TaggedRange {
|
||||
TaggedRange(size_t bytes, SourceRange range)
|
||||
: bytes(bytes), range(std::move(range)) {}
|
||||
size_t bytes;
|
||||
SourceRange range;
|
||||
};
|
||||
using SourceRangeRecords = std::vector<TaggedRange>;
|
||||
|
||||
TORCH_API void PythonPrint(
|
||||
std::ostream& out,
|
||||
SourceRangeRecords& source_ranges_out,
|
||||
|
||||
@ -1,8 +1,17 @@
|
||||
#include <torch/csrc/jit/source_range.h>
|
||||
#include <torch/csrc/jit/source_range_serialization.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
|
||||
const SourceRange& range) {
|
||||
if (!gen_ranges_) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
return gen_ranges_->findSourceRangeThatGenerated(range);
|
||||
}
|
||||
|
||||
// a range of a shared string 'file_' with
|
||||
C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
|
||||
const std::string& str = source_->text();
|
||||
@ -56,6 +65,13 @@ C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
|
||||
out << str.substr(end_line, end_highlight - end_line);
|
||||
if (!str.empty() && str.back() != '\n')
|
||||
out << "\n";
|
||||
// Retrieve original SourceRange, if present.
|
||||
if (source_) {
|
||||
if (auto orig_source_range = findSourceRangeThatGenerated()) {
|
||||
out << "Compiled from code ";
|
||||
orig_source_range->highlight(out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
||||
@ -8,6 +8,9 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct SourceRangeUnpickler;
|
||||
struct SourceRange;
|
||||
|
||||
// Source represents a code segment. It keeps track of:
|
||||
// - text : the text of the code segment
|
||||
// - filename (optional) : if present, represents the name of the file from
|
||||
@ -15,18 +18,25 @@ namespace jit {
|
||||
// - starting_line_no : represents the line in the original file where the
|
||||
// code segment started.
|
||||
struct Source {
|
||||
explicit Source(std::string text)
|
||||
: text_(std::move(text)), filename_(c10::nullopt) {
|
||||
explicit Source(
|
||||
std::string text,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: text_(std::move(text)),
|
||||
filename_(c10::nullopt),
|
||||
starting_line_no_(0),
|
||||
gen_ranges_(std::move(gen_ranges)) {
|
||||
calc_line_start_offsets();
|
||||
}
|
||||
|
||||
Source(
|
||||
std::string text,
|
||||
c10::optional<std::string> filename,
|
||||
size_t starting_line_no)
|
||||
size_t starting_line_no,
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
|
||||
: text_(std::move(text)),
|
||||
filename_(std::move(filename)),
|
||||
starting_line_no_(starting_line_no) {
|
||||
starting_line_no_(starting_line_no),
|
||||
gen_ranges_(std::move(gen_ranges)) {
|
||||
calc_line_start_offsets();
|
||||
}
|
||||
|
||||
@ -67,6 +77,9 @@ struct Source {
|
||||
return starting_line_no_;
|
||||
}
|
||||
|
||||
c10::optional<SourceRange> findSourceRangeThatGenerated(
|
||||
const SourceRange& range);
|
||||
|
||||
private:
|
||||
void calc_line_start_offsets() {
|
||||
size_t pos = 0;
|
||||
@ -82,6 +95,8 @@ struct Source {
|
||||
// Starting offsets for lines into the source. e.g. line 0 starts at
|
||||
// line_starting_offsets_[0], etc.
|
||||
std::vector<size_t> line_starting_offsets_;
|
||||
|
||||
std::shared_ptr<SourceRangeUnpickler> gen_ranges_;
|
||||
};
|
||||
|
||||
// A SourceRange is a view into a Source, that points to a subset of the source,
|
||||
@ -139,6 +154,13 @@ struct CAFFE2_API SourceRange {
|
||||
bool operator!=(const SourceRange& rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
c10::optional<SourceRange> findSourceRangeThatGenerated() const {
|
||||
if (!source_) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
return source_->findSourceRangeThatGenerated(*this);
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Source> source_;
|
||||
@ -151,5 +173,15 @@ inline std::ostream& operator<<(std::ostream& out, const SourceRange& range) {
|
||||
return out;
|
||||
}
|
||||
|
||||
// A pair of (byte offset, SourceRange) describing a specific segment
|
||||
// of the output stream
|
||||
struct TaggedRange {
|
||||
TaggedRange(size_t bytes, SourceRange range)
|
||||
: bytes(bytes), range(std::move(range)) {}
|
||||
size_t bytes;
|
||||
SourceRange range;
|
||||
};
|
||||
using SourceRangeRecords = std::vector<TaggedRange>;
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
148
torch/csrc/jit/source_range_serialization.cpp
Normal file
148
torch/csrc/jit/source_range_serialization.cpp
Normal file
@ -0,0 +1,148 @@
|
||||
#include <torch/csrc/jit/source_range_serialization.h>
|
||||
#include <torch/csrc/jit/source_range_serialization_impl.h>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
class SourceRangeSerializer {
|
||||
public:
|
||||
// Serialize SourceRange as Tuple[SourceType, int, int]
|
||||
// where SourceType = Tuple[str, Optional[str], int, List[int]],
|
||||
// the serialized form of Source
|
||||
c10::IValue serialize(const SourceRange& sr);
|
||||
|
||||
private:
|
||||
// Serialize Source as Tuple[str, Optional[str], int, List[int]]
|
||||
// This caches serialized sources, since many SourceRanges can
|
||||
// refer to the same one.
|
||||
c10::IValue serialize_source(const std::shared_ptr<Source>& s);
|
||||
|
||||
std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
|
||||
};
|
||||
|
||||
class SourceRangeDeserializer {
|
||||
public:
|
||||
SourceRange deserialize(const c10::IValue& iv) {
|
||||
auto tup_elems = iv.toTuple()->elements();
|
||||
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
|
||||
std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]);
|
||||
int64_t start_ = tup_elems[1].toInt();
|
||||
int64_t end_ = tup_elems[2].toInt();
|
||||
return SourceRange(source_, start_, end_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Source> deserialize_source(const c10::IValue& iv) {
|
||||
auto tup = iv.toTuple();
|
||||
if (cached_sources.count(tup)) {
|
||||
return cached_sources.at(tup);
|
||||
}
|
||||
|
||||
auto tup_elems = tup->elements();
|
||||
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
|
||||
std::string text_ = tup_elems[0].toString()->string();
|
||||
c10::optional<std::string> filename_ =
|
||||
tup_elems[1].toOptional<std::string>();
|
||||
int64_t starting_line_no_ = tup_elems[2].toInt();
|
||||
|
||||
auto source = std::make_shared<Source>(
|
||||
std::move(text_), std::move(filename_), starting_line_no_);
|
||||
cached_sources[tup] = source;
|
||||
return source;
|
||||
}
|
||||
|
||||
std::unordered_map<
|
||||
c10::intrusive_ptr<c10::ivalue::Tuple>,
|
||||
std::shared_ptr<Source>>
|
||||
cached_sources;
|
||||
};
|
||||
|
||||
c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) {
|
||||
std::vector<c10::IValue> elements = {
|
||||
serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end()};
|
||||
return c10::ivalue::Tuple::create(std::move(elements));
|
||||
}
|
||||
|
||||
c10::IValue SourceRangeSerializer::serialize_source(
|
||||
const std::shared_ptr<Source>& s) {
|
||||
if (serialized_sources.count(s)) {
|
||||
return serialized_sources.at(s);
|
||||
}
|
||||
std::vector<c10::IValue> elements{
|
||||
s->text(), s->filename(), (int64_t)s->starting_line_no()};
|
||||
auto serialized = c10::ivalue::Tuple::create(std::move(elements));
|
||||
serialized_sources[s] = serialized;
|
||||
return serialized;
|
||||
}
|
||||
|
||||
SourceRangePickler::SourceRangePickler()
|
||||
: p(new Pickler()), srs(new SourceRangeSerializer()) {}
|
||||
|
||||
void SourceRangePickler::pickle(const SourceRangeRecords& ranges) {
|
||||
p->start();
|
||||
p->startTuple();
|
||||
for (const auto& range : ranges) {
|
||||
std::vector<c10::IValue> row_elems{(int64_t)range.bytes,
|
||||
srs->serialize(range.range)};
|
||||
p->addIValue(c10::ivalue::Tuple::create(std::move(row_elems)));
|
||||
}
|
||||
p->endTuple();
|
||||
p->finish();
|
||||
}
|
||||
|
||||
const std::vector<char>& SourceRangePickler::get_data() {
|
||||
return p->stack();
|
||||
}
|
||||
|
||||
ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
|
||||
at::DataPtr&& data,
|
||||
size_t size)
|
||||
: data(std::move(data)),
|
||||
size(size),
|
||||
deserializer(new SourceRangeDeserializer()),
|
||||
unpickled_records(nullptr) {}
|
||||
|
||||
void ConcreteSourceRangeUnpickler::unpickle() {
|
||||
if (unpickled_records) {
|
||||
return;
|
||||
}
|
||||
|
||||
Unpickler up(data.get(), size, nullptr);
|
||||
auto ivalues = up.parse_ivalue_list();
|
||||
|
||||
unpickled_records = std::make_shared<SourceRangeRecords>();
|
||||
for (auto& val : ivalues) {
|
||||
auto tup_elems = val.toTuple()->elements();
|
||||
int64_t offset = tup_elems[0].toInt();
|
||||
auto source_range = deserializer->deserialize(tup_elems[1]);
|
||||
unpickled_records->emplace_back(offset, std::move(source_range));
|
||||
}
|
||||
}
|
||||
|
||||
c10::optional<SourceRange> ConcreteSourceRangeUnpickler::
|
||||
findSourceRangeThatGenerated(const SourceRange& range) {
|
||||
unpickle();
|
||||
|
||||
auto query = TaggedRange(range.start(), SourceRange{""});
|
||||
auto entry = std::upper_bound(
|
||||
unpickled_records->begin(),
|
||||
unpickled_records->end(),
|
||||
query,
|
||||
[](const TaggedRange& a, const TaggedRange& b) -> bool {
|
||||
return a.bytes < b.bytes;
|
||||
});
|
||||
|
||||
// NB: must decrement iterator since upper_bound finds the element
|
||||
// *greater than* the query.
|
||||
if (entry != unpickled_records->begin()) {
|
||||
return (entry - 1)->range;
|
||||
}
|
||||
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
42
torch/csrc/jit/source_range_serialization.h
Normal file
42
torch/csrc/jit/source_range_serialization.h
Normal file
@ -0,0 +1,42 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <torch/csrc/jit/source_range.h>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace c10 {
|
||||
struct IValue;
|
||||
}
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
class Pickler;
|
||||
class SourceRangeSerializer;
|
||||
class SourceRangeDeserializer;
|
||||
|
||||
class SourceRangePickler {
|
||||
public:
|
||||
SourceRangePickler();
|
||||
|
||||
void pickle(const SourceRangeRecords& ranges);
|
||||
|
||||
const std::vector<char>& get_data();
|
||||
|
||||
private:
|
||||
std::shared_ptr<Pickler> p;
|
||||
std::shared_ptr<SourceRangeSerializer> srs;
|
||||
};
|
||||
|
||||
class SourceRangeUnpickler {
|
||||
public:
|
||||
virtual c10::optional<SourceRange> findSourceRangeThatGenerated(
|
||||
const SourceRange& range) = 0;
|
||||
|
||||
virtual ~SourceRangeUnpickler() {}
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
30
torch/csrc/jit/source_range_serialization_impl.h
Normal file
30
torch/csrc/jit/source_range_serialization_impl.h
Normal file
@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/pickler.h>
|
||||
#include <torch/csrc/jit/source_range_serialization.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Do this clownyness with virtual functions because of the split
|
||||
// between ATen core and torch
|
||||
|
||||
class ConcreteSourceRangeUnpickler : public SourceRangeUnpickler {
|
||||
public:
|
||||
ConcreteSourceRangeUnpickler(at::DataPtr&& data, size_t size);
|
||||
|
||||
c10::optional<SourceRange> findSourceRangeThatGenerated(
|
||||
const SourceRange& range) override;
|
||||
|
||||
private:
|
||||
at::DataPtr data;
|
||||
size_t size;
|
||||
|
||||
void unpickle();
|
||||
|
||||
std::shared_ptr<SourceRangeDeserializer> deserializer;
|
||||
std::shared_ptr<SourceRangeRecords> unpickled_records;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
@ -1,39 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/source_range.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
class SourceRangeSerializer {
|
||||
public:
|
||||
// Serialize SourceRange as Tuple[SourceType, int, int]
|
||||
// where SourceType = Tuple[str, Optional[str], int, List[int]],
|
||||
// the serialized form of Source
|
||||
c10::IValue serialize(const SourceRange& sr) {
|
||||
std::vector<c10::IValue> elements = {
|
||||
serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end()};
|
||||
return c10::ivalue::Tuple::create(std::move(elements));
|
||||
}
|
||||
|
||||
private:
|
||||
// Serialize Source as Tuple[str, Optional[str], int, List[int]]
|
||||
// This caches serialized sources, since many SourceRanges can
|
||||
// refer to the same one.
|
||||
c10::IValue serialize_source(const std::shared_ptr<Source>& s) {
|
||||
if (serialized_sources.count(s)) {
|
||||
return serialized_sources.at(s);
|
||||
}
|
||||
std::vector<c10::IValue> elements{
|
||||
s->text(), s->filename(), (int64_t)s->starting_line_no()};
|
||||
auto serialized = c10::ivalue::Tuple::create(std::move(elements));
|
||||
serialized_sources[s] = serialized;
|
||||
return serialized;
|
||||
}
|
||||
|
||||
std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
Reference in New Issue
Block a user