Revert D15275731: Remote SourceLocation

Differential Revision:
D15275731

Original commit changeset: f4da178c3137

fbshipit-source-id: 830b79735eb2dadc4795b5aae407826bf20ef121
This commit is contained in:
Wanchao Liang
2019-05-09 12:38:21 -07:00
committed by Facebook Github Bot
parent eca91de5d2
commit e870b11ae6
26 changed files with 175 additions and 153 deletions

View File

@ -9,6 +9,7 @@ EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})
# Add files needed from jit folders
LIST(APPEND ATen_CORE_HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/source_range.h
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/source_location.h
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/function_schema_parser.h
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/lexer.h
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/strtod.h
@ -22,7 +23,6 @@ LIST(APPEND ATen_CORE_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/lexer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/strtod.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/schema_type_parser.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/source_range.cpp
)
# Pass to parent

View File

@ -85,7 +85,7 @@ c10::optional<Value*> tryInsertConstant(
return c10::nullopt;
}
if (loc)
n->setSourceRange(*loc);
n->setSourceLocation(std::make_shared<SourceRange>(*loc));
if (scope)
n->setScope(*scope);
if (result_type) {

View File

@ -38,7 +38,13 @@ namespace onnx = ::ONNX_NAMESPACE;
class ScriptModuleSerializer;
std::string getNodeStackTraceString(const Node* n) {
return n->sourceRange().str();
std::stringstream ss;
if (n->getSourceLocation()) {
n->getSourceLocation()->highlight(ss);
} else {
ss << "<unknown location>";
}
return ss.str();
}
void validateBlock(
@ -252,8 +258,10 @@ void EncoderBase::EncodeBlock(
continue;
}
auto p_n = graph_proto->add_node();
if (!strip_doc_) {
p_n->set_doc_string(node->sourceRange().str());
if (node->getSourceLocation() && !strip_doc_) {
std::stringstream ss;
node->getSourceLocation()->highlight(ss);
p_n->set_doc_string(ss.str());
}
for (auto input : node->inputs()) {
if (input->node()->mustBeNone() && !is_raw_export) {

View File

@ -330,7 +330,7 @@ struct Instruction {
UseList inputs;
ListHandle<int> outputs;
Symbol debug_name; // used in dump to understand the generated code
c10::optional<SourceRange> debug_location; // for error reporting
std::shared_ptr<SourceLocation> debug_location; // for error reporting
};
int relativeJump(int from_inst, int to_inst) {
@ -377,7 +377,7 @@ struct CodeImpl {
void insertNodesFromBlock(Block* block) {
for (auto node : block->nodes()) {
SourceRange source_location = node->sourceRange();
const auto& source_location = node->getSourceLocation();
switch (node->kind()) {
case prim::If: {
// x = if c:
@ -481,7 +481,7 @@ struct CodeImpl {
size_t insertInstruction(Node* n) {
auto inst = insertInstruction(
n->kind(),
n->sourceRange(),
n->getSourceLocation(),
n->inputs(),
moveFlags(n),
n->outputs());
@ -490,7 +490,7 @@ struct CodeImpl {
}
size_t insertInstruction(
Symbol sym,
const SourceRange& debug_location,
std::shared_ptr<SourceLocation> debug_location,
ArrayRef<Value*> inputs,
ArrayRef<uint8_t> move_flags,
ArrayRef<Value*> outputs) {
@ -520,7 +520,7 @@ struct CodeImpl {
}
size_t insertAssign(
const SourceRange& debug_location,
std::shared_ptr<SourceLocation> debug_location,
ArrayRef<Value*> inputs,
ArrayRef<uint8_t> move_flags,
ArrayRef<Value*> outputs) {
@ -713,10 +713,14 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
} catch (std::exception& e) {
// Error from the current thread
bool is_jit_exception = dynamic_cast<JITException*>(&e);
handleError(
instructions[pc].debug_location->wrapException(
e, "operation failed in interpreter"),
is_jit_exception);
if (instructions[pc].debug_location) {
handleError(
instructions[pc].debug_location->wrapException(
e, "operation failed in interpreter"),
is_jit_exception);
} else {
handleError(e.what(), is_jit_exception);
}
return false;
}
}

View File

@ -200,14 +200,6 @@ void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false)
out << "]";
}
SourceRange Node::sourceRange() const {
if(source_range_) {
return *source_range_;
}
std::stringstream ss;
return SourceRange(ss.str());
}
static std::ostream& indent(std::ostream& out, size_t level) {
for (size_t i = 0; i < level; ++i) {
out << " ";
@ -232,10 +224,8 @@ std::ostream& Node::print(
if (numAttributes() > 1 && kind() != prim::DifferentiableGraph) {
printAttributes(out, /*ignore_subgraph=*/true);
}
groups->push_back(this);
} else {
out << kind().toQualString();
if (hasAttributes()) {
printAttributes(out);
@ -251,7 +241,6 @@ std::ostream& Node::print(
out << ", ";
out << "scope: " << scName << "\n";
}
for (size_t i = 0; i < blocks().size(); ++i) {
auto b = blocks()[i];
indent(out, level + 1) << "block" << i << "("
@ -262,7 +251,6 @@ std::ostream& Node::print(
}
indent(out, level + 2) << "-> (" << b->outputs() << ")\n";
}
return out;
}
@ -985,7 +973,7 @@ void Node::destroy() {
}
void Node::cloneFrom(Node* s) {
s->source_range_ = s->source_range_;
setSourceLocation(s->getSourceLocation());
if (s->scope_ && !s->scope_->isBlank()) {
scope_ = s->scope_;
}

View File

@ -249,7 +249,7 @@ struct TORCH_API Node {
std::vector<Block*> blocks_;
Graph* graph_;
Block* owning_block_;
c10::optional<SourceRange> source_range_;
std::shared_ptr<SourceLocation> source_location_;
ScopePtr scope_;
// Assumes FunctionSchemas are persistent, so we don't manage their lifetime.
// This field is effective a cache that's populated on attribute lookups and
@ -287,12 +287,13 @@ struct TORCH_API Node {
NodeKind kind() const {
return kind_;
}
Node* setSourceRange(SourceRange r) {
source_range_ = std::move(r);
Node* setSourceLocation(std::shared_ptr<SourceLocation> sl) {
source_location_ = std::move(sl);
return this;
}
SourceRange sourceRange() const;
std::shared_ptr<SourceLocation> getSourceLocation() const {
return source_location_;
}
Graph* owningGraph() {
return graph_;
}

View File

@ -237,7 +237,7 @@ const Operator& getOperatorFor(const Node* node) {
if (op)
return *op;
auto er = script::ErrorReport(node->sourceRange());
auto er = script::ErrorReport(node->getSourceLocation());
er << "Schema not found for node. File a bug report.\n";
er << "Node: " << *node << "\n";
er << "Input types:";

View File

@ -464,7 +464,7 @@ void AliasDb::analyzeImpl(Node* node) {
// We don't have alias info for this node. Either schematize it, or
// add it an analyze* method for it.
if (hasMutableOutputs) {
throw script::ErrorReport(node->sourceRange())
throw script::ErrorReport(node->getSourceLocation())
<< "Alias information not found for node. File a bug report.\n"
<< "Node: " << *node << "\n";
}

View File

@ -44,7 +44,7 @@ void removeTupleNodes(Node* n, bool must_remove_tuples) {
auto maybe_int = constant_as<int64_t>(idx);
if (!maybe_int) {
if (must_remove_tuples) {
AT_ERROR(n->sourceRange(), "tuple index with non-constant index");
AT_ERROR(n->getSourceLocation(), "tuple index with non-constant index");
}
return;
}

View File

@ -211,7 +211,7 @@ void BlockToONNX(
outputs[i]->setType(old->type());
// Copy over source location and scope information to all nodes
// created by the symbolic
outputs[i]->node()->setSourceRange(node->sourceRange());
outputs[i]->node()->setSourceLocation(node->getSourceLocation());
outputs[i]->node()->setScope(node->scope());
env[old] = outputs[i];
} else {

View File

@ -531,7 +531,7 @@ struct PythonPrintPass {
// this must be a while loop, but check that there isn't _also_ a trip
// count
if (trip_count_is_specified) {
throw script::ErrorReport(stmt.node()->sourceRange())
throw script::ErrorReport(stmt.node()->getSourceLocation())
<< "loop cannot be printed as python "
<< "because it has gone through an optimization "
<< "that combined while and for loops. File a bug.";
@ -678,7 +678,7 @@ struct PythonPrintPass {
switch (node->kind()) {
case prim::Return:
if (enforce_importable_ && node->inputs().size() != 1) {
throw script::ErrorReport(node->sourceRange())
throw script::ErrorReport(node->getSourceLocation())
<< "Exportable methods must have a single return value. "
<< "Normal use of ScriptMethods should enforce this.";
}
@ -733,7 +733,7 @@ struct PythonPrintPass {
} break;
case prim::Function: {
if (enforce_importable_) {
throw script::ErrorReport(node->sourceRange())
throw script::ErrorReport(node->getSourceLocation())
<< "closures are not exportable";
}
assignValuesToTheirUniqueNames(node->outputs());
@ -850,7 +850,7 @@ struct PythonPrintPass {
case prim::PythonOp: {
auto value = static_cast<const PythonOp*>(node);
if (enforce_importable_) {
throw script::ErrorReport(node->sourceRange())
throw script::ErrorReport(node->getSourceLocation())
<< "could not export python function call " << value->name()
<< ". Remove calls to Python functions before export. "
<< "Did you forget add @script or @script_method annotation? "

View File

@ -73,7 +73,11 @@ class ShapePropagator {
} catch (propagation_error& e) {
setUnshapedType(node);
} catch (std::exception& e) {
node->sourceRange().wrapAndRethrowException(e, "operation failed shape propagation");
if (auto sl = node->getSourceLocation()) {
sl->wrapAndRethrowException(e, "operation failed shape propagation");
} else {
throw;
}
}
}
}

View File

@ -439,9 +439,15 @@ void initPythonIRBindings(PyObject* module_) {
return ss.str();
})
.def(
"sourceRange",
[](Node& n) {
return n.sourceRange().str();
"getSourceLocation",
[](Node& n) -> py::object {
std::stringstream ss;
if (auto sl = n.getSourceLocation()) {
sl->highlight(ss);
return py::str(ss.str());
} else {
return py::none();
}
})
.def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; })
.def("outputsSize", [](Node& n) { return n.outputs().size(); })

View File

@ -101,7 +101,9 @@ Node* preRecordPythonTrace(
}
void pythonRecordSourceLocation(Node* n) {
n->setSourceRange(SourceRange(getPythonInterpreterStackTrace()));
auto sl =
std::make_shared<StringSourceLocation>(getPythonInterpreterStackTrace());
n->setSourceLocation(sl);
}
void pythonWarn(const std::string& reason) {

View File

@ -22,7 +22,7 @@ namespace {
void checkListInputType(const c10::TypePtr& elem_type, const Node* node) {
if (!elem_type->isSubtypeOf(NumberType::get()) &&
elem_type != BoolType::get()) {
auto error = script::ErrorReport(node->sourceRange());
auto error = script::ErrorReport(node->getSourceLocation());
error << "Input list to torch.tensor must be of ints, floats, or bools, "
<< "got " << elem_type->str();
// special case empty list torch.tensor([])

View File

@ -953,7 +953,7 @@ struct to_ir {
Node* create(Symbol kind, const SourceRange& loc, size_t n_outputs) {
return graph->create(kind, n_outputs)
->setSourceRange(loc);
->setSourceLocation(std::make_shared<SourceRange>(loc));
}
Value* emitTernaryIf(const TernaryIf& expr) {
@ -2379,7 +2379,7 @@ struct to_ir {
auto fork_node =
method.graph()
->insertNode(method.graph()->create(prim::fork, 1))
->setSourceRange(loc);
->setSourceLocation(std::make_shared<SourceRange>(loc));
auto body_block = fork_node->addBlock();
// Build a template of the graph to be executed

View File

@ -1,7 +1,6 @@
#pragma once
#include <torch/csrc/jit/script/tree.h>
#include <c10/util/Optional.h>
namespace torch {
namespace jit {
@ -11,15 +10,17 @@ struct ErrorReport : public std::exception {
ErrorReport(const ErrorReport& e)
: ss(e.ss.str()), context(e.context), the_message(e.the_message) {}
ErrorReport() : context(c10::nullopt) {}
explicit ErrorReport(SourceRange r)
: context(std::move(r)) {}
ErrorReport() : context(nullptr) {}
explicit ErrorReport(const SourceRange& r)
: context(std::make_shared<SourceRange>(r)) {}
explicit ErrorReport(std::shared_ptr<SourceLocation> loc)
: context(std::move(loc)) {}
explicit ErrorReport(const TreeRef& tree) : ErrorReport(tree->range()) {}
explicit ErrorReport(const Token& tok) : ErrorReport(tok.range) {}
const char* what() const noexcept override {
std::stringstream msg;
msg << "\n" << ss.str();
if (context) {
if (context != nullptr) {
msg << ":\n";
context->highlight(msg);
} else {
@ -34,7 +35,7 @@ struct ErrorReport : public std::exception {
friend const ErrorReport& operator<<(const ErrorReport& e, const T& t);
mutable std::stringstream ss;
c10::optional<SourceRange> context;
std::shared_ptr<SourceLocation> context;
mutable std::string the_message;
};

View File

@ -194,7 +194,7 @@ std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
continue;
}
if (e.n->kind() != prim::GetAttr) {
throw ErrorReport(e.n->sourceRange())
throw ErrorReport(e.n->getSourceLocation())
<< "temporary: the only valid use of a module is looking up an attribute";
}
Slot slot(e.mod, e.mod->type()->getAttributeSlot(e.n->s(attr::name)));

View File

@ -110,7 +110,7 @@ std::shared_ptr<SugaredValue> PythonValue::call(
auto python_op = static_cast<PythonOp*>(new_node);
python_op->ignore_on_export = true;
}
new_node->setSourceRange(loc);
new_node->setSourceLocation(std::make_shared<SourceRange>(loc));
for (auto& i : matched_schema->inputs)
new_node->addInput(i);

View File

@ -100,7 +100,7 @@ Value* tryConvertToType(
value->type()->isSubtypeOf(TensorType::get())) {
auto n = graph.createImplicitTensorToNum(concrete_type, value);
value = graph.insertNode(n)
->setSourceRange(loc)
->setSourceLocation(std::make_shared<SourceRange>(loc))
->output();
}
if (value->type()->isSubtypeOf(StringType::get()) &&
@ -355,7 +355,7 @@ static Value* emitBuiltinNode(
Graph& graph,
Symbol name) {
auto n = graph.insertNode(graph.create(name, matched_schema.inputs, 0))
->setSourceRange(loc);
->setSourceLocation(std::make_shared<SourceRange>(loc));
for (auto& ret : matched_schema.return_types) {
n->addOutput()->setType(ret);

View File

@ -38,7 +38,7 @@ std::shared_ptr<SugaredValue> PrintValue::call(
lowered_inputs.erase(lowered_inputs.begin());
}
g.insertNode(g.create(prim::Print, lowered_inputs, 0)
->setSourceRange(loc));
->setSourceLocation(std::make_shared<SourceRange>(loc)));
return std::make_shared<NoneValue>();
}

View File

@ -0,0 +1,54 @@
#pragma once
#include <ostream>
#include <sstream>
#include <stdexcept>
#include <string>
namespace torch {
namespace jit {
// SourceLocation represents source code-level debug information for a node.
// It contains information about where a node got generated.
// In the case of tracing this will be a python stack trace.
// In the case of using the scripting frontend this will be backed
// by a SourceRange object
struct SourceLocation {
virtual ~SourceLocation() = default;
virtual void highlight(std::ostream& out) const = 0;
std::string wrapException(
const std::exception& e,
const std::string& additional = "") {
std::stringstream msg;
msg << "\n" << e.what() << ":\n";
if (!additional.empty()) {
msg << additional << ":\n";
}
highlight(msg);
return msg.str();
}
void wrapAndRethrowException(
const std::exception& e,
const std::string& additional = "") {
throw std::runtime_error(wrapException(e, additional));
}
};
inline std::ostream& operator<<(std::ostream& out, const SourceLocation& sl) {
sl.highlight(out);
return out;
}
// normally a python stack trace
struct StringSourceLocation : public SourceLocation {
StringSourceLocation(std::string context) : context(std::move(context)) {}
void highlight(std::ostream& out) const override {
out << context;
}
private:
std::string context;
};
} // namespace jit
} // namespace torch

View File

@ -1,56 +0,0 @@
#include <torch/csrc/jit/source_range.h>
namespace torch {
namespace jit {
// a range of a shared string 'file_' with
C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
if (size() == file_->size()) {
// this is just the entire file, not a subset, so print it out.
// primarily used to print out python stack traces
out << *file_;
return;
}
const std::string& str = file();
size_t begin_line = start(); // beginning of line to highlight
size_t end_line = start(); // end of line to highlight
while (begin_line > 0 && str[begin_line - 1] != '\n')
--begin_line;
while (end_line < str.size() && str[end_line] != '\n')
++end_line;
AT_ASSERT(begin_line == 0 || str[begin_line - 1] == '\n');
AT_ASSERT(end_line == str.size() || str[end_line] == '\n');
size_t begin_highlight = begin_line; // beginning of context, CONTEXT lines
// before the highlight line
for (size_t i = 0; begin_highlight > 0; --begin_highlight) {
if (str[begin_highlight - 1] == '\n')
++i;
if (i >= CONTEXT)
break;
}
AT_ASSERT(begin_highlight == 0 || str[begin_highlight - 1] == '\n');
size_t end_highlight =
end_line; // end of context, CONTEXT lines after the highlight line
for (size_t i = 0; end_highlight < str.size(); ++end_highlight) {
if (str[end_highlight] == '\n')
++i;
if (i >= CONTEXT)
break;
}
AT_ASSERT(end_highlight == str.size() || str[end_highlight] == '\n');
out << str.substr(begin_highlight, end_line - begin_highlight) << "\n";
out << std::string(start() - begin_line, ' ');
size_t len = std::min(size(), end_line - start());
out << std::string(len, '~')
<< (len < size() ? "... <--- HERE" : " <--- HERE");
out << str.substr(end_line, end_highlight - end_line);
if (!str.empty() && str.back() != '\n')
out << "\n";
}
} // namespace jit
} // namespace torch

View File

@ -1,31 +1,67 @@
#pragma once
#include <c10/util/Exception.h>
#include <torch/csrc/jit/source_location.h>
#include <algorithm>
#include <memory>
#include <iostream>
namespace torch {
namespace jit {
// a range of a shared string 'file_' with functions to help debug by highlight
// that
// range.
struct CAFFE2_API SourceRange {
struct SourceRange : public SourceLocation {
SourceRange(std::shared_ptr<std::string> file_, size_t start_, size_t end_)
: file_(std::move(file_)), start_(start_), end_(end_) {}
explicit SourceRange(std::string string_range)
: file_(std::make_shared<std::string>(std::move(string_range))),
start_(0),
end_(file_->size()) {}
const std::string text() const {
return file().substr(start(), end() - start());
}
size_t size() const {
return end() - start();
}
static const size_t CONTEXT = 10;
void highlight(std::ostream& out) const;
void highlight(std::ostream& out) const override {
const std::string& str = file();
size_t begin_line = start(); // beginning of line to highlight
size_t end_line = start(); // end of line to highlight
while (begin_line > 0 && str[begin_line - 1] != '\n')
--begin_line;
while (end_line < str.size() && str[end_line] != '\n')
++end_line;
AT_ASSERT(begin_line == 0 || str[begin_line - 1] == '\n');
AT_ASSERT(end_line == str.size() || str[end_line] == '\n');
size_t begin_highlight = begin_line; // beginning of context, CONTEXT lines
// before the highlight line
for (size_t i = 0; begin_highlight > 0; --begin_highlight) {
if (str[begin_highlight - 1] == '\n')
++i;
if (i >= CONTEXT)
break;
}
AT_ASSERT(begin_highlight == 0 || str[begin_highlight - 1] == '\n');
size_t end_highlight =
end_line; // end of context, CONTEXT lines after the highlight line
for (size_t i = 0; end_highlight < str.size(); ++end_highlight) {
if (str[end_highlight] == '\n')
++i;
if (i >= CONTEXT)
break;
}
AT_ASSERT(end_highlight == str.size() || str[end_highlight] == '\n');
out << str.substr(begin_highlight, end_line - begin_highlight) << "\n";
out << std::string(start() - begin_line, ' ');
size_t len = std::min(size(), end_line - start());
out << std::string(len, '~')
<< (len < size() ? "... <--- HERE" : " <--- HERE");
out << str.substr(end_line, end_highlight - end_line);
if (!str.empty() && str.back() != '\n')
out << "\n";
}
const std::string& file() const {
return *file_;
}
@ -38,27 +74,6 @@ struct CAFFE2_API SourceRange {
size_t end() const {
return end_;
}
std::string str() const {
std::stringstream ss;
highlight(ss);
return ss.str();
}
std::string wrapException(
const std::exception& e,
const std::string& additional = "") {
std::stringstream msg;
msg << "\n" << e.what() << ":\n";
if (!additional.empty()) {
msg << additional << ":\n";
}
highlight(msg);
return msg.str();
}
void wrapAndRethrowException(
const std::exception& e,
const std::string& additional = "") {
throw std::runtime_error(wrapException(e, additional));
}
private:
std::shared_ptr<std::string> file_;
@ -66,10 +81,5 @@ struct CAFFE2_API SourceRange {
size_t end_;
};
inline std::ostream& operator<<(std::ostream& out, const SourceRange& range) {
range.highlight(out);
return out;
}
} // namespace jit
} // namespace torch

View File

@ -1332,7 +1332,7 @@ void loadModule(const script::CompilationUnit& module) {
Node* forward_tuple = pair.forward->outputs().at(0)->node();
if (forward_tuple->kind() != prim::TupleConstruct) {
throw script::ErrorReport(forward_tuple->sourceRange())
throw script::ErrorReport(forward_tuple->getSourceLocation())
<< "gradient must return literal a tuple";
}

View File

@ -519,10 +519,10 @@ def _check_trace(check_inputs, func, executor_options, traced_func, check_tolera
node_diff = difflib.ndiff(str(n_mod).splitlines(True),
str(n_check).splitlines(True))
source_printout = 'Node diff:\n' + indent(''.join(node_diff)) + '\n'
mod_stack = n_mod.sourceRange()
mod_stack = n_mod.getSourceLocation()
if mod_stack:
source_printout += 'Trace source location:\n' + indent(mod_stack) + '\n'
check_stack = n_check.sourceRange()
check_stack = n_check.getSourceLocation()
if check_stack:
source_printout += 'Check source location:\n' + indent(check_stack) + '\n'
graph_diff_errors += source_printout
@ -548,7 +548,7 @@ def _check_trace(check_inputs, func, executor_options, traced_func, check_tolera
if tensor_compare_errors is None:
tensor_compare_errors = ''
tensor_compare_errors += 'Node:\n' + indent(str(n_mod)) + '\n'
compare_stack = n_mod.sourceRange()
compare_stack = n_mod.getSourceLocation()
if compare_stack:
tensor_compare_errors += 'Source Location:\n' + indent(compare_stack) + '\n'
tensor_compare_errors += 'Comparison exception: ' + indent(str(e))