Load original SourceRanges on import (#22180)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22180
ghimport-source-id: efa46dcb845c099f0a746f523901ab2c2cd3b004

Test Plan: Imported from OSS

Differential Revision: D15981425

Pulled By: jamesr66a

fbshipit-source-id: bef682bd13c1a5be95bdb97e025690c6f2d523d3
This commit is contained in:
James Reed
2019-07-01 21:11:12 -07:00
committed by Facebook Github Bot
parent 2c2a913a4f
commit ffa15d2285
14 changed files with 413 additions and 75 deletions

View File

@ -440,6 +440,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp
${TORCH_SRC_DIR}/csrc/jit/source_range_serialization.cpp
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp

View File

@ -24,6 +24,7 @@ import inspect
import io
import math
import os
import pickle
import tempfile
import textwrap
@ -120,6 +121,8 @@ class JitTestCase(TestCase):
self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
main_module = archive.open('archive/code/archive.py')
main_module_code = "".join([line.decode() for line in main_module])
main_module_debug_file = archive.open('archive/debug/archive.pkl')
main_module_debug = pickle.load(main_module_debug_file)
except RuntimeError as e:
if not self._isHookExceptionOk(e):
raise
@ -138,8 +141,11 @@ class JitTestCase(TestCase):
archive2 = zipfile.ZipFile(saved_module_buffer_2)
main_module_2 = archive2.open('archive/code/archive.py')
main_module_2_code = "".join([line.decode() for line in main_module_2])
main_module_2_debug_file = archive.open('archive/debug/archive.pkl')
main_module_2_debug = pickle.load(main_module_2_debug_file)
self.assertMultiLineEqual(main_module_code, main_module_2_code)
self.assertEqual(main_module_debug, main_module_2_debug)
def getExportImportCopy(self, m, also_test_file=True, map_location=None):
if isinstance(m, torch._C.Function):

View File

@ -3323,6 +3323,82 @@ def foo(xyz):
fc.run(scripted.graph)
fc.run(str(scripted.graph))
def test_serialized_source_ranges(self):
class FooTest(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, w):
return torch.mm(x, w.t())
ft = FooTest()
loaded = self.getExportImportCopy(ft)
_, lineno = inspect.getsourcelines(FooTest)
with self.assertRaisesRegex(RuntimeError, 'test_jit.py:{}'.format(lineno + 3)):
loaded(torch.rand(3, 4), torch.rand(30, 40))
def test_serialized_source_ranges2(self):
class FooTest2(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self):
raise RuntimeError('foo')
_, lineno = inspect.getsourcelines(FooTest2)
with self.assertRaisesRegex(torch._C.JITException, 'test_jit.py:{}'.format(lineno + 3)):
ft = FooTest2()
loaded = self.getExportImportCopy(ft)
loaded()
def test_serialized_source_ranges_dont_jitter(self):
class FooTest3(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, lim):
first = 1
second = 1
i = 1
somenum = 5
dontmutateme = 3
third = 0
while bool(i < lim):
third = first + second
first = second
second = third
j = 0
while j < 10:
somenum = somenum * 2
j = j + 1
i = i + j
i = i + dontmutateme
st = second + third
fs = first + second
return third, st, fs
ft3 = FooTest3()
def debug_records_from_mod(mod):
buffer = io.BytesIO()
torch.jit.save(ft3, buffer)
buffer.seek(0)
archive = zipfile.ZipFile(buffer)
debug_file = archive.open('archive/debug/archive.pkl')
return pickle.load(debug_file), buffer
records1, buffer = debug_records_from_mod(ft3)
buffer.seek(0)
loaded = torch.jit.load(buffer)
records2, buffer = debug_records_from_mod(loaded)
buffer.seek(0)
loaded2 = torch.jit.load(buffer)
records3, _ = debug_records_from_mod(loaded2)
self.assertEqual(records1, records2)
self.assertEqual(records2, records3)
def test_tensor_shape(self):
x = torch.empty(34, 56, 78)

View File

@ -123,6 +123,7 @@ libtorch_sources = [
"torch/csrc/jit/script/class_type.cpp",
"torch/csrc/jit/script/parser.cpp",
"torch/csrc/jit/script/jit_exception.cpp",
"torch/csrc/jit/source_range_serialization.cpp",
"torch/csrc/jit/testing/file_check.cpp",
"torch/csrc/jit/import_source.cpp",
"torch/csrc/jit/hooks_for_testing.cpp",

View File

@ -11,7 +11,7 @@
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/source_range_serializer.h>
#include <torch/csrc/jit/source_range_serialization.h>
#include <caffe2/core/types.h>
#include <caffe2/proto/caffe2_pb.h>
@ -931,21 +931,14 @@ void ScriptModuleSerializer::convertModule(
// Write out debug records
torch::RecordRef* debug_record =
module_def->mutable_torchscript_debug_arena();
Pickler p;
SourceRangeSerializer srs;
p.start();
p.startTuple();
for (const auto& range : source_ranges) {
std::vector<c10::IValue> row_elems{(int64_t)range.bytes,
srs.serialize(range.range)};
p.addIValue(c10::ivalue::Tuple::create(std::move(row_elems)));
}
p.endTuple();
p.finish();
SourceRangePickler source_range_pickler;
source_range_pickler.pickle(source_ranges);
const auto& range_data = source_range_pickler.get_data();
std::stringstream debug_filename;
debug_filename << "debug/" << module_name.str() << ".pkl";
writer_.writeRecord(
debug_filename.str(), p.stack().data(), p.stack().size());
debug_filename.str(), range_data.data(), range_data.size());
debug_record->set_key(debug_filename.str());
}

View File

@ -9,6 +9,8 @@
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/script/script_type_parser.h>
#include <torch/csrc/jit/source_range_serialization.h>
#include <torch/csrc/jit/source_range_serialization_impl.h>
#include "caffe2/core/common.h"
#include "caffe2/core/types.h"
@ -328,6 +330,20 @@ void ScriptModuleDeserializer::convertModule(
module.register_attribute(
attr_def.name(), typeParser.parseType(attr_def.type()), ivalue);
}
// If present, load in the table of source ranges from the original
// generating code.
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr;
if (module_def.has_torchscript_debug_arena()) {
at::DataPtr data;
size_t size;
std::tie(data, size) =
reader_.getRecord(module_def.torchscript_debug_arena().key());
gen_ranges =
std::make_shared<ConcreteSourceRangeUnpickler>(std::move(data), size);
}
if (module_def.has_torchscript_arena()) {
at::DataPtr data;
size_t size;
@ -337,7 +353,8 @@ void ScriptModuleDeserializer::convertModule(
auto src = std::make_shared<Source>(
std::string(static_cast<const char*>(data.get()), size),
module_def.torchscript_arena().key(),
1);
1,
std::move(gen_ranges));
std::function<void(const std::string&)> import_callback =
[this](const std::string& qualifier) { importCallback(qualifier); };

View File

@ -171,10 +171,13 @@ struct PythonPrintPass {
SourceRangeStack source_range_stack_ = {SourceRange("")};
struct WithSourceRange {
explicit WithSourceRange(SourceRangeStack* stack, SourceRange sr)
: stack(stack) {
explicit WithSourceRange(SourceRangeStack* stack, Node* n) : stack(stack) {
TORCH_INTERNAL_ASSERT(stack);
stack->push_back(std::move(sr));
if (auto gen_source = n->sourceRange().findSourceRangeThatGenerated()) {
stack->push_back(std::move(gen_source.value()));
} else {
stack->push_back(std::move(n->sourceRange()));
}
}
~WithSourceRange() {
@ -190,6 +193,13 @@ struct PythonPrintPass {
TaggedStringStream(TaggedStringStream&& rhs) = default;
TaggedStringStream& operator<<(const std::string& s) {
// This prevents having redundant entries at the same offset,
// which can happen for example in printValueList when begin
// and end are the empty string.
if (s.size() == 0) {
return *this;
}
if (!ranges_.size() || ranges_.back().range != srs_->back()) {
ranges_.emplace_back((size_t)oss_.tellp(), srs_->back());
}
@ -233,6 +243,21 @@ struct PythonPrintPass {
return ranges_;
}
// Write out this TaggedStringStream's text and source ranges to
// os and source_ranges_out, respectively. stream_pos gives
// the byte offset into the current stream, so we can accurately
// record source ranges as byte offsets.
void print(
std::ostream& os,
SourceRangeRecords* source_ranges_out,
int64_t stream_pos) {
os << str();
for (const auto& x : ranges()) {
source_ranges_out->push_back(x);
source_ranges_out->back().bytes += stream_pos;
}
}
private:
std::ostringstream oss_;
std::vector<TaggedRange> ranges_;
@ -756,7 +781,7 @@ struct PythonPrintPass {
}
void printNode(Node* node, bool print_const) {
WithSourceRange guard(&source_range_stack_, node->sourceRange());
WithSourceRange guard(&source_range_stack_, node);
// Check for class dependencies. If this node inputs or outputs a class
// type, we need to add it to our table of dependencies.
for (const auto input : node->inputs()) {
@ -1149,8 +1174,7 @@ struct PythonPrintPass {
Graph& graph = *func.graph();
used_names_.clear(); // each graph can reuse local names
WithSourceRange guard(
&source_range_stack_, graph.param_node()->sourceRange());
WithSourceRange guard(&source_range_stack_, graph.param_node());
indent();
body_ << "def " << func.name() << "(";
@ -1250,8 +1274,9 @@ struct PythonPrintPass {
}
void print(std::ostream& out, SourceRangeRecords& source_ranges_out) {
out << getImports() << body_.str();
source_ranges_out = body_.ranges();
out << getImports();
int64_t source_offset = out.tellp();
body_.print(out, &source_ranges_out, source_offset);
}
};

View File

@ -12,16 +12,6 @@ struct Method;
struct Module;
} // namespace script
// A pair of (byte offset, SourceRange) describing a specific segment
// of the output stream
struct TaggedRange {
TaggedRange(size_t bytes, SourceRange range)
: bytes(bytes), range(std::move(range)) {}
size_t bytes;
SourceRange range;
};
using SourceRangeRecords = std::vector<TaggedRange>;
TORCH_API void PythonPrint(
std::ostream& out,
SourceRangeRecords& source_ranges_out,

View File

@ -1,8 +1,17 @@
#include <torch/csrc/jit/source_range.h>
#include <torch/csrc/jit/source_range_serialization.h>
namespace torch {
namespace jit {
c10::optional<SourceRange> Source::findSourceRangeThatGenerated(
const SourceRange& range) {
if (!gen_ranges_) {
return c10::nullopt;
}
return gen_ranges_->findSourceRangeThatGenerated(range);
}
// a range of a shared string 'file_' with
C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
const std::string& str = source_->text();
@ -56,6 +65,13 @@ C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
out << str.substr(end_line, end_highlight - end_line);
if (!str.empty() && str.back() != '\n')
out << "\n";
// Retrieve original SourceRange, if present.
if (source_) {
if (auto orig_source_range = findSourceRangeThatGenerated()) {
out << "Compiled from code ";
orig_source_range->highlight(out);
}
}
}
} // namespace jit

View File

@ -8,6 +8,9 @@
namespace torch {
namespace jit {
struct SourceRangeUnpickler;
struct SourceRange;
// Source represents a code segment. It keeps track of:
// - text : the text of the code segment
// - filename (optional) : if present, represents the name of the file from
@ -15,18 +18,25 @@ namespace jit {
// - starting_line_no : represents the line in the original file where the
// code segment started.
struct Source {
explicit Source(std::string text)
: text_(std::move(text)), filename_(c10::nullopt) {
explicit Source(
std::string text,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: text_(std::move(text)),
filename_(c10::nullopt),
starting_line_no_(0),
gen_ranges_(std::move(gen_ranges)) {
calc_line_start_offsets();
}
Source(
std::string text,
c10::optional<std::string> filename,
size_t starting_line_no)
size_t starting_line_no,
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
: text_(std::move(text)),
filename_(std::move(filename)),
starting_line_no_(starting_line_no) {
starting_line_no_(starting_line_no),
gen_ranges_(std::move(gen_ranges)) {
calc_line_start_offsets();
}
@ -67,6 +77,9 @@ struct Source {
return starting_line_no_;
}
c10::optional<SourceRange> findSourceRangeThatGenerated(
const SourceRange& range);
private:
void calc_line_start_offsets() {
size_t pos = 0;
@ -82,6 +95,8 @@ struct Source {
// Starting offsets for lines into the source. e.g. line 0 starts at
// line_starting_offsets_[0], etc.
std::vector<size_t> line_starting_offsets_;
std::shared_ptr<SourceRangeUnpickler> gen_ranges_;
};
// A SourceRange is a view into a Source, that points to a subset of the source,
@ -139,6 +154,13 @@ struct CAFFE2_API SourceRange {
bool operator!=(const SourceRange& rhs) const {
return !(*this == rhs);
}
c10::optional<SourceRange> findSourceRangeThatGenerated() const {
if (!source_) {
return c10::nullopt;
}
return source_->findSourceRangeThatGenerated(*this);
}
private:
std::shared_ptr<Source> source_;
@ -151,5 +173,15 @@ inline std::ostream& operator<<(std::ostream& out, const SourceRange& range) {
return out;
}
// A pair of (byte offset, SourceRange) describing a specific segment
// of the output stream
struct TaggedRange {
TaggedRange(size_t bytes, SourceRange range)
: bytes(bytes), range(std::move(range)) {}
size_t bytes;
SourceRange range;
};
using SourceRangeRecords = std::vector<TaggedRange>;
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,148 @@
#include <torch/csrc/jit/source_range_serialization.h>
#include <torch/csrc/jit/source_range_serialization_impl.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/pickler.h>
namespace torch {
namespace jit {
class SourceRangeSerializer {
public:
// Serialize SourceRange as Tuple[SourceType, int, int]
// where SourceType = Tuple[str, Optional[str], int, List[int]],
// the serialized form of Source
c10::IValue serialize(const SourceRange& sr);
private:
// Serialize Source as Tuple[str, Optional[str], int, List[int]]
// This caches serialized sources, since many SourceRanges can
// refer to the same one.
c10::IValue serialize_source(const std::shared_ptr<Source>& s);
std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
};
class SourceRangeDeserializer {
public:
SourceRange deserialize(const c10::IValue& iv) {
auto tup_elems = iv.toTuple()->elements();
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]);
int64_t start_ = tup_elems[1].toInt();
int64_t end_ = tup_elems[2].toInt();
return SourceRange(source_, start_, end_);
}
private:
std::shared_ptr<Source> deserialize_source(const c10::IValue& iv) {
auto tup = iv.toTuple();
if (cached_sources.count(tup)) {
return cached_sources.at(tup);
}
auto tup_elems = tup->elements();
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
std::string text_ = tup_elems[0].toString()->string();
c10::optional<std::string> filename_ =
tup_elems[1].toOptional<std::string>();
int64_t starting_line_no_ = tup_elems[2].toInt();
auto source = std::make_shared<Source>(
std::move(text_), std::move(filename_), starting_line_no_);
cached_sources[tup] = source;
return source;
}
std::unordered_map<
c10::intrusive_ptr<c10::ivalue::Tuple>,
std::shared_ptr<Source>>
cached_sources;
};
c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) {
std::vector<c10::IValue> elements = {
serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end()};
return c10::ivalue::Tuple::create(std::move(elements));
}
c10::IValue SourceRangeSerializer::serialize_source(
const std::shared_ptr<Source>& s) {
if (serialized_sources.count(s)) {
return serialized_sources.at(s);
}
std::vector<c10::IValue> elements{
s->text(), s->filename(), (int64_t)s->starting_line_no()};
auto serialized = c10::ivalue::Tuple::create(std::move(elements));
serialized_sources[s] = serialized;
return serialized;
}
SourceRangePickler::SourceRangePickler()
: p(new Pickler()), srs(new SourceRangeSerializer()) {}
void SourceRangePickler::pickle(const SourceRangeRecords& ranges) {
p->start();
p->startTuple();
for (const auto& range : ranges) {
std::vector<c10::IValue> row_elems{(int64_t)range.bytes,
srs->serialize(range.range)};
p->addIValue(c10::ivalue::Tuple::create(std::move(row_elems)));
}
p->endTuple();
p->finish();
}
const std::vector<char>& SourceRangePickler::get_data() {
return p->stack();
}
ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
at::DataPtr&& data,
size_t size)
: data(std::move(data)),
size(size),
deserializer(new SourceRangeDeserializer()),
unpickled_records(nullptr) {}
void ConcreteSourceRangeUnpickler::unpickle() {
if (unpickled_records) {
return;
}
Unpickler up(data.get(), size, nullptr);
auto ivalues = up.parse_ivalue_list();
unpickled_records = std::make_shared<SourceRangeRecords>();
for (auto& val : ivalues) {
auto tup_elems = val.toTuple()->elements();
int64_t offset = tup_elems[0].toInt();
auto source_range = deserializer->deserialize(tup_elems[1]);
unpickled_records->emplace_back(offset, std::move(source_range));
}
}
c10::optional<SourceRange> ConcreteSourceRangeUnpickler::
findSourceRangeThatGenerated(const SourceRange& range) {
unpickle();
auto query = TaggedRange(range.start(), SourceRange{""});
auto entry = std::upper_bound(
unpickled_records->begin(),
unpickled_records->end(),
query,
[](const TaggedRange& a, const TaggedRange& b) -> bool {
return a.bytes < b.bytes;
});
// NB: must decrement iterator since upper_bound finds the element
// *greater than* the query.
if (entry != unpickled_records->begin()) {
return (entry - 1)->range;
}
return c10::nullopt;
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,42 @@
#pragma once
#include <c10/core/Allocator.h>
#include <torch/csrc/jit/source_range.h>
#include <unordered_map>
#include <vector>
namespace c10 {
struct IValue;
}
namespace torch {
namespace jit {
class Pickler;
class SourceRangeSerializer;
class SourceRangeDeserializer;
class SourceRangePickler {
public:
SourceRangePickler();
void pickle(const SourceRangeRecords& ranges);
const std::vector<char>& get_data();
private:
std::shared_ptr<Pickler> p;
std::shared_ptr<SourceRangeSerializer> srs;
};
class SourceRangeUnpickler {
public:
virtual c10::optional<SourceRange> findSourceRangeThatGenerated(
const SourceRange& range) = 0;
virtual ~SourceRangeUnpickler() {}
};
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,30 @@
#pragma once
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/source_range_serialization.h>
namespace torch {
namespace jit {
// Do this clownyness with virtual functions because of the split
// between ATen core and torch
class ConcreteSourceRangeUnpickler : public SourceRangeUnpickler {
public:
ConcreteSourceRangeUnpickler(at::DataPtr&& data, size_t size);
c10::optional<SourceRange> findSourceRangeThatGenerated(
const SourceRange& range) override;
private:
at::DataPtr data;
size_t size;
void unpickle();
std::shared_ptr<SourceRangeDeserializer> deserializer;
std::shared_ptr<SourceRangeRecords> unpickled_records;
};
} // namespace jit
} // namespace torch

View File

@ -1,39 +0,0 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/source_range.h>
namespace torch {
namespace jit {
class SourceRangeSerializer {
public:
// Serialize SourceRange as Tuple[SourceType, int, int]
// where SourceType = Tuple[str, Optional[str], int, List[int]],
// the serialized form of Source
c10::IValue serialize(const SourceRange& sr) {
std::vector<c10::IValue> elements = {
serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end()};
return c10::ivalue::Tuple::create(std::move(elements));
}
private:
// Serialize Source as Tuple[str, Optional[str], int, List[int]]
// This caches serialized sources, since many SourceRanges can
// refer to the same one.
c10::IValue serialize_source(const std::shared_ptr<Source>& s) {
if (serialized_sources.count(s)) {
return serialized_sources.at(s);
}
std::vector<c10::IValue> elements{
s->text(), s->filename(), (int64_t)s->starting_line_no()};
auto serialized = c10::ivalue::Tuple::create(std::move(elements));
serialized_sources[s] = serialized;
return serialized;
}
std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
};
} // namespace jit
} // namespace torch