mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: With this PR, we can now support left and right shift operators in the JIT engine for <int, int> and <Tensor, int>. Updated tests pass as expected: ``` > python test/test_jit.py ... Ran 2427 tests in 84.861s OK (skipped=139, expected failures=1) ``` Running the following code with Python results in the output below: ``` > cat ~/expressions.py import torch torch.jit.script def fn(a, b): # type: (int, int) return ( a << b, # supported b >> a, # supported a & b, a | b, a ^ b ) print(fn.graph) ``` ``` > python ~/expressions.py graph(%a.1 : int, %b.1 : int): %4 : int = aten::leftshift(%a.1, %b.1) # /home/ince/expressions.py:7:8 %7 : int = aten::rightshift(%b.1, %a.1) # /home/ince/expressions.py:8:8 %10 : int = aten::__and__(%a.1, %b.1) # /home/ince/expressions.py:9:8 %13 : int = aten::__or__(%a.1, %b.1) # /home/ince/expressions.py:10:8 %16 : int = aten::__xor__(%a.1, %b.1) # /home/ince/expressions.py:11:8 %17 : (int, int, int, int, int) = prim::TupleConstruct(%4, %7, %10, %13, %16) return (%17) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/34563 Differential Revision: D20434209 Pulled By: tugrulince fbshipit-source-id: 886386c59755106e17b84778b8e495b80a6269cd
107 lines
2.3 KiB
C++
107 lines
2.3 KiB
C++
#include <torch/csrc/jit/frontend/lexer.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <mutex>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
static const std::unordered_map<int, int> binary_prec = {
|
|
{TK_IF, 1},
|
|
{TK_FOR, 1},
|
|
{TK_AND, 2},
|
|
{TK_OR, 2},
|
|
// reserve a level for unary not
|
|
{TK_IN, 4},
|
|
{TK_NOTIN, 4},
|
|
{'<', 4},
|
|
{'>', 4},
|
|
{TK_IS, 4},
|
|
{TK_ISNOT, 4},
|
|
{TK_EQ, 4},
|
|
{TK_LE, 4},
|
|
{TK_GE, 4},
|
|
{TK_NE, 4},
|
|
{'|', 5},
|
|
{'^', 6},
|
|
{'&', 7},
|
|
{TK_LSHIFT, 8},
|
|
{TK_RSHIFT, 8},
|
|
{'+', 9},
|
|
{'-', 9},
|
|
{'*', 10},
|
|
{'/', 10},
|
|
{TK_FLOOR_DIV, 10},
|
|
{'%', 10},
|
|
{'@', 10},
|
|
{TK_POW, 11},
|
|
};
|
|
|
|
static const std::unordered_map<int, int> unary_prec = {
|
|
{TK_NOT, 3},
|
|
{'~', 3},
|
|
{'-', 10},
|
|
{'*', 10},
|
|
};
|
|
|
|
bool SharedParserData::isUnary(int kind, int* prec) {
|
|
auto it = unary_prec.find(kind);
|
|
if (it != unary_prec.end()) {
|
|
*prec = it->second;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
bool SharedParserData::isBinary(int kind, int* prec) {
|
|
auto it = binary_prec.find(kind);
|
|
if (it != binary_prec.end()) {
|
|
*prec = it->second;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
C10_EXPORT int stringToKind(const std::string& str) {
|
|
static std::once_flag init_flag;
|
|
static std::unordered_map<std::string, int> str_to_kind;
|
|
std::call_once(init_flag, []() {
|
|
for (char tok : std::string(valid_single_char_tokens))
|
|
str_to_kind[std::string(1, tok)] = tok;
|
|
#define DEFINE_CASE(tok, _, str) \
|
|
if (std::string(str) != "") \
|
|
str_to_kind[str] = tok;
|
|
TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
|
|
#undef DEFINE_CASE
|
|
});
|
|
try {
|
|
return str_to_kind.at(str);
|
|
} catch (std::out_of_range& err) {
|
|
throw std::out_of_range("unknown token in stringToKind");
|
|
}
|
|
}
|
|
|
|
C10_EXPORT std::string kindToString(int kind) {
|
|
if (kind < 256)
|
|
return std::string(1, kind);
|
|
switch (kind) {
|
|
#define DEFINE_CASE(tok, str, _) \
|
|
case tok: \
|
|
return str;
|
|
TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
|
|
#undef DEFINE_CASE
|
|
default:
|
|
throw std::runtime_error("Unknown kind: " + c10::guts::to_string(kind));
|
|
}
|
|
}
|
|
|
|
C10_EXPORT SharedParserData& sharedParserData() {
|
|
static SharedParserData data; // safely handles multi-threaded init
|
|
return data;
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|