Revert "[Relandx2] Rewrite the guts of torch::jit::Lexer to speed it up (#152372)"

This reverts commit 7ce6f632142b65849fa33f325c90a24bace2c130.

Reverted https://github.com/pytorch/pytorch/pull/152372 on behalf of https://github.com/malfet due to Looks like it broke distributed this time around, see f05d3e5019/1 ([comment](https://github.com/pytorch/pytorch/pull/152372#issuecomment-2837426497))
This commit is contained in:
PyTorch MergeBot
2025-04-29 04:37:40 +00:00
parent f05d3e5019
commit 46419c7899
3 changed files with 110 additions and 403 deletions

View File

@ -29,7 +29,7 @@ TEST(LexerTest, AllTokens) {
TEST(LexerTest, SlightlyOffIsNot) {
std::vector<std::string> suffixes = {"", " ", "**"};
for (const auto& suffix : suffixes) {
std::vector<std::string> extras = {"n", "no", "no3", "note"};
std::vector<std::string> extras = {"n", "no", "no3"};
for (const auto& extra : extras) {
std::string s = "is " + extra + suffix;
Lexer l(std::make_shared<Source>(s));
@ -45,7 +45,7 @@ TEST(LexerTest, SlightlyOffIsNot) {
TEST(LexerTest, SlightlyOffNotIn) {
std::vector<std::string> suffixes = {"", " ", "**"};
for (const auto& suffix : suffixes) {
std::vector<std::string> extras = {"i", "i3", "inn"};
std::vector<std::string> extras = {"i", "i3"};
for (const auto& extra : extras) {
std::string s = "not " + extra + suffix;
Lexer l(std::make_shared<Source>(s));
@ -57,4 +57,32 @@ TEST(LexerTest, SlightlyOffNotIn) {
}
}
}
TEST(LexerTest, IsNoteBug) {
// The code string `is note` is lexed as TK_ISNOT followed by a
// TK_IDENT that is an e. This is not how it works in Python, but
// presumably we need to maintain this behavior.
Lexer l(std::make_shared<Source>("is note"));
const auto is_not_tok = l.next();
EXPECT_EQ(is_not_tok.kind, TK_ISNOT);
const auto e_tok = l.next();
EXPECT_EQ(e_tok.kind, TK_IDENT);
EXPECT_EQ(e_tok.range.text(), "e");
const auto eof_tok = l.next();
EXPECT_EQ(eof_tok.kind, TK_EOF);
}
TEST(LexerTest, NotInpBug) {
// Another manifestation of the above IsNoteBug; `not inp` is lexed
// as TK_NOT_IN followed by a TK_IDENT that is a p. Again, not how
// it works in Python.
Lexer l(std::make_shared<Source>("not inp"));
const auto not_in_tok = l.next();
EXPECT_EQ(not_in_tok.kind, TK_NOTIN);
const auto p_tok = l.next();
EXPECT_EQ(p_tok.kind, TK_IDENT);
EXPECT_EQ(p_tok.range.text(), "p");
const auto eof_tok = l.next();
EXPECT_EQ(eof_tok.kind, TK_EOF);
}
} // namespace torch::jit

View File

@ -1,17 +1,13 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/frontend/parser_constants.h>
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/frontend/strtod.h>
#include <algorithm>
#include <array>
#include <cctype>
#include <clocale>
#include <cstdlib>
#include <cstring>
#include <memory>
#include <sstream>
#include <string>
@ -137,10 +133,51 @@ enum TokenKind {
TORCH_API std::string kindToString(int kind);
TORCH_API int stringToKind(const std::string& str);
// nested hash tables that indicate char-by-char what is a valid token.
struct TokenTrie;
using TokenTrieRef = std::unique_ptr<TokenTrie>;
struct TokenTrie {
TokenTrie() = default;
void insert(const char* str, int tok) {
if (*str == '\0') {
AT_ASSERT(kind == 0);
kind = tok;
return;
}
for (size_t i = 0, e = child_chars.size(); i < e; ++i) {
if (child_chars[i] == *str) {
child_tries[i]->insert(str + 1, tok);
return;
}
}
child_chars.emplace_back(*str);
child_tries.emplace_back(std::make_unique<TokenTrie>());
child_tries.back()->insert(str + 1, tok);
}
int kind{0}; // 0 == invalid token
std::vector<char> child_chars;
std::vector<TokenTrieRef> child_tries;
};
// stuff that is shared against all TC lexers/parsers and is initialized only
// once.
struct TORCH_API SharedParserData {
SharedParserData() = default;
SharedParserData() : head(new TokenTrie()) {
for (const char* c = valid_single_char_tokens; *c; c++) {
std::string str(1, *c);
head->insert(str.c_str(), *c);
}
#define ADD_CASE(tok, _, tokstring) \
if (*(tokstring) != '\0') { \
head->insert((tokstring), (tok)); \
}
TC_FORALL_TOKEN_KINDS(ADD_CASE)
#undef ADD_CASE
}
bool match(
StringCordView::Iterator pos,
@ -211,213 +248,41 @@ struct TORCH_API SharedParserData {
return true;
}
if (std::isalpha(*pos) || *pos == '_') {
matchIdentOrKeyword(pos, kind, end);
return true;
}
// Hand-coded DFA matching for tokens that cannot be confused with
// identifiers. We could use a lexer generator toolkit like Flex
// or re2c instead, but that would add another dependency, and I
// expect this component to change infrequently given that PyTorch
// 2.0 is years old already. Note that the tests in text_lexer.cpp
// should guarantee that we don't forget to update this when we
// update TC_FORALL_TOKEN_KINDS.
const auto next_pos = pos.next_iter();
switch (*pos) {
case '+': {
if (pos.has_next() && *next_pos == '=') {
*end = next_pos.next_iter();
*kind = TK_PLUS_EQ;
return true;
}
goto single_char_token;
// check for either an ident or a token
// ident tracks whether what we have scanned so far could be an identifier
// matched indicates if we have found any match.
bool matched = false;
bool ident = true;
TokenTrie* cur = head.get();
// for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr);
// i++)
for (size_t i = 0; pos.has_next() && (ident || cur != nullptr);
++pos, ++i) {
ident = ident && validIdent(i, *pos);
if (ident) {
matched = true;
*end = pos.next_iter();
*kind = TK_IDENT;
}
case '-':
if (pos.has_next()) {
if (*next_pos == '=') {
*end = next_pos.next_iter();
*kind = TK_MINUS_EQ;
return true;
}
if (*next_pos == '>') {
*end = next_pos.next_iter();
*kind = TK_ARROW;
return true;
}
// check for token second, so that e.g. 'max' matches the token TK_MAX
// rather the
// identifier 'max'
if (cur) {
const auto begin_it = cur->child_chars.begin();
const auto end_it = cur->child_chars.end();
const auto ch_it = std::find(begin_it, end_it, *pos);
cur = (ch_it == end_it) ? nullptr
: cur->child_tries[ch_it - begin_it].get();
if (cur && cur->kind != 0) {
matched = true;
*end = pos.next_iter();
*kind = cur->kind;
}
goto single_char_token;
case '*':
if (pos.has_next()) {
if (*next_pos == '*') {
if (next_pos.has_next() && *next_pos.next_iter() == '=') {
*end = next_pos.next_iter().next_iter();
*kind = TK_POW_EQ;
return true;
}
*end = next_pos.next_iter();
*kind = TK_POW;
return true;
}
if (*next_pos == '=') {
*end = next_pos.next_iter();
*kind = TK_TIMES_EQ;
return true;
}
}
goto single_char_token;
case '/':
if (pos.has_next()) {
if (*next_pos == '/') {
*end = next_pos.next_iter();
*kind = TK_FLOOR_DIV;
return true;
}
if (*next_pos == '=') {
*end = next_pos.next_iter();
*kind = TK_DIV_EQ;
return true;
}
}
goto single_char_token;
case '%':
if (pos.has_next()) {
if (*next_pos == '=') {
*end = next_pos.next_iter();
*kind = TK_MOD_EQ;
return true;
}
}
goto single_char_token;
case '=':
if (pos.has_next()) {
if (*next_pos == '=') {
*end = next_pos.next_iter();
*kind = TK_EQ;
return true;
}
}
goto single_char_token;
case '>':
if (pos.has_next()) {
if (*next_pos == '=') {
*end = next_pos.next_iter();
*kind = TK_GE;
return true;
}
if (*next_pos == '>') {
if (next_pos.has_next() && *next_pos.next_iter() == '=') {
*end = next_pos.next_iter().next_iter();
*kind = TK_RSHIFT_EQ;
return true;
}
*end = next_pos.next_iter();
*kind = TK_RSHIFT;
return true;
}
}
goto single_char_token;
case '<':
if (pos.has_next()) {
if (*next_pos == '=') {
if (next_pos.has_next() && *next_pos.next_iter() == '>') {
*end = next_pos.next_iter().next_iter();
*kind = TK_EQUIVALENT;
return true;
}
*end = next_pos.next_iter();
*kind = TK_LE;
return true;
}
if (*next_pos == '<') {
if (next_pos.has_next() && *next_pos.next_iter() == '=') {
*end = next_pos.next_iter().next_iter();
*kind = TK_LSHIFT_EQ;
return true;
}
*end = next_pos.next_iter();
*kind = TK_LSHIFT;
return true;
}
}
goto single_char_token;
case '.':
if (pos.has_next()) {
if (*next_pos == '.' && next_pos.has_next() &&
*next_pos.next_iter() == '.') {
*end = next_pos.next_iter().next_iter();
*kind = TK_DOTS;
return true;
}
}
goto single_char_token;
case '!':
if (pos.has_next()) {
if (*next_pos == '=') {
*end = next_pos.next_iter();
*kind = TK_NE;
return true;
}
}
goto single_char_token;
case '&':
if (pos.has_next()) {
if (*next_pos == '=') {
*end = next_pos.next_iter();
*kind = TK_BIT_AND_EQ;
return true;
}
}
goto single_char_token;
case '^':
if (pos.has_next()) {
if (*next_pos == '=') {
*end = next_pos.next_iter();
*kind = TK_BIT_XOR_EQ;
return true;
}
}
goto single_char_token;
case '|':
if (pos.has_next()) {
if (*next_pos == '=') {
*end = next_pos.next_iter();
*kind = TK_BIT_OR_EQ;
return true;
}
}
goto single_char_token;
case '#':
*end = pos + std::strlen("# type:");
*kind = TK_TYPE_COMMENT;
return true;
case '@':
case '(':
case ')':
case '[':
case ']':
case ':':
case ',':
case '{':
case '}':
case '?':
case '~':
single_char_token:
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
std::strchr(valid_single_char_tokens, *pos) != nullptr,
"Did you forget to add the character `",
*pos,
"` to valid_single_char_tokens?");
*end = next_pos;
*kind = *pos;
return true;
}
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
std::strchr(valid_single_char_tokens, *pos) == nullptr,
"Did you forget to add the character `",
*pos,
"` to the above switch statement?");
return false;
return matched;
}
bool isUnary(int kind, int* prec);
@ -434,196 +299,8 @@ struct TORCH_API SharedParserData {
}
private:
void matchIdentOrKeyword(
StringCordView::Iterator pos,
int* kind,
StringCordView::Iterator* end) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pos.has_next());
static constexpr char kIsNot[] = "is not";
static constexpr char kNotIn[] = "not in";
constexpr char kMaybeIsNot = 'i';
constexpr char kMaybeNotIn = 'n';
constexpr int kIsNotSpaceIndex = 2;
constexpr int kNotInSpaceIndex = 3;
auto start = pos;
char possible_special_token = *pos;
// The longest tokens are 8 chars.
std::array<char, 8> token_chars;
token_chars.fill('\0');
token_chars[0] = possible_special_token;
++pos;
size_t i;
auto valid_ident_char = [](const char ch) {
return std::isalpha(ch) || ch == '_' || std::isdigit(ch);
};
for (i = 1; pos.has_next(); ++pos, ++i) {
auto ch = *pos;
if (possible_special_token == kMaybeIsNot) {
if (ch != kIsNot[i]) {
if (i >= kIsNotSpaceIndex + 1) {
// Kick out to the after-loop flow, which will correctly
// record that we found TK_IS.
break;
}
possible_special_token = '\0';
} else if (ch == ' ') {
continue;
}
if (possible_special_token && i == sizeof(kIsNot) - 2 &&
(!pos.has_next() || !valid_ident_char(*(pos + 1)))) {
*kind = TK_ISNOT;
*end = pos.next_iter();
return;
}
} else if (possible_special_token == kMaybeNotIn) {
if (ch != kNotIn[i]) {
if (i >= kNotInSpaceIndex + 1) {
// Kick out to the after-loop flow, which will correctly
// record that we found TK_NOT.
break;
}
possible_special_token = '\0';
} else if (ch == ' ') {
continue;
}
if (possible_special_token && i == sizeof(kNotIn) - 2 &&
(!pos.has_next() || !valid_ident_char(*(pos + 1)))) {
*kind = TK_NOTIN;
*end = pos.next_iter();
return;
}
}
if (valid_ident_char(ch)) {
if (i < token_chars.size()) {
token_chars[i] = ch;
}
continue;
}
break;
}
// These two possible_special_token checks have to be after the
// loop and not in the loop because we might see end-of-input
// (e.g., the entire input `not p`).
if (possible_special_token == kMaybeIsNot) {
if (i >= kIsNotSpaceIndex) {
*kind = TK_IS;
*end = start + kIsNotSpaceIndex;
return;
}
} else if (possible_special_token == kMaybeNotIn) {
if (i >= kNotInSpaceIndex) {
*kind = TK_NOT;
*end = start + kNotInSpaceIndex;
return;
}
}
*end = pos;
*kind = identTokenKind(token_chars, i);
}
template <size_t N>
static constexpr uint64_t stringToUint64(const char (&str)[N]) {
static_assert(N <= 9);
uint64_t result = 0;
for (auto i : c10::irange(N)) {
if (!str[i]) {
return result;
}
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
result |= static_cast<uint64_t>(str[i]) << (8 * i);
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
result |= static_cast<uint64_t>(str[i]) << (56 - 8 * i);
#else
#error "Unexpected or undefined value of __BYTE_ORDER__"
#endif
}
return result;
}
static int identTokenKind(
const std::array<char, 8>& token_chars,
size_t token_length) {
if (token_length > token_chars.size()) {
return TK_IDENT;
}
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
static_assert(stringToUint64("and") == 0x646e61);
static_assert(stringToUint64("Ellipsis") == 0x73697370696c6c45);
#else
static_assert(stringToUint64("and") == 0x616e640000000000);
static_assert(stringToUint64("Ellipsis") == 0x456c6c6970736973);
#endif
std::uint64_t token = 0;
std::memcpy(&token, token_chars.data(), token_chars.size());
// FWIW, based on checking Godbolt this probably compiles down to
// binary or linear search over the integers representing our
// strings. I tried an alternate version that switched on the
// first character of the token, but it doesn't seem to matter for
// performance.
switch (token) {
case stringToUint64("Ellipsis"):
return TK_ELLIPSIS;
case stringToUint64("False"):
return TK_FALSE;
case stringToUint64("None"):
return TK_NONE;
case stringToUint64("NoneType"):
return TK_NONE_TYPE;
case stringToUint64("True"):
return TK_TRUE;
case stringToUint64("and"):
return TK_AND;
case stringToUint64("as"):
return TK_AS;
case stringToUint64("assert"):
return TK_ASSERT;
case stringToUint64("break"):
return TK_BREAK;
case stringToUint64("class"):
return TK_CLASS_DEF;
case stringToUint64("continue"):
return TK_CONTINUE;
case stringToUint64("def"):
return TK_DEF;
case stringToUint64("del"):
return TK_DELETE;
case stringToUint64("elif"):
return TK_ELIF;
case stringToUint64("else"):
return TK_ELSE;
case stringToUint64("for"):
return TK_FOR;
case stringToUint64("global"):
return TK_GLOBAL;
case stringToUint64("if"):
return TK_IF;
case stringToUint64("import"):
return TK_IMPORT;
case stringToUint64("in"):
return TK_IN;
case stringToUint64("is"):
return TK_IS;
case stringToUint64("not"):
return TK_NOT;
case stringToUint64("or"):
return TK_OR;
case stringToUint64("pass"):
return TK_PASS;
case stringToUint64("raise"):
return TK_RAISE;
case stringToUint64("return"):
return TK_RETURN;
case stringToUint64("while"):
return TK_WHILE;
case stringToUint64("with"):
return TK_WITH;
default:
return TK_IDENT;
}
bool validIdent(size_t i, char n) {
return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
}
// 1. skip whitespace
@ -635,7 +312,7 @@ struct TORCH_API SharedParserData {
// http://en.cppreference.com/w/cpp/string/byte/strtof
// but we want only the number part, otherwise 1+3 will turn into two
// adjacent numbers in the lexer
if (first == '-' || first == '+' || std::isalpha(first))
if (first == '-' || first == '+' || isalpha(first))
return false;
const char* startptr = str.data() + start;
char* endptr = nullptr;
@ -710,6 +387,8 @@ struct TORCH_API SharedParserData {
auto match_string = str.substr(pos, type_string.size());
return match_string == type_string;
}
TokenTrieRef head;
};
TORCH_API SharedParserData& sharedParserData();

View File

@ -1,6 +1,6 @@
#pragma once
namespace torch::jit {
[[maybe_unused]] static constexpr const char* valid_single_char_tokens =
static constexpr const char* valid_single_char_tokens =
"+-*/%@()[]:,={}><.?!&^|~";
} // namespace torch::jit