mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Summary: Fix https://github.com/pytorch/pytorch/issues/38336 Add %= support in TorchScript. It's now possible to do something like: ```py torch.jit.script def mm(a,b): a %= b return a ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/38983 Differential Revision: D21803523 Pulled By: SplitInfinity fbshipit-source-id: 3437860d06d32e26ca9a5497099148c1f1616c5b
		
			
				
	
	
		
			749 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			749 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #include <torch/csrc/jit/frontend/parser.h>
 | |
| #include <c10/util/Optional.h>
 | |
| #include <torch/csrc/jit/frontend/lexer.h>
 | |
| #include <torch/csrc/jit/frontend/parse_string_literal.h>
 | |
| #include <torch/csrc/jit/frontend/tree.h>
 | |
| #include <torch/csrc/jit/frontend/tree_views.h>
 | |
| 
 | |
| namespace torch {
 | |
| namespace jit {
 | |
| 
 | |
| Decl mergeTypesFromTypeComment(
 | |
|     const Decl& decl,
 | |
|     const Decl& type_annotation_decl,
 | |
|     bool is_method) {
 | |
|   auto expected_num_annotations = decl.params().size();
 | |
|   if (is_method) {
 | |
|     // `self` argument
 | |
|     expected_num_annotations -= 1;
 | |
|   }
 | |
|   if (expected_num_annotations != type_annotation_decl.params().size()) {
 | |
|     throw ErrorReport(decl.range())
 | |
|         << "Number of type annotations ("
 | |
|         << type_annotation_decl.params().size()
 | |
|         << ") did not match the number of "
 | |
|         << (is_method ? "method" : "function") << " parameters ("
 | |
|         << expected_num_annotations << ")";
 | |
|   }
 | |
|   auto old = decl.params();
 | |
|   auto _new = type_annotation_decl.params();
 | |
|   // Merge signature idents and ranges with annotation types
 | |
| 
 | |
|   std::vector<Param> new_params;
 | |
|   size_t i = is_method ? 1 : 0;
 | |
|   size_t j = 0;
 | |
|   if (is_method) {
 | |
|     new_params.push_back(old[0]);
 | |
|   }
 | |
|   for (; i < decl.params().size(); ++i, ++j) {
 | |
|     new_params.emplace_back(old[i].withType(_new[j].type()));
 | |
|   }
 | |
|   return Decl::create(
 | |
|       decl.range(),
 | |
|       List<Param>::create(decl.range(), new_params),
 | |
|       type_annotation_decl.return_type());
 | |
| }
 | |
| 
 | |
| struct ParserImpl {
 | |
|   explicit ParserImpl(const std::shared_ptr<Source>& source)
 | |
|       : L(source), shared(sharedParserData()) {}
 | |
| 
 | |
|   Ident parseIdent() {
 | |
|     auto t = L.expect(TK_IDENT);
 | |
|     // whenever we parse something that has a TreeView type we always
 | |
|     // use its create method so that the accessors and the constructor
 | |
|     // of the Compound tree are in the same place.
 | |
|     return Ident::create(t.range, t.text());
 | |
|   }
 | |
|   TreeRef createApply(const Expr& expr) {
 | |
|     TreeList attributes;
 | |
|     auto range = L.cur().range;
 | |
|     TreeList inputs;
 | |
|     parseArguments(inputs, attributes);
 | |
|     return Apply::create(
 | |
|         range,
 | |
|         expr,
 | |
|         List<Expr>(makeList(range, std::move(inputs))),
 | |
|         List<Attribute>(makeList(range, std::move(attributes))));
 | |
|   }
 | |
| 
 | |
|   static bool followsTuple(int kind) {
 | |
|     switch (kind) {
 | |
|       case TK_PLUS_EQ:
 | |
|       case TK_MINUS_EQ:
 | |
|       case TK_TIMES_EQ:
 | |
|       case TK_DIV_EQ:
 | |
|       case TK_MOD_EQ:
 | |
|       case TK_NEWLINE:
 | |
|       case '=':
 | |
|       case ')':
 | |
|         return true;
 | |
|       default:
 | |
|         return false;
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   // exp | expr, | expr, expr, ...
 | |
|   Expr parseExpOrExpTuple() {
 | |
|     auto prefix = parseExp();
 | |
|     if (L.cur().kind == ',') {
 | |
|       std::vector<Expr> exprs = {prefix};
 | |
|       while (L.nextIf(',')) {
 | |
|         if (followsTuple(L.cur().kind))
 | |
|           break;
 | |
|         exprs.push_back(parseExp());
 | |
|       }
 | |
|       auto list = List<Expr>::create(prefix.range(), exprs);
 | |
|       prefix = TupleLiteral::create(list.range(), list);
 | |
|     }
 | |
|     return prefix;
 | |
|   }
 | |
|   // things like a 1.0 or a(4) that are not unary/binary expressions
 | |
|   // and have higher precedence than all of them
 | |
|   TreeRef parseBaseExp() {
 | |
|     TreeRef prefix;
 | |
|     switch (L.cur().kind) {
 | |
|       case TK_NUMBER: {
 | |
|         prefix = parseConst();
 | |
|       } break;
 | |
|       case TK_TRUE:
 | |
|       case TK_FALSE:
 | |
|       case TK_NONE: {
 | |
|         auto k = L.cur().kind;
 | |
|         auto r = L.cur().range;
 | |
|         prefix = create_compound(k, r, {});
 | |
|         L.next();
 | |
|       } break;
 | |
|       case '(': {
 | |
|         L.next();
 | |
|         if (L.nextIf(')')) {
 | |
|           /// here we have the empty tuple case
 | |
|           std::vector<Expr> vecExpr;
 | |
|           List<Expr> listExpr = List<Expr>::create(L.cur().range, vecExpr);
 | |
|           prefix = TupleLiteral::create(L.cur().range, listExpr);
 | |
|           break;
 | |
|         }
 | |
|         prefix = parseExpOrExpTuple();
 | |
|         L.expect(')');
 | |
|       } break;
 | |
|       case '[': {
 | |
|         auto list = parseList('[', ',', ']', &ParserImpl::parseExp);
 | |
| 
 | |
|         if (list.size() == 1 && (*list.begin()).kind() == TK_LIST_COMP) {
 | |
|           prefix = *list.begin();
 | |
|         } else {
 | |
|           for (auto se : list) {
 | |
|             if (se.kind() == TK_LIST_COMP) {
 | |
|               throw ErrorReport(list.range())
 | |
|                   << " expected a single list comprehension within '[' , ']'";
 | |
|             }
 | |
|           }
 | |
|           prefix = ListLiteral::create(list.range(), List<Expr>(list));
 | |
|         }
 | |
| 
 | |
|       } break;
 | |
|       case '{': {
 | |
|         L.next();
 | |
|         std::vector<Expr> keys;
 | |
|         std::vector<Expr> values;
 | |
|         auto range = L.cur().range;
 | |
|         if (L.cur().kind != '}') {
 | |
|           do {
 | |
|             keys.push_back(parseExp());
 | |
|             L.expect(':');
 | |
|             values.push_back(parseExp());
 | |
|           } while (L.nextIf(','));
 | |
|         }
 | |
|         L.expect('}');
 | |
|         prefix = DictLiteral::create(
 | |
|             range,
 | |
|             List<Expr>::create(range, keys),
 | |
|             List<Expr>::create(range, values));
 | |
|       } break;
 | |
|       case TK_STRINGLITERAL: {
 | |
|         prefix = parseConcatenatedStringLiterals();
 | |
|       } break;
 | |
|       case TK_DOTS: {
 | |
|         prefix = Dots::create(L.cur().range);
 | |
|         L.next();
 | |
|       } break;
 | |
|       default: {
 | |
|         Ident name = parseIdent();
 | |
|         prefix = Var::create(name.range(), name);
 | |
|       } break;
 | |
|     }
 | |
|     while (true) {
 | |
|       if (L.nextIf('.')) {
 | |
|         const auto name = parseIdent();
 | |
|         prefix = Select::create(name.range(), Expr(prefix), Ident(name));
 | |
|       } else if (L.cur().kind == '(') {
 | |
|         prefix = createApply(Expr(prefix));
 | |
|       } else if (L.cur().kind == '[') {
 | |
|         prefix = parseSubscript(prefix);
 | |
|       } else {
 | |
|         break;
 | |
|       }
 | |
|     }
 | |
|     return prefix;
 | |
|   }
 | |
|   c10::optional<TreeRef> maybeParseAssignmentOp() {
 | |
|     auto r = L.cur().range;
 | |
|     switch (L.cur().kind) {
 | |
|       case TK_PLUS_EQ:
 | |
|       case TK_MINUS_EQ:
 | |
|       case TK_TIMES_EQ:
 | |
|       case TK_DIV_EQ:
 | |
|       case TK_MOD_EQ: {
 | |
|         int modifier = L.next().text()[0];
 | |
|         return create_compound(modifier, r, {});
 | |
|       } break;
 | |
|       case '=': {
 | |
|         L.next();
 | |
|         return create_compound('=', r, {}); // no reduction
 | |
|       } break;
 | |
|       default:
 | |
|         return c10::nullopt;
 | |
|     }
 | |
|   }
 | |
|   TreeRef parseTrinary(
 | |
|       TreeRef true_branch,
 | |
|       const SourceRange& range,
 | |
|       int binary_prec) {
 | |
|     auto cond = parseExp();
 | |
|     L.expect(TK_ELSE);
 | |
|     auto false_branch = parseExp(binary_prec);
 | |
|     return create_compound(
 | |
|         TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch});
 | |
|   }
 | |
|   // parse the longest expression whose binary operators have
 | |
|   // precedence strictly greater than 'precedence'
 | |
|   // precedence == 0 will parse _all_ expressions
 | |
|   // this is the core loop of 'top-down precedence parsing'
 | |
|   Expr parseExp() {
 | |
|     return parseExp(0);
 | |
|   }
 | |
|   Expr parseExp(int precedence) {
 | |
|     TreeRef prefix;
 | |
|     int unary_prec;
 | |
|     if (shared.isUnary(L.cur().kind, &unary_prec)) {
 | |
|       auto kind = L.cur().kind;
 | |
|       auto pos = L.cur().range;
 | |
|       L.next();
 | |
|       auto unary_kind =
 | |
|           kind == '*' ? TK_STARRED : kind == '-' ? TK_UNARY_MINUS : kind;
 | |
|       auto subexp = parseExp(unary_prec);
 | |
|       // fold '-' into constant numbers, so that attributes can accept
 | |
|       // things like -1
 | |
|       if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
 | |
|         prefix = Const::create(subexp.range(), "-" + Const(subexp).text());
 | |
|       } else {
 | |
|         prefix = create_compound(unary_kind, pos, {subexp});
 | |
|       }
 | |
|     } else {
 | |
|       prefix = parseBaseExp();
 | |
|     }
 | |
|     int binary_prec;
 | |
|     while (shared.isBinary(L.cur().kind, &binary_prec)) {
 | |
|       if (binary_prec <= precedence) // not allowed to parse something which is
 | |
|         // not greater than 'precedence'
 | |
|         break;
 | |
| 
 | |
|       int kind = L.cur().kind;
 | |
|       auto pos = L.cur().range;
 | |
|       L.next();
 | |
|       if (shared.isRightAssociative(kind))
 | |
|         binary_prec--;
 | |
| 
 | |
|       if (kind == TK_NOTIN) {
 | |
|         // NB: `not in` is just `not( in )`, so we don't introduce new tree view
 | |
|         // but just make it a nested call in our tree view structure
 | |
|         prefix = create_compound(TK_IN, pos, {prefix, parseExp(binary_prec)});
 | |
|         prefix = create_compound(TK_NOT, pos, {prefix});
 | |
|         continue;
 | |
|       }
 | |
| 
 | |
|       // special case for trinary operator
 | |
|       if (kind == TK_IF) {
 | |
|         prefix = parseTrinary(prefix, pos, binary_prec);
 | |
|         continue;
 | |
|       }
 | |
| 
 | |
|       if (kind == TK_FOR) {
 | |
|         // TK_FOR targets should only parse exprs prec greater than 4, which
 | |
|         // only includes subset of Exprs that suppose to be on the LHS according
 | |
|         // to the python grammar
 | |
|         // https://docs.python.org/3/reference/grammar.html
 | |
|         auto target = parseLHSExp();
 | |
|         L.expect(TK_IN);
 | |
|         auto iter = parseExp();
 | |
|         prefix = ListComp::create(pos, Expr(prefix), target, iter);
 | |
|         continue;
 | |
|       }
 | |
| 
 | |
|       prefix = create_compound(kind, pos, {prefix, parseExp(binary_prec)});
 | |
|     }
 | |
|     return Expr(prefix);
 | |
|   }
 | |
| 
 | |
|   void parseSequence(
 | |
|       int begin,
 | |
|       int sep,
 | |
|       int end,
 | |
|       const std::function<void()>& parse) {
 | |
|     if (begin != TK_NOTHING) {
 | |
|       L.expect(begin);
 | |
|     }
 | |
|     while (end != L.cur().kind) {
 | |
|       parse();
 | |
|       if (!L.nextIf(sep)) {
 | |
|         if (end != TK_NOTHING) {
 | |
|           L.expect(end);
 | |
|         }
 | |
|         return;
 | |
|       }
 | |
|     }
 | |
|     L.expect(end);
 | |
|   }
 | |
|   template <typename T>
 | |
|   List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
 | |
|     auto r = L.cur().range;
 | |
|     std::vector<T> elements;
 | |
|     parseSequence(
 | |
|         begin, sep, end, [&] { elements.emplace_back((this->*parse)()); });
 | |
|     return List<T>::create(r, elements);
 | |
|   }
 | |
| 
 | |
|   Const parseConst() {
 | |
|     auto range = L.cur().range;
 | |
|     auto t = L.expect(TK_NUMBER);
 | |
|     return Const::create(t.range, t.text());
 | |
|   }
 | |
| 
 | |
|   StringLiteral parseConcatenatedStringLiterals() {
 | |
|     auto range = L.cur().range;
 | |
|     std::string ss;
 | |
|     while (L.cur().kind == TK_STRINGLITERAL) {
 | |
|       auto literal_range = L.cur().range;
 | |
|       ss.append(parseStringLiteral(literal_range, L.next().text()));
 | |
|     }
 | |
|     return StringLiteral::create(range, ss);
 | |
|   }
 | |
| 
 | |
|   Expr parseAttributeValue() {
 | |
|     return parseExp();
 | |
|   }
 | |
| 
 | |
|   void parseArguments(TreeList& inputs, TreeList& attributes) {
 | |
|     parseSequence('(', ',', ')', [&] {
 | |
|       if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') {
 | |
|         auto ident = parseIdent();
 | |
|         L.expect('=');
 | |
|         auto v = parseAttributeValue();
 | |
|         attributes.push_back(Attribute::create(ident.range(), Ident(ident), v));
 | |
|       } else {
 | |
|         inputs.push_back(parseExp());
 | |
|       }
 | |
|     });
 | |
|   }
 | |
| 
 | |
|   // parse LHS acceptable exprs, which only includes subset of Exprs that prec
 | |
|   // is greater than 4 according to the python grammar
 | |
|   Expr parseLHSExp() {
 | |
|     return parseExp(4);
 | |
|   }
 | |
| 
 | |
|   // Parse expr's of the form [a:], [:b], [a:b], [:] and all variations with
 | |
|   // "::"
 | |
|   Expr parseSubscriptExp() {
 | |
|     TreeRef first, second, third;
 | |
|     auto range = L.cur().range;
 | |
|     if (L.cur().kind != ':') {
 | |
|       first = parseExp();
 | |
|     }
 | |
|     if (L.nextIf(':')) {
 | |
|       if (L.cur().kind != ',' && L.cur().kind != ']' && L.cur().kind != ':') {
 | |
|         second = parseExp();
 | |
|       }
 | |
|       if (L.nextIf(':')) {
 | |
|         if (L.cur().kind != ',' && L.cur().kind != ']') {
 | |
|           third = parseExp();
 | |
|         }
 | |
|       }
 | |
|       auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first))
 | |
|                                : Maybe<Expr>::create(range);
 | |
|       auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second))
 | |
|                                  : Maybe<Expr>::create(range);
 | |
|       auto maybe_third = third ? Maybe<Expr>::create(range, Expr(third))
 | |
|                                : Maybe<Expr>::create(range);
 | |
|       return SliceExpr::create(range, maybe_first, maybe_second, maybe_third);
 | |
|     } else {
 | |
|       return Expr(first);
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   TreeRef parseSubscript(const TreeRef& value) {
 | |
|     const auto range = L.cur().range;
 | |
| 
 | |
|     auto subscript_exprs =
 | |
|         parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
 | |
| 
 | |
|     return Subscript::create(range, Expr(value), subscript_exprs);
 | |
|   }
 | |
| 
 | |
|   Maybe<Expr> maybeParseTypeAnnotation() {
 | |
|     if (L.nextIf(':')) {
 | |
|       // NB: parseExp must not be called inline, since argument evaluation order
 | |
|       // changes when L.cur().range is mutated with respect to the parseExp()
 | |
|       // call.
 | |
|       auto expr = parseExp();
 | |
|       return Maybe<Expr>::create(expr.range(), expr);
 | |
|     } else {
 | |
|       return Maybe<Expr>::create(L.cur().range);
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   TreeRef parseFormalParam(bool kwarg_only) {
 | |
|     auto ident = parseIdent();
 | |
|     TreeRef type = maybeParseTypeAnnotation();
 | |
|     TreeRef def;
 | |
|     if (L.nextIf('=')) {
 | |
|       // NB: parseExp must not be called inline, since argument evaluation order
 | |
|       // changes when L.cur().range is mutated with respect to the parseExp()
 | |
|       // call.
 | |
|       auto expr = parseExp();
 | |
|       def = Maybe<Expr>::create(expr.range(), expr);
 | |
|     } else {
 | |
|       def = Maybe<Expr>::create(L.cur().range);
 | |
|     }
 | |
|     return Param::create(
 | |
|         type->range(),
 | |
|         Ident(ident),
 | |
|         Maybe<Expr>(type),
 | |
|         Maybe<Expr>(def),
 | |
|         kwarg_only);
 | |
|   }
 | |
| 
 | |
|   Param parseBareTypeAnnotation() {
 | |
|     auto type = parseExp();
 | |
|     return Param::create(
 | |
|         type.range(),
 | |
|         Ident::create(type.range(), ""),
 | |
|         Maybe<Expr>::create(type.range(), type),
 | |
|         Maybe<Expr>::create(type.range()),
 | |
|         /*kwarg_only=*/false);
 | |
|   }
 | |
| 
 | |
|   Decl parseTypeComment() {
 | |
|     auto range = L.cur().range;
 | |
|     L.expect(TK_TYPE_COMMENT);
 | |
|     auto param_types =
 | |
|         parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation);
 | |
|     TreeRef return_type;
 | |
|     if (L.nextIf(TK_ARROW)) {
 | |
|       auto return_type_range = L.cur().range;
 | |
|       return_type = Maybe<Expr>::create(return_type_range, parseExp());
 | |
|     } else {
 | |
|       return_type = Maybe<Expr>::create(L.cur().range);
 | |
|     }
 | |
|     return Decl::create(range, param_types, Maybe<Expr>(return_type));
 | |
|   }
 | |
| 
 | |
|   // 'first' has already been parsed since expressions can exist
 | |
|   // alone on a line:
 | |
|   // first[,other,lhs] = rhs
 | |
|   TreeRef parseAssign(const Expr& lhs) {
 | |
|     auto type = maybeParseTypeAnnotation();
 | |
|     auto maybeOp = maybeParseAssignmentOp();
 | |
|     if (maybeOp) {
 | |
|       // There is an assignment operator, parse the RHS and generate the
 | |
|       // assignment.
 | |
|       auto rhs = parseExpOrExpTuple();
 | |
|       if (maybeOp.value()->kind() == '=') {
 | |
|         std::vector<Expr> lhs_list = {lhs};
 | |
|         while (L.nextIf('=')) {
 | |
|           lhs_list.push_back(rhs);
 | |
|           rhs = parseExpOrExpTuple();
 | |
|         }
 | |
|         if (type.present() && lhs_list.size() > 1) {
 | |
|           throw ErrorReport(type.range())
 | |
|               << "Annotated multiple assignment is not supported in python";
 | |
|         }
 | |
|         L.expect(TK_NEWLINE);
 | |
|         return Assign::create(
 | |
|             lhs.range(),
 | |
|             List<Expr>::create(lhs_list[0].range(), lhs_list),
 | |
|             Maybe<Expr>::create(rhs.range(), rhs),
 | |
|             type);
 | |
|       } else {
 | |
|         L.expect(TK_NEWLINE);
 | |
|         // this is an augmented assignment
 | |
|         if (lhs.kind() == TK_TUPLE_LITERAL) {
 | |
|           throw ErrorReport(lhs.range())
 | |
|               << " augmented assignment can only have one LHS expression";
 | |
|         }
 | |
|         return AugAssign::create(
 | |
|             lhs.range(), lhs, AugAssignKind(*maybeOp), Expr(rhs));
 | |
|       }
 | |
|     } else {
 | |
|       // There is no assignment operator, so this is of the form `lhs : <type>`
 | |
|       TORCH_INTERNAL_ASSERT(type.present());
 | |
|       L.expect(TK_NEWLINE);
 | |
|       return Assign::create(
 | |
|           lhs.range(),
 | |
|           List<Expr>::create(lhs.range(), {lhs}),
 | |
|           Maybe<Expr>::create(lhs.range()),
 | |
|           type);
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   TreeRef parseStmt(bool in_class = false) {
 | |
|     switch (L.cur().kind) {
 | |
|       case TK_IF:
 | |
|         return parseIf();
 | |
|       case TK_WHILE:
 | |
|         return parseWhile();
 | |
|       case TK_FOR:
 | |
|         return parseFor();
 | |
|       case TK_GLOBAL: {
 | |
|         auto range = L.next().range;
 | |
|         auto idents =
 | |
|             parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent);
 | |
|         L.expect(TK_NEWLINE);
 | |
|         return Global::create(range, idents);
 | |
|       }
 | |
|       case TK_RETURN: {
 | |
|         auto range = L.next().range;
 | |
|         Expr value = L.cur().kind != TK_NEWLINE
 | |
|             ? parseExpOrExpTuple()
 | |
|             : Expr(create_compound(TK_NONE, range, {}));
 | |
|         L.expect(TK_NEWLINE);
 | |
|         return Return::create(range, value);
 | |
|       }
 | |
|       case TK_RAISE: {
 | |
|         auto range = L.next().range;
 | |
|         auto expr = parseExp();
 | |
|         L.expect(TK_NEWLINE);
 | |
|         return Raise::create(range, expr);
 | |
|       }
 | |
|       case TK_ASSERT: {
 | |
|         auto range = L.next().range;
 | |
|         auto cond = parseExp();
 | |
|         Maybe<Expr> maybe_first = Maybe<Expr>::create(range);
 | |
|         if (L.nextIf(',')) {
 | |
|           auto msg = parseExp();
 | |
|           maybe_first = Maybe<Expr>::create(range, Expr(msg));
 | |
|         }
 | |
|         L.expect(TK_NEWLINE);
 | |
|         return Assert::create(range, cond, maybe_first);
 | |
|       }
 | |
|       case TK_BREAK: {
 | |
|         auto range = L.next().range;
 | |
|         L.expect(TK_NEWLINE);
 | |
|         return Break::create(range);
 | |
|       }
 | |
|       case TK_CONTINUE: {
 | |
|         auto range = L.next().range;
 | |
|         L.expect(TK_NEWLINE);
 | |
|         return Continue::create(range);
 | |
|       }
 | |
|       case TK_PASS: {
 | |
|         auto range = L.next().range;
 | |
|         L.expect(TK_NEWLINE);
 | |
|         return Pass::create(range);
 | |
|       }
 | |
|       case TK_DEF: {
 | |
|         return parseFunction(/*is_method=*/in_class);
 | |
|       }
 | |
|       case TK_DELETE: {
 | |
|         L.expect(TK_DELETE);
 | |
|         auto expr = parseExp();
 | |
|         L.expect(TK_NEWLINE);
 | |
|         return Delete::create(expr);
 | |
|       }
 | |
|       default: {
 | |
|         auto lhs = parseExpOrExpTuple();
 | |
|         if (L.cur().kind != TK_NEWLINE) {
 | |
|           return parseAssign(lhs);
 | |
|         } else {
 | |
|           L.expect(TK_NEWLINE);
 | |
|           return ExprStmt::create(lhs.range(), lhs);
 | |
|         }
 | |
|       }
 | |
|     }
 | |
|   }
 | |
|   TreeRef parseIf(bool expect_if = true) {
 | |
|     auto r = L.cur().range;
 | |
|     if (expect_if)
 | |
|       L.expect(TK_IF);
 | |
|     auto cond = parseExp();
 | |
|     L.expect(':');
 | |
|     auto true_branch = parseStatements(/*expect_indent=*/true);
 | |
|     auto false_branch = makeList(L.cur().range, {});
 | |
|     if (L.nextIf(TK_ELSE)) {
 | |
|       L.expect(':');
 | |
|       false_branch = parseStatements(/*expect_indent=*/true);
 | |
|     } else if (L.nextIf(TK_ELIF)) {
 | |
|       // NB: this needs to be a separate statement, since the call to parseIf
 | |
|       // mutates the lexer state, and thus causes a heap-use-after-free in
 | |
|       // compilers which evaluate argument expressions LTR
 | |
|       auto range = L.cur().range;
 | |
|       false_branch = makeList(range, {parseIf(false)});
 | |
|     }
 | |
|     return If::create(
 | |
|         r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
 | |
|   }
 | |
|   TreeRef parseWhile() {
 | |
|     auto r = L.cur().range;
 | |
|     L.expect(TK_WHILE);
 | |
|     auto cond = parseExp();
 | |
|     L.expect(':');
 | |
|     auto body = parseStatements(/*expect_indent=*/true);
 | |
|     return While::create(r, Expr(cond), List<Stmt>(body));
 | |
|   }
 | |
| 
 | |
|   TreeRef parseFor() {
 | |
|     auto r = L.cur().range;
 | |
|     L.expect(TK_FOR);
 | |
|     auto targets = parseList(TK_NOTHING, ',', TK_IN, &ParserImpl::parseLHSExp);
 | |
|     auto itrs = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseExp);
 | |
|     auto body = parseStatements(/*expect_indent=*/true);
 | |
|     return For::create(r, targets, itrs, body);
 | |
|   }
 | |
| 
 | |
|   TreeRef parseStatements(bool expect_indent, bool in_class = false) {
 | |
|     auto r = L.cur().range;
 | |
|     if (expect_indent) {
 | |
|       L.expect(TK_INDENT);
 | |
|     }
 | |
|     TreeList stmts;
 | |
|     do {
 | |
|       stmts.push_back(parseStmt(in_class));
 | |
|     } while (!L.nextIf(TK_DEDENT));
 | |
|     return create_compound(TK_LIST, r, std::move(stmts));
 | |
|   }
 | |
| 
 | |
|   Maybe<Expr> parseReturnAnnotation() {
 | |
|     if (L.nextIf(TK_ARROW)) {
 | |
|       // Exactly one expression for return type annotation
 | |
|       auto return_type_range = L.cur().range;
 | |
|       return Maybe<Expr>::create(return_type_range, parseExp());
 | |
|     } else {
 | |
|       return Maybe<Expr>::create(L.cur().range);
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   List<Param> parseFormalParams() {
 | |
|     auto r = L.cur().range;
 | |
|     std::vector<Param> params;
 | |
|     bool kwarg_only = false;
 | |
|     parseSequence('(', ',', ')', [&] {
 | |
|       if (!kwarg_only && L.nextIf('*')) {
 | |
|         kwarg_only = true;
 | |
|       } else {
 | |
|         params.emplace_back(parseFormalParam(kwarg_only));
 | |
|       }
 | |
|     });
 | |
|     return List<Param>::create(r, params);
 | |
|   }
 | |
|   Decl parseDecl() {
 | |
|     // Parse return type annotation
 | |
|     List<Param> paramlist = parseFormalParams();
 | |
|     TreeRef return_type;
 | |
|     Maybe<Expr> return_annotation = parseReturnAnnotation();
 | |
|     L.expect(':');
 | |
|     return Decl::create(
 | |
|         paramlist.range(), List<Param>(paramlist), return_annotation);
 | |
|   }
 | |
| 
 | |
|   TreeRef parseClass() {
 | |
|     L.expect(TK_CLASS_DEF);
 | |
|     const auto name = parseIdent();
 | |
|     Maybe<Expr> superclass = Maybe<Expr>::create(name.range());
 | |
|     if (L.nextIf('(')) {
 | |
|       // Only support inheriting from NamedTuple right now.
 | |
|       auto id = parseExp();
 | |
|       superclass = Maybe<Expr>::create(id.range(), id);
 | |
|       L.expect(')');
 | |
|     }
 | |
|     L.expect(':');
 | |
|     const auto statements =
 | |
|         parseStatements(/*expect_indent=*/true, /*in_class=*/true);
 | |
|     return ClassDef::create(
 | |
|         name.range(), name, superclass, List<Stmt>(statements));
 | |
|   }
 | |
| 
 | |
|   TreeRef parseFunction(bool is_method) {
 | |
|     L.expect(TK_DEF);
 | |
|     auto name = parseIdent();
 | |
|     auto decl = parseDecl();
 | |
| 
 | |
|     TreeRef stmts_list;
 | |
|     if (L.nextIf(TK_INDENT)) {
 | |
|       // Handle type annotations specified in a type comment as the first line
 | |
|       // of the function.
 | |
|       if (L.cur().kind == TK_TYPE_COMMENT) {
 | |
|         auto type_annotation_decl = Decl(parseTypeComment());
 | |
|         L.expect(TK_NEWLINE);
 | |
|         decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
 | |
|       }
 | |
| 
 | |
|       stmts_list = parseStatements(false);
 | |
|     } else {
 | |
|       // Special case: the Python grammar allows one-line functions with a
 | |
|       // single statement.
 | |
|       if (L.cur().kind == TK_TYPE_COMMENT) {
 | |
|         auto type_annotation_decl = Decl(parseTypeComment());
 | |
|         decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
 | |
|       }
 | |
| 
 | |
|       TreeList stmts;
 | |
|       stmts.push_back(parseStmt(is_method));
 | |
|       stmts_list = create_compound(TK_LIST, L.cur().range, std::move(stmts));
 | |
|     }
 | |
| 
 | |
|     return Def::create(
 | |
|         name.range(), Ident(name), Decl(decl), List<Stmt>(stmts_list));
 | |
|   }
 | |
|   Lexer& lexer() {
 | |
|     return L;
 | |
|   }
 | |
| 
 | |
|  private:
 | |
|   // short helpers to create nodes
 | |
|   TreeRef create_compound(
 | |
|       int kind,
 | |
|       const SourceRange& range,
 | |
|       TreeList&& trees) {
 | |
|     return Compound::create(kind, range, std::move(trees));
 | |
|   }
 | |
|   TreeRef makeList(const SourceRange& range, TreeList&& trees) {
 | |
|     return create_compound(TK_LIST, range, std::move(trees));
 | |
|   }
 | |
|   Lexer L;
 | |
|   SharedParserData& shared;
 | |
| };
 | |
| 
 | |
| Parser::Parser(const std::shared_ptr<Source>& src)
 | |
|     : pImpl(new ParserImpl(src)) {}
 | |
| 
 | |
| Parser::~Parser() = default;
 | |
| 
 | |
| TreeRef Parser::parseFunction(bool is_method) {
 | |
|   return pImpl->parseFunction(is_method);
 | |
| }
 | |
| TreeRef Parser::parseClass() {
 | |
|   return pImpl->parseClass();
 | |
| }
 | |
| Lexer& Parser::lexer() {
 | |
|   return pImpl->lexer();
 | |
| }
 | |
| Decl Parser::parseTypeComment() {
 | |
|   return pImpl->parseTypeComment();
 | |
| }
 | |
| Expr Parser::parseExp() {
 | |
|   return pImpl->parseExp();
 | |
| }
 | |
| 
 | |
| } // namespace jit
 | |
| } // namespace torch
 |