#include #include #include #include #include #include namespace py = pybind11; namespace torch { namespace jit { namespace script { struct SourceRangeFactory { SourceRangeFactory(std::string source) : source_(std::make_shared(std::move(source))) { size_t pos = 0; do { line_len_prefix_sum_.push_back(pos); pos++; } while ((pos = source_->find('\n', pos)) != std::string::npos); } SourceRange create(int line, int start_col, int end_col) { // Python has a weird convention where col_offset points to the column *before* // the token starts. start_col++; end_col++; // Also, lines are counted from 1. line--; auto line_start = line_len_prefix_sum_.at(line); return SourceRange(source_, line_start + start_col, line_start + end_col); } std::shared_ptr source_; std::vector line_len_prefix_sum_; }; template List wrap_list(const SourceRange& fallback_pos, std::vector&& vec) { if (vec.empty()) return List::create(fallback_pos, std::move(vec)); return List::create(vec.front().range(), std::move(vec)); } template Maybe wrap_maybe(const SourceRange& fallback_pos, T* val) { return val ? Maybe::create(val->range(), *val) : Maybe::create(fallback_pos); } void initTreeViewBindings(PyObject *module) { auto _C = py::handle(module).cast(); auto m = _C.def_submodule("_jit_tree_views"); py::class_(m, "SourceRange") .def("highlight", [](const SourceRange& self) { std::ostringstream stream; self.highlight(stream); return stream.str(); }) .def_property_readonly("start", &SourceRange::start) .def_property_readonly("end", &SourceRange::end); py::class_(m, "SourceRangeFactory") .def(py::init()) .def("make_range", &SourceRangeFactory::create) .def("make_raw_range", [](const SourceRangeFactory& self, size_t start, size_t end) { return SourceRange(self.source_, start, end); }) .def_property_readonly("source", [](const SourceRangeFactory& self) { return *self.source_; }); py::class_(m, "TreeView") .def("range", &TreeView::range) .def("__str__", [](const TreeView& tree) { std::ostringstream stream; stream << tree.get(); return stream.str(); }); py::class_(m, "Ident") .def(py::init(&Ident::create)) .def_property_readonly( "name", [](const Ident& self) { return self.name(); }); py::class_(m, "Param") .def(py::init([](const Expr& type, const Ident& name) { return Param::create(name.range(), name, type, Maybe::create(name.range())); })); py::class_(m, "Attribute") .def(py::init([](const Ident& name, const Expr& value) { return Attribute::create(name.range(), name, value); })); m.def("TrueLiteral", [](const SourceRange& range) { return Expr(Compound::create(TK_TRUE, range, {})); }); m.def("FalseLiteral", [](const SourceRange& range) { return Expr(Compound::create(TK_FALSE, range, {})); }); m.def("NoneLiteral", [](const SourceRange& range) { return Expr(Compound::create(TK_NONE, range, {})); }); py::class_(m, "Stmt"); // NOLINT(bugprone-unused-raii) py::class_(m, "Expr"); // NOLINT(bugprone-unused-raii) py::class_(m, "Def") .def(py::init([](const Ident& name, Decl decl, std::vector body) { const auto& r = name.range(); return Def::create(r, name, decl, wrap_list(r, std::move(body))); })); py::class_(m, "Decl") .def(py::init([](const SourceRange& r, std::vector params, Expr *return_type) { return Decl::create(r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type)); })); py::class_(m, "Assign") .def(py::init([](const Expr& lhs, const Expr& rhs) { return Assign::create(lhs.range(), lhs, rhs); })); py::class_(m, "AugAssign") .def(py::init([](const Expr& lhs, std::string kind_str, const Expr& rhs) { const auto& r = lhs.range(); auto kind = AugAssignKind(Compound::create(stringToKind(kind_str), r, {})); return AugAssign::create(r, lhs, kind, rhs); })); py::class_(m, "Return") .def(py::init([](const SourceRange& range, std::vector values) { return Return::create(range, wrap_list(range, std::move(values))); })); py::class_(m, "Raise") .def(py::init([](const SourceRange& range, Expr *expr) { return Raise::create(range, wrap_maybe(range, expr)); })); py::class_(m, "Assert") .def(py::init([](const SourceRange& range, const Expr& test, Expr *msg) { return Assert::create(range, test, wrap_maybe(range, msg)); })); py::class_(m, "Pass") .def(py::init([](const SourceRange& range) { return Pass::create(range); })); py::class_(m, "If") .def(py::init([](const SourceRange& range, const Expr& cond, std::vector true_branch, std::vector false_branch) { return If::create(range, cond, wrap_list(range, std::move(true_branch)), wrap_list(range, std::move(false_branch))); })); py::class_(m, "While") .def(py::init([](const SourceRange& range, const Expr& cond, std::vector body) { return While::create(range, cond, wrap_list(range, std::move(body))); })); py::class_(m, "For").def(py::init([](const SourceRange range, std::vector& targets, std::vector& itrs, std::vector body) { return For::create( range, wrap_list(range, std::move(targets)), wrap_list(range, std::move(itrs)), wrap_list(range, std::move(body))); })); py::class_(m, "ExprStmt") .def(py::init([](const Expr& expr) { return ExprStmt::create(expr.range(), expr); })); py::class_(m, "Var") .def(py::init([](const Ident& name) { return Var::create(name.range(), name); })) .def_property_readonly("name", [](const Var& var) { return var.name(); }); py::class_(m, "BinOp") .def(py::init([](std::string kind, const Expr& lhs, const Expr& rhs) { return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs); })); // NB: we take range here, because unary ops precede their exprs, so we need to include them py::class_(m, "UnaryOp") .def(py::init([](const SourceRange& range, std::string kind, const Expr& expr) { auto resolved_kind = stringToKind(kind); resolved_kind = resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind; return UnaryOp::create(range, resolved_kind, expr); })); py::class_(m, "Const") .def(py::init([](const SourceRange& range, std::string value) { return Const::create(range, value); })); py::class_(m, "StringLiteral") .def(py::init([](const SourceRange& range, std::string value) { return StringLiteral::create(range, value); })); py::class_(m, "Apply") .def(py::init([](const Expr& expr, std::vector args, std::vector kwargs) { const auto& r = expr.range(); return Apply::create(expr.range(), expr, wrap_list(r, std::move(args)), wrap_list(r, std::move(kwargs))); })); py::class_(m, "Select") .def(py::init([](const Expr& expr, const Ident& field) { const auto& r = expr.range(); return Select::create(expr.range(), expr, field); })); py::class_(m, "TernaryIf") .def(py::init([](const Expr& cond, const Expr& true_expr, const Expr& false_expr) { return TernaryIf::create(cond.range(), cond, true_expr, false_expr); })); py::class_(m, "ListLiteral") .def(py::init([](const SourceRange& range, std::vector args) { return ListLiteral::create(range, wrap_list(range, std::move(args))); })); py::class_(m, "TupleLiteral") .def(py::init([](const SourceRange& range, std::vector args) { return TupleLiteral::create(range, wrap_list(range, std::move(args))); })); py::class_(m, "Subscript") .def(py::init([](const Expr& base, std::vector subscript_exprs) { return Subscript::create(base.range(), base, wrap_list(base.range(), std::move(subscript_exprs))); })); py::class_(m, "SliceExpr") .def(py::init([](const SourceRange& range, Expr *lower, Expr *upper) { return SliceExpr::create(range, wrap_maybe(range, lower), wrap_maybe(range, upper)); })); py::class_(m, "Starred") .def(py::init([](const SourceRange& range, Expr expr){ return Starred::create(range, expr); })); } }}} // namespace torch::jit::script