mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D15275731: Remote SourceLocation
Differential Revision: D15275731 Original commit changeset: f4da178c3137 fbshipit-source-id: 830b79735eb2dadc4795b5aae407826bf20ef121
This commit is contained in:
committed by
Facebook Github Bot
parent
eca91de5d2
commit
e870b11ae6
@ -9,6 +9,7 @@ EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})
|
||||
# Add files needed from jit folders
|
||||
LIST(APPEND ATen_CORE_HEADERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/source_range.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/source_location.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/function_schema_parser.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/lexer.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/strtod.h
|
||||
@ -22,7 +23,6 @@ LIST(APPEND ATen_CORE_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/lexer.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/strtod.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/schema_type_parser.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/source_range.cpp
|
||||
)
|
||||
|
||||
# Pass to parent
|
||||
|
@ -85,7 +85,7 @@ c10::optional<Value*> tryInsertConstant(
|
||||
return c10::nullopt;
|
||||
}
|
||||
if (loc)
|
||||
n->setSourceRange(*loc);
|
||||
n->setSourceLocation(std::make_shared<SourceRange>(*loc));
|
||||
if (scope)
|
||||
n->setScope(*scope);
|
||||
if (result_type) {
|
||||
|
@ -38,7 +38,13 @@ namespace onnx = ::ONNX_NAMESPACE;
|
||||
class ScriptModuleSerializer;
|
||||
|
||||
std::string getNodeStackTraceString(const Node* n) {
|
||||
return n->sourceRange().str();
|
||||
std::stringstream ss;
|
||||
if (n->getSourceLocation()) {
|
||||
n->getSourceLocation()->highlight(ss);
|
||||
} else {
|
||||
ss << "<unknown location>";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
void validateBlock(
|
||||
@ -252,8 +258,10 @@ void EncoderBase::EncodeBlock(
|
||||
continue;
|
||||
}
|
||||
auto p_n = graph_proto->add_node();
|
||||
if (!strip_doc_) {
|
||||
p_n->set_doc_string(node->sourceRange().str());
|
||||
if (node->getSourceLocation() && !strip_doc_) {
|
||||
std::stringstream ss;
|
||||
node->getSourceLocation()->highlight(ss);
|
||||
p_n->set_doc_string(ss.str());
|
||||
}
|
||||
for (auto input : node->inputs()) {
|
||||
if (input->node()->mustBeNone() && !is_raw_export) {
|
||||
|
@ -330,7 +330,7 @@ struct Instruction {
|
||||
UseList inputs;
|
||||
ListHandle<int> outputs;
|
||||
Symbol debug_name; // used in dump to understand the generated code
|
||||
c10::optional<SourceRange> debug_location; // for error reporting
|
||||
std::shared_ptr<SourceLocation> debug_location; // for error reporting
|
||||
};
|
||||
|
||||
int relativeJump(int from_inst, int to_inst) {
|
||||
@ -377,7 +377,7 @@ struct CodeImpl {
|
||||
|
||||
void insertNodesFromBlock(Block* block) {
|
||||
for (auto node : block->nodes()) {
|
||||
SourceRange source_location = node->sourceRange();
|
||||
const auto& source_location = node->getSourceLocation();
|
||||
switch (node->kind()) {
|
||||
case prim::If: {
|
||||
// x = if c:
|
||||
@ -481,7 +481,7 @@ struct CodeImpl {
|
||||
size_t insertInstruction(Node* n) {
|
||||
auto inst = insertInstruction(
|
||||
n->kind(),
|
||||
n->sourceRange(),
|
||||
n->getSourceLocation(),
|
||||
n->inputs(),
|
||||
moveFlags(n),
|
||||
n->outputs());
|
||||
@ -490,7 +490,7 @@ struct CodeImpl {
|
||||
}
|
||||
size_t insertInstruction(
|
||||
Symbol sym,
|
||||
const SourceRange& debug_location,
|
||||
std::shared_ptr<SourceLocation> debug_location,
|
||||
ArrayRef<Value*> inputs,
|
||||
ArrayRef<uint8_t> move_flags,
|
||||
ArrayRef<Value*> outputs) {
|
||||
@ -520,7 +520,7 @@ struct CodeImpl {
|
||||
}
|
||||
|
||||
size_t insertAssign(
|
||||
const SourceRange& debug_location,
|
||||
std::shared_ptr<SourceLocation> debug_location,
|
||||
ArrayRef<Value*> inputs,
|
||||
ArrayRef<uint8_t> move_flags,
|
||||
ArrayRef<Value*> outputs) {
|
||||
@ -713,10 +713,14 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
} catch (std::exception& e) {
|
||||
// Error from the current thread
|
||||
bool is_jit_exception = dynamic_cast<JITException*>(&e);
|
||||
handleError(
|
||||
instructions[pc].debug_location->wrapException(
|
||||
e, "operation failed in interpreter"),
|
||||
is_jit_exception);
|
||||
if (instructions[pc].debug_location) {
|
||||
handleError(
|
||||
instructions[pc].debug_location->wrapException(
|
||||
e, "operation failed in interpreter"),
|
||||
is_jit_exception);
|
||||
} else {
|
||||
handleError(e.what(), is_jit_exception);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -200,14 +200,6 @@ void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false)
|
||||
out << "]";
|
||||
}
|
||||
|
||||
SourceRange Node::sourceRange() const {
|
||||
if(source_range_) {
|
||||
return *source_range_;
|
||||
}
|
||||
std::stringstream ss;
|
||||
return SourceRange(ss.str());
|
||||
}
|
||||
|
||||
static std::ostream& indent(std::ostream& out, size_t level) {
|
||||
for (size_t i = 0; i < level; ++i) {
|
||||
out << " ";
|
||||
@ -232,10 +224,8 @@ std::ostream& Node::print(
|
||||
if (numAttributes() > 1 && kind() != prim::DifferentiableGraph) {
|
||||
printAttributes(out, /*ignore_subgraph=*/true);
|
||||
}
|
||||
|
||||
groups->push_back(this);
|
||||
} else {
|
||||
|
||||
out << kind().toQualString();
|
||||
if (hasAttributes()) {
|
||||
printAttributes(out);
|
||||
@ -251,7 +241,6 @@ std::ostream& Node::print(
|
||||
out << ", ";
|
||||
out << "scope: " << scName << "\n";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < blocks().size(); ++i) {
|
||||
auto b = blocks()[i];
|
||||
indent(out, level + 1) << "block" << i << "("
|
||||
@ -262,7 +251,6 @@ std::ostream& Node::print(
|
||||
}
|
||||
indent(out, level + 2) << "-> (" << b->outputs() << ")\n";
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -985,7 +973,7 @@ void Node::destroy() {
|
||||
}
|
||||
|
||||
void Node::cloneFrom(Node* s) {
|
||||
s->source_range_ = s->source_range_;
|
||||
setSourceLocation(s->getSourceLocation());
|
||||
if (s->scope_ && !s->scope_->isBlank()) {
|
||||
scope_ = s->scope_;
|
||||
}
|
||||
|
@ -249,7 +249,7 @@ struct TORCH_API Node {
|
||||
std::vector<Block*> blocks_;
|
||||
Graph* graph_;
|
||||
Block* owning_block_;
|
||||
c10::optional<SourceRange> source_range_;
|
||||
std::shared_ptr<SourceLocation> source_location_;
|
||||
ScopePtr scope_;
|
||||
// Assumes FunctionSchemas are persistent, so we don't manage their lifetime.
|
||||
// This field is effective a cache that's populated on attribute lookups and
|
||||
@ -287,12 +287,13 @@ struct TORCH_API Node {
|
||||
NodeKind kind() const {
|
||||
return kind_;
|
||||
}
|
||||
Node* setSourceRange(SourceRange r) {
|
||||
source_range_ = std::move(r);
|
||||
Node* setSourceLocation(std::shared_ptr<SourceLocation> sl) {
|
||||
source_location_ = std::move(sl);
|
||||
return this;
|
||||
}
|
||||
SourceRange sourceRange() const;
|
||||
|
||||
std::shared_ptr<SourceLocation> getSourceLocation() const {
|
||||
return source_location_;
|
||||
}
|
||||
Graph* owningGraph() {
|
||||
return graph_;
|
||||
}
|
||||
|
@ -237,7 +237,7 @@ const Operator& getOperatorFor(const Node* node) {
|
||||
if (op)
|
||||
return *op;
|
||||
|
||||
auto er = script::ErrorReport(node->sourceRange());
|
||||
auto er = script::ErrorReport(node->getSourceLocation());
|
||||
er << "Schema not found for node. File a bug report.\n";
|
||||
er << "Node: " << *node << "\n";
|
||||
er << "Input types:";
|
||||
|
@ -464,7 +464,7 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||
// We don't have alias info for this node. Either schematize it, or
|
||||
// add it an analyze* method for it.
|
||||
if (hasMutableOutputs) {
|
||||
throw script::ErrorReport(node->sourceRange())
|
||||
throw script::ErrorReport(node->getSourceLocation())
|
||||
<< "Alias information not found for node. File a bug report.\n"
|
||||
<< "Node: " << *node << "\n";
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ void removeTupleNodes(Node* n, bool must_remove_tuples) {
|
||||
auto maybe_int = constant_as<int64_t>(idx);
|
||||
if (!maybe_int) {
|
||||
if (must_remove_tuples) {
|
||||
AT_ERROR(n->sourceRange(), "tuple index with non-constant index");
|
||||
AT_ERROR(n->getSourceLocation(), "tuple index with non-constant index");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -211,7 +211,7 @@ void BlockToONNX(
|
||||
outputs[i]->setType(old->type());
|
||||
// Copy over source location and scope information to all nodes
|
||||
// created by the symbolic
|
||||
outputs[i]->node()->setSourceRange(node->sourceRange());
|
||||
outputs[i]->node()->setSourceLocation(node->getSourceLocation());
|
||||
outputs[i]->node()->setScope(node->scope());
|
||||
env[old] = outputs[i];
|
||||
} else {
|
||||
|
@ -531,7 +531,7 @@ struct PythonPrintPass {
|
||||
// this must be a while loop, but check that there isn't _also_ a trip
|
||||
// count
|
||||
if (trip_count_is_specified) {
|
||||
throw script::ErrorReport(stmt.node()->sourceRange())
|
||||
throw script::ErrorReport(stmt.node()->getSourceLocation())
|
||||
<< "loop cannot be printed as python "
|
||||
<< "because it has gone through an optimization "
|
||||
<< "that combined while and for loops. File a bug.";
|
||||
@ -678,7 +678,7 @@ struct PythonPrintPass {
|
||||
switch (node->kind()) {
|
||||
case prim::Return:
|
||||
if (enforce_importable_ && node->inputs().size() != 1) {
|
||||
throw script::ErrorReport(node->sourceRange())
|
||||
throw script::ErrorReport(node->getSourceLocation())
|
||||
<< "Exportable methods must have a single return value. "
|
||||
<< "Normal use of ScriptMethods should enforce this.";
|
||||
}
|
||||
@ -733,7 +733,7 @@ struct PythonPrintPass {
|
||||
} break;
|
||||
case prim::Function: {
|
||||
if (enforce_importable_) {
|
||||
throw script::ErrorReport(node->sourceRange())
|
||||
throw script::ErrorReport(node->getSourceLocation())
|
||||
<< "closures are not exportable";
|
||||
}
|
||||
assignValuesToTheirUniqueNames(node->outputs());
|
||||
@ -850,7 +850,7 @@ struct PythonPrintPass {
|
||||
case prim::PythonOp: {
|
||||
auto value = static_cast<const PythonOp*>(node);
|
||||
if (enforce_importable_) {
|
||||
throw script::ErrorReport(node->sourceRange())
|
||||
throw script::ErrorReport(node->getSourceLocation())
|
||||
<< "could not export python function call " << value->name()
|
||||
<< ". Remove calls to Python functions before export. "
|
||||
<< "Did you forget add @script or @script_method annotation? "
|
||||
|
@ -73,7 +73,11 @@ class ShapePropagator {
|
||||
} catch (propagation_error& e) {
|
||||
setUnshapedType(node);
|
||||
} catch (std::exception& e) {
|
||||
node->sourceRange().wrapAndRethrowException(e, "operation failed shape propagation");
|
||||
if (auto sl = node->getSourceLocation()) {
|
||||
sl->wrapAndRethrowException(e, "operation failed shape propagation");
|
||||
} else {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -439,9 +439,15 @@ void initPythonIRBindings(PyObject* module_) {
|
||||
return ss.str();
|
||||
})
|
||||
.def(
|
||||
"sourceRange",
|
||||
[](Node& n) {
|
||||
return n.sourceRange().str();
|
||||
"getSourceLocation",
|
||||
[](Node& n) -> py::object {
|
||||
std::stringstream ss;
|
||||
if (auto sl = n.getSourceLocation()) {
|
||||
sl->highlight(ss);
|
||||
return py::str(ss.str());
|
||||
} else {
|
||||
return py::none();
|
||||
}
|
||||
})
|
||||
.def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; })
|
||||
.def("outputsSize", [](Node& n) { return n.outputs().size(); })
|
||||
|
@ -101,7 +101,9 @@ Node* preRecordPythonTrace(
|
||||
}
|
||||
|
||||
void pythonRecordSourceLocation(Node* n) {
|
||||
n->setSourceRange(SourceRange(getPythonInterpreterStackTrace()));
|
||||
auto sl =
|
||||
std::make_shared<StringSourceLocation>(getPythonInterpreterStackTrace());
|
||||
n->setSourceLocation(sl);
|
||||
}
|
||||
|
||||
void pythonWarn(const std::string& reason) {
|
||||
|
@ -22,7 +22,7 @@ namespace {
|
||||
void checkListInputType(const c10::TypePtr& elem_type, const Node* node) {
|
||||
if (!elem_type->isSubtypeOf(NumberType::get()) &&
|
||||
elem_type != BoolType::get()) {
|
||||
auto error = script::ErrorReport(node->sourceRange());
|
||||
auto error = script::ErrorReport(node->getSourceLocation());
|
||||
error << "Input list to torch.tensor must be of ints, floats, or bools, "
|
||||
<< "got " << elem_type->str();
|
||||
// special case empty list torch.tensor([])
|
||||
|
@ -953,7 +953,7 @@ struct to_ir {
|
||||
|
||||
Node* create(Symbol kind, const SourceRange& loc, size_t n_outputs) {
|
||||
return graph->create(kind, n_outputs)
|
||||
->setSourceRange(loc);
|
||||
->setSourceLocation(std::make_shared<SourceRange>(loc));
|
||||
}
|
||||
|
||||
Value* emitTernaryIf(const TernaryIf& expr) {
|
||||
@ -2379,7 +2379,7 @@ struct to_ir {
|
||||
auto fork_node =
|
||||
method.graph()
|
||||
->insertNode(method.graph()->create(prim::fork, 1))
|
||||
->setSourceRange(loc);
|
||||
->setSourceLocation(std::make_shared<SourceRange>(loc));
|
||||
auto body_block = fork_node->addBlock();
|
||||
|
||||
// Build a template of the graph to be executed
|
||||
|
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/script/tree.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -11,15 +10,17 @@ struct ErrorReport : public std::exception {
|
||||
ErrorReport(const ErrorReport& e)
|
||||
: ss(e.ss.str()), context(e.context), the_message(e.the_message) {}
|
||||
|
||||
ErrorReport() : context(c10::nullopt) {}
|
||||
explicit ErrorReport(SourceRange r)
|
||||
: context(std::move(r)) {}
|
||||
ErrorReport() : context(nullptr) {}
|
||||
explicit ErrorReport(const SourceRange& r)
|
||||
: context(std::make_shared<SourceRange>(r)) {}
|
||||
explicit ErrorReport(std::shared_ptr<SourceLocation> loc)
|
||||
: context(std::move(loc)) {}
|
||||
explicit ErrorReport(const TreeRef& tree) : ErrorReport(tree->range()) {}
|
||||
explicit ErrorReport(const Token& tok) : ErrorReport(tok.range) {}
|
||||
const char* what() const noexcept override {
|
||||
std::stringstream msg;
|
||||
msg << "\n" << ss.str();
|
||||
if (context) {
|
||||
if (context != nullptr) {
|
||||
msg << ":\n";
|
||||
context->highlight(msg);
|
||||
} else {
|
||||
@ -34,7 +35,7 @@ struct ErrorReport : public std::exception {
|
||||
friend const ErrorReport& operator<<(const ErrorReport& e, const T& t);
|
||||
|
||||
mutable std::stringstream ss;
|
||||
c10::optional<SourceRange> context;
|
||||
std::shared_ptr<SourceLocation> context;
|
||||
mutable std::string the_message;
|
||||
};
|
||||
|
||||
|
@ -194,7 +194,7 @@ std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
|
||||
continue;
|
||||
}
|
||||
if (e.n->kind() != prim::GetAttr) {
|
||||
throw ErrorReport(e.n->sourceRange())
|
||||
throw ErrorReport(e.n->getSourceLocation())
|
||||
<< "temporary: the only valid use of a module is looking up an attribute";
|
||||
}
|
||||
Slot slot(e.mod, e.mod->type()->getAttributeSlot(e.n->s(attr::name)));
|
||||
|
@ -110,7 +110,7 @@ std::shared_ptr<SugaredValue> PythonValue::call(
|
||||
auto python_op = static_cast<PythonOp*>(new_node);
|
||||
python_op->ignore_on_export = true;
|
||||
}
|
||||
new_node->setSourceRange(loc);
|
||||
new_node->setSourceLocation(std::make_shared<SourceRange>(loc));
|
||||
for (auto& i : matched_schema->inputs)
|
||||
new_node->addInput(i);
|
||||
|
||||
|
@ -100,7 +100,7 @@ Value* tryConvertToType(
|
||||
value->type()->isSubtypeOf(TensorType::get())) {
|
||||
auto n = graph.createImplicitTensorToNum(concrete_type, value);
|
||||
value = graph.insertNode(n)
|
||||
->setSourceRange(loc)
|
||||
->setSourceLocation(std::make_shared<SourceRange>(loc))
|
||||
->output();
|
||||
}
|
||||
if (value->type()->isSubtypeOf(StringType::get()) &&
|
||||
@ -355,7 +355,7 @@ static Value* emitBuiltinNode(
|
||||
Graph& graph,
|
||||
Symbol name) {
|
||||
auto n = graph.insertNode(graph.create(name, matched_schema.inputs, 0))
|
||||
->setSourceRange(loc);
|
||||
->setSourceLocation(std::make_shared<SourceRange>(loc));
|
||||
|
||||
for (auto& ret : matched_schema.return_types) {
|
||||
n->addOutput()->setType(ret);
|
||||
|
@ -38,7 +38,7 @@ std::shared_ptr<SugaredValue> PrintValue::call(
|
||||
lowered_inputs.erase(lowered_inputs.begin());
|
||||
}
|
||||
g.insertNode(g.create(prim::Print, lowered_inputs, 0)
|
||||
->setSourceRange(loc));
|
||||
->setSourceLocation(std::make_shared<SourceRange>(loc)));
|
||||
return std::make_shared<NoneValue>();
|
||||
}
|
||||
|
||||
|
54
torch/csrc/jit/source_location.h
Normal file
54
torch/csrc/jit/source_location.h
Normal file
@ -0,0 +1,54 @@
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
// SourceLocation represents source code-level debug information for a node.
|
||||
// It contains information about where a node got generated.
|
||||
// In the case of tracing this will be a python stack trace.
|
||||
// In the case of using the scripting frontend this will be backed
|
||||
// by a SourceRange object
|
||||
struct SourceLocation {
|
||||
virtual ~SourceLocation() = default;
|
||||
virtual void highlight(std::ostream& out) const = 0;
|
||||
|
||||
std::string wrapException(
|
||||
const std::exception& e,
|
||||
const std::string& additional = "") {
|
||||
std::stringstream msg;
|
||||
msg << "\n" << e.what() << ":\n";
|
||||
if (!additional.empty()) {
|
||||
msg << additional << ":\n";
|
||||
}
|
||||
highlight(msg);
|
||||
return msg.str();
|
||||
}
|
||||
void wrapAndRethrowException(
|
||||
const std::exception& e,
|
||||
const std::string& additional = "") {
|
||||
throw std::runtime_error(wrapException(e, additional));
|
||||
}
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const SourceLocation& sl) {
|
||||
sl.highlight(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
// normally a python stack trace
|
||||
struct StringSourceLocation : public SourceLocation {
|
||||
StringSourceLocation(std::string context) : context(std::move(context)) {}
|
||||
void highlight(std::ostream& out) const override {
|
||||
out << context;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string context;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -1,56 +0,0 @@
|
||||
#include <torch/csrc/jit/source_range.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// a range of a shared string 'file_' with
|
||||
C10_EXPORT void SourceRange::highlight(std::ostream& out) const {
|
||||
if (size() == file_->size()) {
|
||||
// this is just the entire file, not a subset, so print it out.
|
||||
// primarily used to print out python stack traces
|
||||
out << *file_;
|
||||
return;
|
||||
}
|
||||
|
||||
const std::string& str = file();
|
||||
size_t begin_line = start(); // beginning of line to highlight
|
||||
size_t end_line = start(); // end of line to highlight
|
||||
while (begin_line > 0 && str[begin_line - 1] != '\n')
|
||||
--begin_line;
|
||||
while (end_line < str.size() && str[end_line] != '\n')
|
||||
++end_line;
|
||||
AT_ASSERT(begin_line == 0 || str[begin_line - 1] == '\n');
|
||||
AT_ASSERT(end_line == str.size() || str[end_line] == '\n');
|
||||
|
||||
size_t begin_highlight = begin_line; // beginning of context, CONTEXT lines
|
||||
// before the highlight line
|
||||
for (size_t i = 0; begin_highlight > 0; --begin_highlight) {
|
||||
if (str[begin_highlight - 1] == '\n')
|
||||
++i;
|
||||
if (i >= CONTEXT)
|
||||
break;
|
||||
}
|
||||
AT_ASSERT(begin_highlight == 0 || str[begin_highlight - 1] == '\n');
|
||||
|
||||
size_t end_highlight =
|
||||
end_line; // end of context, CONTEXT lines after the highlight line
|
||||
for (size_t i = 0; end_highlight < str.size(); ++end_highlight) {
|
||||
if (str[end_highlight] == '\n')
|
||||
++i;
|
||||
if (i >= CONTEXT)
|
||||
break;
|
||||
}
|
||||
AT_ASSERT(end_highlight == str.size() || str[end_highlight] == '\n');
|
||||
|
||||
out << str.substr(begin_highlight, end_line - begin_highlight) << "\n";
|
||||
out << std::string(start() - begin_line, ' ');
|
||||
size_t len = std::min(size(), end_line - start());
|
||||
out << std::string(len, '~')
|
||||
<< (len < size() ? "... <--- HERE" : " <--- HERE");
|
||||
out << str.substr(end_line, end_highlight - end_line);
|
||||
if (!str.empty() && str.back() != '\n')
|
||||
out << "\n";
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -1,31 +1,67 @@
|
||||
#pragma once
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/jit/source_location.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// a range of a shared string 'file_' with functions to help debug by highlight
|
||||
// that
|
||||
// range.
|
||||
struct CAFFE2_API SourceRange {
|
||||
struct SourceRange : public SourceLocation {
|
||||
SourceRange(std::shared_ptr<std::string> file_, size_t start_, size_t end_)
|
||||
: file_(std::move(file_)), start_(start_), end_(end_) {}
|
||||
explicit SourceRange(std::string string_range)
|
||||
: file_(std::make_shared<std::string>(std::move(string_range))),
|
||||
start_(0),
|
||||
end_(file_->size()) {}
|
||||
|
||||
const std::string text() const {
|
||||
return file().substr(start(), end() - start());
|
||||
}
|
||||
size_t size() const {
|
||||
return end() - start();
|
||||
}
|
||||
|
||||
static const size_t CONTEXT = 10;
|
||||
void highlight(std::ostream& out) const;
|
||||
void highlight(std::ostream& out) const override {
|
||||
const std::string& str = file();
|
||||
size_t begin_line = start(); // beginning of line to highlight
|
||||
size_t end_line = start(); // end of line to highlight
|
||||
while (begin_line > 0 && str[begin_line - 1] != '\n')
|
||||
--begin_line;
|
||||
while (end_line < str.size() && str[end_line] != '\n')
|
||||
++end_line;
|
||||
AT_ASSERT(begin_line == 0 || str[begin_line - 1] == '\n');
|
||||
AT_ASSERT(end_line == str.size() || str[end_line] == '\n');
|
||||
|
||||
size_t begin_highlight = begin_line; // beginning of context, CONTEXT lines
|
||||
// before the highlight line
|
||||
for (size_t i = 0; begin_highlight > 0; --begin_highlight) {
|
||||
if (str[begin_highlight - 1] == '\n')
|
||||
++i;
|
||||
if (i >= CONTEXT)
|
||||
break;
|
||||
}
|
||||
AT_ASSERT(begin_highlight == 0 || str[begin_highlight - 1] == '\n');
|
||||
|
||||
size_t end_highlight =
|
||||
end_line; // end of context, CONTEXT lines after the highlight line
|
||||
for (size_t i = 0; end_highlight < str.size(); ++end_highlight) {
|
||||
if (str[end_highlight] == '\n')
|
||||
++i;
|
||||
if (i >= CONTEXT)
|
||||
break;
|
||||
}
|
||||
AT_ASSERT(end_highlight == str.size() || str[end_highlight] == '\n');
|
||||
|
||||
out << str.substr(begin_highlight, end_line - begin_highlight) << "\n";
|
||||
out << std::string(start() - begin_line, ' ');
|
||||
size_t len = std::min(size(), end_line - start());
|
||||
out << std::string(len, '~')
|
||||
<< (len < size() ? "... <--- HERE" : " <--- HERE");
|
||||
out << str.substr(end_line, end_highlight - end_line);
|
||||
if (!str.empty() && str.back() != '\n')
|
||||
out << "\n";
|
||||
}
|
||||
const std::string& file() const {
|
||||
return *file_;
|
||||
}
|
||||
@ -38,27 +74,6 @@ struct CAFFE2_API SourceRange {
|
||||
size_t end() const {
|
||||
return end_;
|
||||
}
|
||||
std::string str() const {
|
||||
std::stringstream ss;
|
||||
highlight(ss);
|
||||
return ss.str();
|
||||
}
|
||||
std::string wrapException(
|
||||
const std::exception& e,
|
||||
const std::string& additional = "") {
|
||||
std::stringstream msg;
|
||||
msg << "\n" << e.what() << ":\n";
|
||||
if (!additional.empty()) {
|
||||
msg << additional << ":\n";
|
||||
}
|
||||
highlight(msg);
|
||||
return msg.str();
|
||||
}
|
||||
void wrapAndRethrowException(
|
||||
const std::exception& e,
|
||||
const std::string& additional = "") {
|
||||
throw std::runtime_error(wrapException(e, additional));
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<std::string> file_;
|
||||
@ -66,10 +81,5 @@ struct CAFFE2_API SourceRange {
|
||||
size_t end_;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const SourceRange& range) {
|
||||
range.highlight(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -1332,7 +1332,7 @@ void loadModule(const script::CompilationUnit& module) {
|
||||
Node* forward_tuple = pair.forward->outputs().at(0)->node();
|
||||
|
||||
if (forward_tuple->kind() != prim::TupleConstruct) {
|
||||
throw script::ErrorReport(forward_tuple->sourceRange())
|
||||
throw script::ErrorReport(forward_tuple->getSourceLocation())
|
||||
<< "gradient must return literal a tuple";
|
||||
}
|
||||
|
||||
|
@ -519,10 +519,10 @@ def _check_trace(check_inputs, func, executor_options, traced_func, check_tolera
|
||||
node_diff = difflib.ndiff(str(n_mod).splitlines(True),
|
||||
str(n_check).splitlines(True))
|
||||
source_printout = 'Node diff:\n' + indent(''.join(node_diff)) + '\n'
|
||||
mod_stack = n_mod.sourceRange()
|
||||
mod_stack = n_mod.getSourceLocation()
|
||||
if mod_stack:
|
||||
source_printout += 'Trace source location:\n' + indent(mod_stack) + '\n'
|
||||
check_stack = n_check.sourceRange()
|
||||
check_stack = n_check.getSourceLocation()
|
||||
if check_stack:
|
||||
source_printout += 'Check source location:\n' + indent(check_stack) + '\n'
|
||||
graph_diff_errors += source_printout
|
||||
@ -548,7 +548,7 @@ def _check_trace(check_inputs, func, executor_options, traced_func, check_tolera
|
||||
if tensor_compare_errors is None:
|
||||
tensor_compare_errors = ''
|
||||
tensor_compare_errors += 'Node:\n' + indent(str(n_mod)) + '\n'
|
||||
compare_stack = n_mod.sourceRange()
|
||||
compare_stack = n_mod.getSourceLocation()
|
||||
if compare_stack:
|
||||
tensor_compare_errors += 'Source Location:\n' + indent(compare_stack) + '\n'
|
||||
tensor_compare_errors += 'Comparison exception: ' + indent(str(e))
|
||||
|
Reference in New Issue
Block a user