#include #include #include #include #include #include namespace torch::jit { Decl mergeTypesFromTypeComment( const Decl& decl, const Decl& type_annotation_decl, bool is_method) { auto expected_num_annotations = decl.params().size(); if (is_method) { // `self` argument expected_num_annotations -= 1; } if (expected_num_annotations != type_annotation_decl.params().size()) { throw ErrorReport(decl.range()) << "Number of type annotations (" << type_annotation_decl.params().size() << ") did not match the number of " << (is_method ? "method" : "function") << " parameters (" << expected_num_annotations << ")"; } auto old = decl.params(); auto _new = type_annotation_decl.params(); // Merge signature idents and ranges with annotation types std::vector new_params; size_t i = is_method ? 1 : 0; size_t j = 0; if (is_method) { new_params.push_back(old[0]); } for (; i < decl.params().size(); ++i, ++j) { new_params.emplace_back(old[i].withType(_new[j].type())); } return Decl::create( decl.range(), List::create(decl.range(), new_params), type_annotation_decl.return_type()); } struct ParserImpl { explicit ParserImpl(const std::shared_ptr& source) : L(source), shared(sharedParserData()) {} Ident parseIdent() { auto t = L.expect(TK_IDENT); // whenever we parse something that has a TreeView type we always // use its create method so that the accessors and the constructor // of the Compound tree are in the same place. return Ident::create(t.range, t.text()); } TreeRef createApply(const Expr& expr) { TreeList attributes; auto range = L.cur().range; TreeList inputs; parseArguments(inputs, attributes); return Apply::create( range, expr, List(makeList(range, std::move(inputs))), List(makeList(range, std::move(attributes)))); } static bool followsTuple(int kind) { switch (kind) { case TK_PLUS_EQ: case TK_MINUS_EQ: case TK_TIMES_EQ: case TK_DIV_EQ: case TK_MOD_EQ: case TK_BIT_OR_EQ: case TK_BIT_AND_EQ: case TK_BIT_XOR_EQ: case TK_LSHIFT_EQ: case TK_RSHIFT_EQ: case TK_POW_EQ: case TK_NEWLINE: case '=': case ')': return true; default: return false; } } // exp | expr, | expr, expr, ... Expr parseExpOrExpTuple() { auto prefix = parseExp(); if (L.cur().kind == ',') { std::vector exprs = {prefix}; while (L.nextIf(',')) { if (followsTuple(L.cur().kind)) break; exprs.push_back(parseExp()); } auto list = List::create(prefix.range(), exprs); prefix = TupleLiteral::create(list.range(), list); } return prefix; } // things like a 1.0 or a(4) that are not unary/binary expressions // and have higher precedence than all of them TreeRef parseBaseExp() { TreeRef prefix; switch (L.cur().kind) { case TK_NUMBER: { prefix = parseConst(); } break; case TK_TRUE: case TK_FALSE: case TK_NONE: case TK_NONE_TYPE: { auto k = L.cur().kind; auto r = L.cur().range; prefix = create_compound(k, r, {}); L.next(); } break; case '(': { L.next(); if (L.nextIf(')')) { /// here we have the empty tuple case std::vector vecExpr; List listExpr = List::create(L.cur().range, vecExpr); prefix = TupleLiteral::create(L.cur().range, listExpr); break; } prefix = parseExpOrExpTuple(); L.expect(')'); } break; case '[': { auto list = parseList('[', ',', ']', &ParserImpl::parseExp); if (list.size() == 1 && (*list.begin()).kind() == TK_LIST_COMP) { prefix = *list.begin(); } else { for (auto se : list) { if (se.kind() == TK_LIST_COMP) { throw ErrorReport(list.range()) << " expected a single list comprehension within '[' , ']'"; } } prefix = ListLiteral::create(list.range(), List(list)); } } break; case '{': { L.next(); // If we have a dict literal, `keys` and `values` will store the keys // and values used in the object's construction. EDGE CASE: We have a // dict comprehension, so we'll get the first element of the dict // comprehension in `keys` and a list comprehension in `values`. // For example, `{i : chr(i + 65) for i in range(4)}` would give us // `i` in `keys` and `chr(i + 65) for i in range(4)` in `values`. // The optimal way of handling this case is to simply splice the new // dict comprehension together from the existing list comprehension. // Splicing prevents breaking changes to our API and does not require // the use of global variables. std::vector keys; std::vector values; auto range = L.cur().range; if (L.cur().kind != '}') { do { keys.push_back(parseExp()); L.expect(':'); values.push_back(parseExp()); } while (L.nextIf(',')); } L.expect('}'); if (keys.size() == 1 && (*values.begin()).kind() == TK_LIST_COMP) { ListComp lc(*values.begin()); prefix = DictComp::create( range, *keys.begin(), lc.elt(), lc.target(), lc.iter()); } else { prefix = DictLiteral::create( range, List::create(range, keys), List::create(range, values)); } } break; case TK_STRINGLITERAL: { prefix = parseConcatenatedStringLiterals(); } break; case TK_ELLIPSIS: case TK_DOTS: { prefix = Dots::create(L.cur().range); L.next(); } break; default: { Ident name = parseIdent(); prefix = Var::create(name.range(), name); } break; } while (true) { if (L.nextIf('.')) { const auto name = parseIdent(); prefix = Select::create(name.range(), Expr(prefix), Ident(name)); } else if (L.cur().kind == '(') { prefix = createApply(Expr(prefix)); } else if (L.cur().kind == '[') { prefix = parseSubscript(prefix); } else { break; } } return prefix; } std::optional maybeParseAssignmentOp() { auto r = L.cur().range; switch (L.cur().kind) { case TK_PLUS_EQ: case TK_MINUS_EQ: case TK_TIMES_EQ: case TK_DIV_EQ: case TK_BIT_OR_EQ: case TK_BIT_AND_EQ: case TK_BIT_XOR_EQ: case TK_MOD_EQ: { int modifier = L.next().text()[0]; return create_compound(modifier, r, {}); } break; case TK_LSHIFT_EQ: { L.next(); return create_compound(TK_LSHIFT, r, {}); } break; case TK_RSHIFT_EQ: { L.next(); return create_compound(TK_RSHIFT, r, {}); } break; case TK_POW_EQ: { L.next(); return create_compound(TK_POW, r, {}); } break; case '=': { L.next(); return create_compound('=', r, {}); // no reduction } break; default: return std::nullopt; } } TreeRef parseTrinary( TreeRef true_branch, const SourceRange& range, int binary_prec) { auto cond = parseExp(); L.expect(TK_ELSE); auto false_branch = parseExp(binary_prec); return create_compound( TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch}); } // parse the longest expression whose binary operators have // precedence strictly greater than 'precedence' // precedence == 0 will parse _all_ expressions // this is the core loop of 'top-down precedence parsing' Expr parseExp() { return parseExp(0); } Expr parseExp(int precedence) { TreeRef prefix; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int unary_prec; if (shared.isUnary(L.cur().kind, &unary_prec)) { auto kind = L.cur().kind; auto pos = L.cur().range; L.next(); auto unary_kind = kind == '*' ? TK_STARRED : kind == '-' ? TK_UNARY_MINUS : kind; auto subexp = parseExp(unary_prec); // fold '-' into constant numbers, so that attributes can accept // things like -1 if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) { prefix = Const::create(subexp.range(), "-" + Const(subexp).text()); } else { prefix = create_compound(unary_kind, pos, {subexp}); } } else { prefix = parseBaseExp(); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int binary_prec; while (shared.isBinary(L.cur().kind, &binary_prec)) { if (binary_prec <= precedence) // not allowed to parse something which is // not greater than 'precedence' break; int kind = L.cur().kind; auto pos = L.cur().range; L.next(); if (shared.isRightAssociative(kind)) binary_prec--; if (kind == TK_NOTIN) { // NB: `not in` is just `not( in )`, so we don't introduce new tree view // but just make it a nested call in our tree view structure prefix = create_compound(TK_IN, pos, {prefix, parseExp(binary_prec)}); prefix = create_compound(TK_NOT, pos, {prefix}); continue; } // special case for trinary operator if (kind == TK_IF) { prefix = parseTrinary(prefix, pos, binary_prec); continue; } if (kind == TK_FOR) { // TK_FOR targets should only parse exprs prec greater than 4, which // only includes subset of Exprs that suppose to be on the LHS according // to the python grammar // https://docs.python.org/3/reference/grammar.html auto target = parseLHSExp(); L.expect(TK_IN); auto iter = parseExp(); prefix = ListComp::create(pos, Expr(prefix), target, iter); continue; } prefix = create_compound(kind, pos, {prefix, parseExp(binary_prec)}); } return Expr(prefix); } void parseSequence( int begin, int sep, int end, const std::function& parse) { if (begin != TK_NOTHING) { L.expect(begin); } while (end != L.cur().kind) { parse(); if (!L.nextIf(sep)) { if (end != TK_NOTHING) { L.expect(end); } return; } } L.expect(end); } template List parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) { auto r = L.cur().range; std::vector elements; parseSequence( begin, sep, end, [&] { elements.emplace_back((this->*parse)()); }); return List::create(r, elements); } Const parseConst() { auto range = L.cur().range; auto t = L.expect(TK_NUMBER); return Const::create(t.range, t.text()); } StringLiteral parseConcatenatedStringLiterals() { auto range = L.cur().range; std::string ss; while (L.cur().kind == TK_STRINGLITERAL) { auto literal_range = L.cur().range; ss.append(parseStringLiteral(literal_range, L.next().text())); } return StringLiteral::create(range, ss); } Expr parseAttributeValue() { return parseExp(); } void parseArguments(TreeList& inputs, TreeList& attributes) { parseSequence('(', ',', ')', [&] { if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') { auto ident = parseIdent(); L.expect('='); auto v = parseAttributeValue(); attributes.push_back(Attribute::create(ident.range(), Ident(ident), v)); } else { inputs.push_back(parseExp()); } }); } // parse LHS acceptable exprs, which only includes subset of Exprs that prec // is greater than 4 according to the python grammar Expr parseLHSExp() { return parseExp(4); } // Parse expr's of the form [a:], [:b], [a:b], [:] and all variations with // "::" Expr parseSubscriptExp() { TreeRef first, second, third; auto range = L.cur().range; if (L.cur().kind != ':') { first = parseExp(); } if (L.nextIf(':')) { if (L.cur().kind != ',' && L.cur().kind != ']' && L.cur().kind != ':') { second = parseExp(); } if (L.nextIf(':')) { if (L.cur().kind != ',' && L.cur().kind != ']') { third = parseExp(); } } auto maybe_first = first ? Maybe::create(range, Expr(first)) : Maybe::create(range); auto maybe_second = second ? Maybe::create(range, Expr(second)) : Maybe::create(range); auto maybe_third = third ? Maybe::create(range, Expr(third)) : Maybe::create(range); return SliceExpr::create(range, maybe_first, maybe_second, maybe_third); } else { return Expr(first); } } TreeRef parseSubscript(const TreeRef& value) { const auto range = L.cur().range; auto subscript_exprs = parseList('[', ',', ']', &ParserImpl::parseSubscriptExp); const auto whole_range = SourceRange(range.source(), range.start(), L.cur().range.start()); return Subscript::create(whole_range, Expr(value), subscript_exprs); } Maybe maybeParseTypeAnnotation() { if (L.nextIf(':')) { // NB: parseExp must not be called inline, since argument evaluation order // changes when L.cur().range is mutated with respect to the parseExp() // call. auto expr = parseExp(); return Maybe::create(expr.range(), expr); } else { return Maybe::create(L.cur().range); } } TreeRef parseFormalParam(bool kwarg_only) { auto ident = parseIdent(); TreeRef type = maybeParseTypeAnnotation(); TreeRef def; if (L.nextIf('=')) { // NB: parseExp must not be called inline, since argument evaluation order // changes when L.cur().range is mutated with respect to the parseExp() // call. auto expr = parseExp(); def = Maybe::create(expr.range(), expr); } else { def = Maybe::create(L.cur().range); } return Param::create( type->range(), Ident(ident), Maybe(type), Maybe(def), kwarg_only); } Param parseBareTypeAnnotation() { auto type = parseExp(); return Param::create( type.range(), Ident::create(type.range(), ""), Maybe::create(type.range(), type), Maybe::create(type.range()), /*kwarg_only=*/false); } Decl parseTypeComment() { auto range = L.cur().range; L.expect(TK_TYPE_COMMENT); auto param_types = parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation); TreeRef return_type; if (L.nextIf(TK_ARROW)) { auto return_type_range = L.cur().range; return_type = Maybe::create(return_type_range, parseExp()); } else { return_type = Maybe::create(L.cur().range); } return Decl::create(range, param_types, Maybe(return_type)); } // 'first' has already been parsed since expressions can exist // alone on a line: // first[,other,lhs] = rhs TreeRef parseAssign(const Expr& lhs) { auto type = maybeParseTypeAnnotation(); auto maybeOp = maybeParseAssignmentOp(); if (maybeOp) { // There is an assignment operator, parse the RHS and generate the // assignment. auto rhs = parseExpOrExpTuple(); if (maybeOp.value()->kind() == '=') { std::vector lhs_list = {lhs}; while (L.nextIf('=')) { lhs_list.push_back(rhs); rhs = parseExpOrExpTuple(); } if (type.present() && lhs_list.size() > 1) { throw ErrorReport(type.range()) << "Annotated multiple assignment is not supported in python"; } L.expect(TK_NEWLINE); return Assign::create( lhs.range(), List::create(lhs_list[0].range(), lhs_list), Maybe::create(rhs.range(), rhs), type); } else { L.expect(TK_NEWLINE); // this is an augmented assignment if (lhs.kind() == TK_TUPLE_LITERAL) { throw ErrorReport(lhs.range()) << " augmented assignment can only have one LHS expression"; } return AugAssign::create( lhs.range(), lhs, AugAssignKind(*maybeOp), Expr(rhs)); } } else { // There is no assignment operator, so this is of the form `lhs : ` TORCH_INTERNAL_ASSERT(type.present()); L.expect(TK_NEWLINE); return Assign::create( lhs.range(), List::create(lhs.range(), {lhs}), Maybe::create(lhs.range()), type); } } TreeRef parseStmt(bool in_class = false) { switch (L.cur().kind) { case TK_IF: return parseIf(); case TK_WHILE: return parseWhile(); case TK_FOR: return parseFor(); case TK_GLOBAL: { auto range = L.next().range; auto idents = parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent); L.expect(TK_NEWLINE); return Global::create(range, idents); } case TK_RETURN: { auto range = L.next().range; Expr value = L.cur().kind != TK_NEWLINE ? parseExpOrExpTuple() : Expr(create_compound(TK_NONE, range, {})); L.expect(TK_NEWLINE); return Return::create(range, value); } case TK_RAISE: { auto range = L.next().range; auto expr = parseExp(); L.expect(TK_NEWLINE); return Raise::create(range, expr); } case TK_ASSERT: { auto range = L.next().range; auto cond = parseExp(); Maybe maybe_first = Maybe::create(range); if (L.nextIf(',')) { auto msg = parseExp(); maybe_first = Maybe::create(range, Expr(msg)); } L.expect(TK_NEWLINE); return Assert::create(range, cond, maybe_first); } case TK_BREAK: { auto range = L.next().range; L.expect(TK_NEWLINE); return Break::create(range); } case TK_CONTINUE: { auto range = L.next().range; L.expect(TK_NEWLINE); return Continue::create(range); } case TK_PASS: { auto range = L.next().range; L.expect(TK_NEWLINE); return Pass::create(range); } case TK_DEF: { return parseFunction(/*is_method=*/in_class); } case TK_DELETE: { auto range = L.next().range; auto targets = parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp); L.expect(TK_NEWLINE); return Delete::create(range, targets); } case TK_WITH: { return parseWith(); } default: { auto lhs = parseExpOrExpTuple(); if (L.cur().kind != TK_NEWLINE) { return parseAssign(lhs); } else { L.expect(TK_NEWLINE); return ExprStmt::create(lhs.range(), lhs); } } } } WithItem parseWithItem() { auto target = parseExp(); if (L.cur().kind == TK_AS) { // If the current token is TK_AS, this with item is of the form // "expression as target". auto token = L.expect(TK_AS); Ident ident = parseIdent(); auto var = Var::create(ident.range(), ident); return WithItem::create( token.range, target, Maybe::create(ident.range(), var)); } else { // If not, this with item is of the form "expression". return WithItem::create( target.range(), target, Maybe::create(target.range())); } } TreeRef parseIf(bool expect_if = true) { auto r = L.cur().range; if (expect_if) L.expect(TK_IF); auto cond = parseExp(); L.expect(':'); auto true_branch = parseStatements(/*expect_indent=*/true); auto false_branch = makeList(L.cur().range, {}); if (L.nextIf(TK_ELSE)) { L.expect(':'); false_branch = parseStatements(/*expect_indent=*/true); } else if (L.nextIf(TK_ELIF)) { // NB: this needs to be a separate statement, since the call to parseIf // mutates the lexer state, and thus causes a heap-use-after-free in // compilers which evaluate argument expressions LTR auto range = L.cur().range; false_branch = makeList(range, {parseIf(false)}); } return If::create( r, Expr(cond), List(true_branch), List(false_branch)); } TreeRef parseWhile() { auto r = L.cur().range; L.expect(TK_WHILE); auto cond = parseExp(); L.expect(':'); auto body = parseStatements(/*expect_indent=*/true); return While::create(r, Expr(cond), List(body)); } TreeRef parseFor() { auto r = L.cur().range; L.expect(TK_FOR); auto targets = parseList(TK_NOTHING, ',', TK_IN, &ParserImpl::parseLHSExp); auto itrs = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseExp); auto body = parseStatements(/*expect_indent=*/true); return For::create(r, targets, itrs, body); } TreeRef parseWith() { auto r = L.cur().range; // Parse "with expression [as target][, expression [as target]]*:". L.expect(TK_WITH); auto targets = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseWithItem); // Parse the body. auto body = parseStatements(/*expect_indent=*/true); return With::create(r, targets, body); } TreeRef parseStatements(bool expect_indent, bool in_class = false) { auto r = L.cur().range; if (expect_indent) { L.expect(TK_INDENT); } TreeList stmts; do { stmts.push_back(parseStmt(in_class)); } while (!L.nextIf(TK_DEDENT)); return create_compound(TK_LIST, r, std::move(stmts)); } Maybe parseReturnAnnotation() { if (L.nextIf(TK_ARROW)) { // Exactly one expression for return type annotation auto return_type_range = L.cur().range; return Maybe::create(return_type_range, parseExp()); } else { return Maybe::create(L.cur().range); } } List parseFormalParams() { auto r = L.cur().range; std::vector params; bool kwarg_only = false; parseSequence('(', ',', ')', [&] { if (!kwarg_only && L.nextIf('*')) { kwarg_only = true; } else { params.emplace_back(parseFormalParam(kwarg_only)); } }); return List::create(r, params); } Decl parseDecl() { // Parse return type annotation List paramlist = parseFormalParams(); TreeRef return_type; Maybe return_annotation = parseReturnAnnotation(); L.expect(':'); return Decl::create( paramlist.range(), List(paramlist), return_annotation); } TreeRef parseClass() { L.expect(TK_CLASS_DEF); const auto name = parseIdent(); Maybe superclass = Maybe::create(name.range()); if (L.nextIf('(')) { // Only support inheriting from NamedTuple right now. auto id = parseExp(); superclass = Maybe::create(id.range(), id); L.expect(')'); } L.expect(':'); const auto statements = parseStatements(/*expect_indent=*/true, /*in_class=*/true); return ClassDef::create( name.range(), name, superclass, List(statements)); } TreeRef parseFunction(bool is_method) { L.expect(TK_DEF); auto name = parseIdent(); auto decl = parseDecl(); TreeRef stmts_list; if (L.nextIf(TK_INDENT)) { // Handle type annotations specified in a type comment as the first line // of the function. if (L.cur().kind == TK_TYPE_COMMENT) { auto type_annotation_decl = Decl(parseTypeComment()); L.expect(TK_NEWLINE); decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method); } stmts_list = parseStatements(false); } else { // Special case: the Python grammar allows one-line functions with a // single statement. if (L.cur().kind == TK_TYPE_COMMENT) { auto type_annotation_decl = Decl(parseTypeComment()); decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method); } TreeList stmts; stmts.push_back(parseStmt(is_method)); stmts_list = create_compound(TK_LIST, L.cur().range, std::move(stmts)); } return Def::create( name.range(), Ident(name), Decl(decl), List(stmts_list)); } Lexer& lexer() { return L; } private: // short helpers to create nodes TreeRef create_compound( int kind, const SourceRange& range, TreeList&& trees) { return Compound::create(kind, range, std::move(trees)); } TreeRef makeList(const SourceRange& range, TreeList&& trees) { return create_compound(TK_LIST, range, std::move(trees)); } Lexer L; SharedParserData& shared; }; Parser::Parser(const std::shared_ptr& src) : pImpl(new ParserImpl(src)) {} Parser::~Parser() = default; TreeRef Parser::parseFunction(bool is_method) { return pImpl->parseFunction(is_method); } TreeRef Parser::parseClass() { return pImpl->parseClass(); } Lexer& Parser::lexer() { return pImpl->lexer(); } Decl Parser::parseTypeComment() { return pImpl->parseTypeComment(); } Expr Parser::parseExp() { return pImpl->parseExp(); } } // namespace torch::jit