diff --git a/test/cpp/jit/test_lexer.cpp b/test/cpp/jit/test_lexer.cpp index 1a9ddf9d5f39..465adbf6ecb4 100644 --- a/test/cpp/jit/test_lexer.cpp +++ b/test/cpp/jit/test_lexer.cpp @@ -29,7 +29,7 @@ TEST(LexerTest, AllTokens) { TEST(LexerTest, SlightlyOffIsNot) { std::vector suffixes = {"", " ", "**"}; for (const auto& suffix : suffixes) { - std::vector extras = {"n", "no", "no3", "note"}; + std::vector extras = {"n", "no", "no3"}; for (const auto& extra : extras) { std::string s = "is " + extra + suffix; Lexer l(std::make_shared(s)); @@ -45,7 +45,7 @@ TEST(LexerTest, SlightlyOffIsNot) { TEST(LexerTest, SlightlyOffNotIn) { std::vector suffixes = {"", " ", "**"}; for (const auto& suffix : suffixes) { - std::vector extras = {"i", "i3", "inn"}; + std::vector extras = {"i", "i3"}; for (const auto& extra : extras) { std::string s = "not " + extra + suffix; Lexer l(std::make_shared(s)); @@ -57,4 +57,32 @@ TEST(LexerTest, SlightlyOffNotIn) { } } } + +TEST(LexerTest, IsNoteBug) { + // The code string `is note` is lexed as TK_ISNOT followed by a + // TK_IDENT that is an e. This is not how it works in Python, but + // presumably we need to maintain this behavior. + Lexer l(std::make_shared("is note")); + const auto is_not_tok = l.next(); + EXPECT_EQ(is_not_tok.kind, TK_ISNOT); + const auto e_tok = l.next(); + EXPECT_EQ(e_tok.kind, TK_IDENT); + EXPECT_EQ(e_tok.range.text(), "e"); + const auto eof_tok = l.next(); + EXPECT_EQ(eof_tok.kind, TK_EOF); +} + +TEST(LexerTest, NotInpBug) { + // Another manifestation of the above IsNoteBug; `not inp` is lexed + // as TK_NOT_IN followed by a TK_IDENT that is a p. Again, not how + // it works in Python. + Lexer l(std::make_shared("not inp")); + const auto not_in_tok = l.next(); + EXPECT_EQ(not_in_tok.kind, TK_NOTIN); + const auto p_tok = l.next(); + EXPECT_EQ(p_tok.kind, TK_IDENT); + EXPECT_EQ(p_tok.range.text(), "p"); + const auto eof_tok = l.next(); + EXPECT_EQ(eof_tok.kind, TK_EOF); +} } // namespace torch::jit diff --git a/torch/csrc/jit/frontend/lexer.h b/torch/csrc/jit/frontend/lexer.h index 2b835ef68a55..0faf6ff24da4 100644 --- a/torch/csrc/jit/frontend/lexer.h +++ b/torch/csrc/jit/frontend/lexer.h @@ -1,17 +1,13 @@ #pragma once #include #include -#include #include #include #include #include #include -#include -#include #include #include -#include #include #include #include @@ -137,10 +133,51 @@ enum TokenKind { TORCH_API std::string kindToString(int kind); TORCH_API int stringToKind(const std::string& str); +// nested hash tables that indicate char-by-char what is a valid token. +struct TokenTrie; +using TokenTrieRef = std::unique_ptr; +struct TokenTrie { + TokenTrie() = default; + void insert(const char* str, int tok) { + if (*str == '\0') { + AT_ASSERT(kind == 0); + kind = tok; + return; + } + + for (size_t i = 0, e = child_chars.size(); i < e; ++i) { + if (child_chars[i] == *str) { + child_tries[i]->insert(str + 1, tok); + return; + } + } + + child_chars.emplace_back(*str); + child_tries.emplace_back(std::make_unique()); + child_tries.back()->insert(str + 1, tok); + } + int kind{0}; // 0 == invalid token + + std::vector child_chars; + std::vector child_tries; +}; + // stuff that is shared against all TC lexers/parsers and is initialized only // once. struct TORCH_API SharedParserData { - SharedParserData() = default; + SharedParserData() : head(new TokenTrie()) { + for (const char* c = valid_single_char_tokens; *c; c++) { + std::string str(1, *c); + head->insert(str.c_str(), *c); + } + +#define ADD_CASE(tok, _, tokstring) \ + if (*(tokstring) != '\0') { \ + head->insert((tokstring), (tok)); \ + } + TC_FORALL_TOKEN_KINDS(ADD_CASE) +#undef ADD_CASE + } bool match( StringCordView::Iterator pos, @@ -211,213 +248,41 @@ struct TORCH_API SharedParserData { return true; } - if (std::isalpha(*pos) || *pos == '_') { - matchIdentOrKeyword(pos, kind, end); - return true; - } - - // Hand-coded DFA matching for tokens that cannot be confused with - // identifiers. We could use a lexer generator toolkit like Flex - // or re2c instead, but that would add another dependency, and I - // expect this component to change infrequently given that PyTorch - // 2.0 is years old already. Note that the tests in text_lexer.cpp - // should guarantee that we don't forget to update this when we - // update TC_FORALL_TOKEN_KINDS. - const auto next_pos = pos.next_iter(); - switch (*pos) { - case '+': { - if (pos.has_next() && *next_pos == '=') { - *end = next_pos.next_iter(); - *kind = TK_PLUS_EQ; - return true; - } - goto single_char_token; + // check for either an ident or a token + // ident tracks whether what we have scanned so far could be an identifier + // matched indicates if we have found any match. + bool matched = false; + bool ident = true; + TokenTrie* cur = head.get(); + // for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr); + // i++) + for (size_t i = 0; pos.has_next() && (ident || cur != nullptr); + ++pos, ++i) { + ident = ident && validIdent(i, *pos); + if (ident) { + matched = true; + *end = pos.next_iter(); + *kind = TK_IDENT; } - case '-': - if (pos.has_next()) { - if (*next_pos == '=') { - *end = next_pos.next_iter(); - *kind = TK_MINUS_EQ; - return true; - } - if (*next_pos == '>') { - *end = next_pos.next_iter(); - *kind = TK_ARROW; - return true; - } + // check for token second, so that e.g. 'max' matches the token TK_MAX + // rather the + // identifier 'max' + if (cur) { + const auto begin_it = cur->child_chars.begin(); + const auto end_it = cur->child_chars.end(); + const auto ch_it = std::find(begin_it, end_it, *pos); + + cur = (ch_it == end_it) ? nullptr + : cur->child_tries[ch_it - begin_it].get(); + + if (cur && cur->kind != 0) { + matched = true; + *end = pos.next_iter(); + *kind = cur->kind; } - goto single_char_token; - case '*': - if (pos.has_next()) { - if (*next_pos == '*') { - if (next_pos.has_next() && *next_pos.next_iter() == '=') { - *end = next_pos.next_iter().next_iter(); - *kind = TK_POW_EQ; - return true; - } - *end = next_pos.next_iter(); - *kind = TK_POW; - return true; - } - if (*next_pos == '=') { - *end = next_pos.next_iter(); - *kind = TK_TIMES_EQ; - return true; - } - } - goto single_char_token; - case '/': - if (pos.has_next()) { - if (*next_pos == '/') { - *end = next_pos.next_iter(); - *kind = TK_FLOOR_DIV; - return true; - } - if (*next_pos == '=') { - *end = next_pos.next_iter(); - *kind = TK_DIV_EQ; - return true; - } - } - goto single_char_token; - case '%': - if (pos.has_next()) { - if (*next_pos == '=') { - *end = next_pos.next_iter(); - *kind = TK_MOD_EQ; - return true; - } - } - goto single_char_token; - case '=': - if (pos.has_next()) { - if (*next_pos == '=') { - *end = next_pos.next_iter(); - *kind = TK_EQ; - return true; - } - } - goto single_char_token; - case '>': - if (pos.has_next()) { - if (*next_pos == '=') { - *end = next_pos.next_iter(); - *kind = TK_GE; - return true; - } - if (*next_pos == '>') { - if (next_pos.has_next() && *next_pos.next_iter() == '=') { - *end = next_pos.next_iter().next_iter(); - *kind = TK_RSHIFT_EQ; - return true; - } - *end = next_pos.next_iter(); - *kind = TK_RSHIFT; - return true; - } - } - goto single_char_token; - case '<': - if (pos.has_next()) { - if (*next_pos == '=') { - if (next_pos.has_next() && *next_pos.next_iter() == '>') { - *end = next_pos.next_iter().next_iter(); - *kind = TK_EQUIVALENT; - return true; - } - *end = next_pos.next_iter(); - *kind = TK_LE; - return true; - } - if (*next_pos == '<') { - if (next_pos.has_next() && *next_pos.next_iter() == '=') { - *end = next_pos.next_iter().next_iter(); - *kind = TK_LSHIFT_EQ; - return true; - } - *end = next_pos.next_iter(); - *kind = TK_LSHIFT; - return true; - } - } - goto single_char_token; - case '.': - if (pos.has_next()) { - if (*next_pos == '.' && next_pos.has_next() && - *next_pos.next_iter() == '.') { - *end = next_pos.next_iter().next_iter(); - *kind = TK_DOTS; - return true; - } - } - goto single_char_token; - case '!': - if (pos.has_next()) { - if (*next_pos == '=') { - *end = next_pos.next_iter(); - *kind = TK_NE; - return true; - } - } - goto single_char_token; - case '&': - if (pos.has_next()) { - if (*next_pos == '=') { - *end = next_pos.next_iter(); - *kind = TK_BIT_AND_EQ; - return true; - } - } - goto single_char_token; - case '^': - if (pos.has_next()) { - if (*next_pos == '=') { - *end = next_pos.next_iter(); - *kind = TK_BIT_XOR_EQ; - return true; - } - } - goto single_char_token; - case '|': - if (pos.has_next()) { - if (*next_pos == '=') { - *end = next_pos.next_iter(); - *kind = TK_BIT_OR_EQ; - return true; - } - } - goto single_char_token; - case '#': - *end = pos + std::strlen("# type:"); - *kind = TK_TYPE_COMMENT; - return true; - case '@': - case '(': - case ')': - case '[': - case ']': - case ':': - case ',': - case '{': - case '}': - case '?': - case '~': - single_char_token: - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - std::strchr(valid_single_char_tokens, *pos) != nullptr, - "Did you forget to add the character `", - *pos, - "` to valid_single_char_tokens?"); - *end = next_pos; - *kind = *pos; - return true; + } } - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - std::strchr(valid_single_char_tokens, *pos) == nullptr, - "Did you forget to add the character `", - *pos, - "` to the above switch statement?"); - return false; + return matched; } bool isUnary(int kind, int* prec); @@ -434,196 +299,8 @@ struct TORCH_API SharedParserData { } private: - void matchIdentOrKeyword( - StringCordView::Iterator pos, - int* kind, - StringCordView::Iterator* end) const { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pos.has_next()); - static constexpr char kIsNot[] = "is not"; - static constexpr char kNotIn[] = "not in"; - constexpr char kMaybeIsNot = 'i'; - constexpr char kMaybeNotIn = 'n'; - constexpr int kIsNotSpaceIndex = 2; - constexpr int kNotInSpaceIndex = 3; - auto start = pos; - char possible_special_token = *pos; - // The longest tokens are 8 chars. - std::array token_chars; - token_chars.fill('\0'); - token_chars[0] = possible_special_token; - ++pos; - size_t i; - auto valid_ident_char = [](const char ch) { - return std::isalpha(ch) || ch == '_' || std::isdigit(ch); - }; - for (i = 1; pos.has_next(); ++pos, ++i) { - auto ch = *pos; - if (possible_special_token == kMaybeIsNot) { - if (ch != kIsNot[i]) { - if (i >= kIsNotSpaceIndex + 1) { - // Kick out to the after-loop flow, which will correctly - // record that we found TK_IS. - break; - } - possible_special_token = '\0'; - } else if (ch == ' ') { - continue; - } - if (possible_special_token && i == sizeof(kIsNot) - 2 && - (!pos.has_next() || !valid_ident_char(*(pos + 1)))) { - *kind = TK_ISNOT; - *end = pos.next_iter(); - return; - } - } else if (possible_special_token == kMaybeNotIn) { - if (ch != kNotIn[i]) { - if (i >= kNotInSpaceIndex + 1) { - // Kick out to the after-loop flow, which will correctly - // record that we found TK_NOT. - break; - } - possible_special_token = '\0'; - } else if (ch == ' ') { - continue; - } - - if (possible_special_token && i == sizeof(kNotIn) - 2 && - (!pos.has_next() || !valid_ident_char(*(pos + 1)))) { - *kind = TK_NOTIN; - *end = pos.next_iter(); - return; - } - } - if (valid_ident_char(ch)) { - if (i < token_chars.size()) { - token_chars[i] = ch; - } - continue; - } - break; - } - - // These two possible_special_token checks have to be after the - // loop and not in the loop because we might see end-of-input - // (e.g., the entire input `not p`). - if (possible_special_token == kMaybeIsNot) { - if (i >= kIsNotSpaceIndex) { - *kind = TK_IS; - *end = start + kIsNotSpaceIndex; - return; - } - } else if (possible_special_token == kMaybeNotIn) { - if (i >= kNotInSpaceIndex) { - *kind = TK_NOT; - *end = start + kNotInSpaceIndex; - return; - } - } - - *end = pos; - *kind = identTokenKind(token_chars, i); - } - - template - static constexpr uint64_t stringToUint64(const char (&str)[N]) { - static_assert(N <= 9); - uint64_t result = 0; - for (auto i : c10::irange(N)) { - if (!str[i]) { - return result; - } -#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ - result |= static_cast(str[i]) << (8 * i); -#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - result |= static_cast(str[i]) << (56 - 8 * i); -#else -#error "Unexpected or undefined value of __BYTE_ORDER__" -#endif - } - return result; - } - - static int identTokenKind( - const std::array& token_chars, - size_t token_length) { - if (token_length > token_chars.size()) { - return TK_IDENT; - } -#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ - static_assert(stringToUint64("and") == 0x646e61); - static_assert(stringToUint64("Ellipsis") == 0x73697370696c6c45); -#else - static_assert(stringToUint64("and") == 0x616e640000000000); - static_assert(stringToUint64("Ellipsis") == 0x456c6c6970736973); -#endif - - std::uint64_t token = 0; - std::memcpy(&token, token_chars.data(), token_chars.size()); - // FWIW, based on checking Godbolt this probably compiles down to - // binary or linear search over the integers representing our - // strings. I tried an alternate version that switched on the - // first character of the token, but it doesn't seem to matter for - // performance. - switch (token) { - case stringToUint64("Ellipsis"): - return TK_ELLIPSIS; - case stringToUint64("False"): - return TK_FALSE; - case stringToUint64("None"): - return TK_NONE; - case stringToUint64("NoneType"): - return TK_NONE_TYPE; - case stringToUint64("True"): - return TK_TRUE; - case stringToUint64("and"): - return TK_AND; - case stringToUint64("as"): - return TK_AS; - case stringToUint64("assert"): - return TK_ASSERT; - case stringToUint64("break"): - return TK_BREAK; - case stringToUint64("class"): - return TK_CLASS_DEF; - case stringToUint64("continue"): - return TK_CONTINUE; - case stringToUint64("def"): - return TK_DEF; - case stringToUint64("del"): - return TK_DELETE; - case stringToUint64("elif"): - return TK_ELIF; - case stringToUint64("else"): - return TK_ELSE; - case stringToUint64("for"): - return TK_FOR; - case stringToUint64("global"): - return TK_GLOBAL; - case stringToUint64("if"): - return TK_IF; - case stringToUint64("import"): - return TK_IMPORT; - case stringToUint64("in"): - return TK_IN; - case stringToUint64("is"): - return TK_IS; - case stringToUint64("not"): - return TK_NOT; - case stringToUint64("or"): - return TK_OR; - case stringToUint64("pass"): - return TK_PASS; - case stringToUint64("raise"): - return TK_RAISE; - case stringToUint64("return"): - return TK_RETURN; - case stringToUint64("while"): - return TK_WHILE; - case stringToUint64("with"): - return TK_WITH; - default: - return TK_IDENT; - } + bool validIdent(size_t i, char n) { + return isalpha(n) || n == '_' || (i > 0 && isdigit(n)); } // 1. skip whitespace @@ -635,7 +312,7 @@ struct TORCH_API SharedParserData { // http://en.cppreference.com/w/cpp/string/byte/strtof // but we want only the number part, otherwise 1+3 will turn into two // adjacent numbers in the lexer - if (first == '-' || first == '+' || std::isalpha(first)) + if (first == '-' || first == '+' || isalpha(first)) return false; const char* startptr = str.data() + start; char* endptr = nullptr; @@ -710,6 +387,8 @@ struct TORCH_API SharedParserData { auto match_string = str.substr(pos, type_string.size()); return match_string == type_string; } + + TokenTrieRef head; }; TORCH_API SharedParserData& sharedParserData(); diff --git a/torch/csrc/jit/frontend/parser_constants.h b/torch/csrc/jit/frontend/parser_constants.h index aaa4ef3a498a..fb5cf0d88e1e 100644 --- a/torch/csrc/jit/frontend/parser_constants.h +++ b/torch/csrc/jit/frontend/parser_constants.h @@ -1,6 +1,6 @@ #pragma once namespace torch::jit { -[[maybe_unused]] static constexpr const char* valid_single_char_tokens = +static constexpr const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!&^|~"; } // namespace torch::jit