diff --git a/test/test_jit.py b/test/test_jit.py index f34ab406b8db..5fa1debef62c 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8374,17 +8374,24 @@ a") .check("Traceback") \ .check("in foo").check("in baz").run(str(cm.exception)) - def test_binop_unsupported_error(self): - with self.assertRaisesRegex(NotSupportedError, "unsupported binary operator:"): - @torch.jit.script - def binop(x, y): - # Replace this with another unsupported op when/if it gets supported - return x << y + def test_operator_precedence(self): + def double(x): + # type: (int) -> int + return 2 * x + + def complicated_arithmetic_operation(): + # TODO we need to test exponent operator '**' and bitwise not + # operator '~' once they are properly supported. + list = [0, 1, 2, 3] + result = list[1:3][0] + double(4) + (-3 + 8) * 6 // 2 % 4 << 2 + 1 >> 1 | 23 & 16 + 3 ^ 4 + return result + + self.checkScript(complicated_arithmetic_operation, ()) def test_bitwise_ops(self): def int_test(): - return 2 & 3, 2 ^ 3, 2 | 3 + return 2 & 3, 2 ^ 3, 2 | 3, 2 << 3, 2 >> 3 self.checkScript(int_test, ()) @@ -8398,10 +8405,15 @@ a") def tensor_test(x, y): return x & y, x ^ y, x | y + def tensor_with_int_test(x, y): + # type: (Tensor, int) -> Tuple[Tensor, Tensor] + return x << y, x >> y + x = torch.tensor(2) y = torch.tensor(3) self.checkScript(tensor_test, (x, y)) + self.checkScript(tensor_with_int_test, (x, 2)) def not_test(x): return ~x diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 18ab8466a8e1..2df572825640 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -2278,6 +2278,10 @@ struct to_ir { return aten::__not__; case TK_FLOOR_DIV: return aten::floordiv; + case TK_LSHIFT: + return aten::__lshift__; + case TK_RSHIFT: + return aten::__rshift__; case '&': return aten::__and__; case '|': @@ -2329,6 +2333,10 @@ struct to_ir { return "__xor__"; case TK_IN: return "__contains__"; + case TK_LSHIFT: + return "__lshift__"; + case TK_RSHIFT: + return "__rshift__"; default: throw std::runtime_error("unknown kind " + c10::to_string(kind)); } @@ -2868,7 +2876,9 @@ struct to_ir { case '%': case '&': case '|': - case '^': { + case '^': + case TK_LSHIFT: + case TK_RSHIFT: { const auto& inputs = tree->trees(); auto kind = getNodeKind(tree->kind(), inputs.size()); auto overload = getOperatorOverload(tree->kind(), inputs.size()); diff --git a/torch/csrc/jit/frontend/lexer.cpp b/torch/csrc/jit/frontend/lexer.cpp index 028662ae5076..52de286e8e70 100644 --- a/torch/csrc/jit/frontend/lexer.cpp +++ b/torch/csrc/jit/frontend/lexer.cpp @@ -28,21 +28,23 @@ static const std::unordered_map binary_prec = { {'|', 5}, {'^', 6}, {'&', 7}, - {'+', 8}, - {'-', 8}, - {'*', 9}, - {'/', 9}, - {TK_FLOOR_DIV, 9}, - {'%', 9}, - {'@', 9}, - {TK_POW, 10}, + {TK_LSHIFT, 8}, + {TK_RSHIFT, 8}, + {'+', 9}, + {'-', 9}, + {'*', 10}, + {'/', 10}, + {TK_FLOOR_DIV, 10}, + {'%', 10}, + {'@', 10}, + {TK_POW, 11}, }; static const std::unordered_map unary_prec = { {TK_NOT, 3}, {'~', 3}, - {'-', 9}, - {'*', 9}, + {'-', 10}, + {'*', 10}, }; bool SharedParserData::isUnary(int kind, int* prec) { diff --git a/torch/csrc/jit/frontend/lexer.h b/torch/csrc/jit/frontend/lexer.h index 1457b78049a0..2b62edc3c49c 100644 --- a/torch/csrc/jit/frontend/lexer.h +++ b/torch/csrc/jit/frontend/lexer.h @@ -72,6 +72,8 @@ namespace jit { _(TK_AND, "and", "and") \ _(TK_OR, "or", "or") \ _(TK_NOT, "not", "not") \ + _(TK_LSHIFT, "<<", "<<") \ + _(TK_RSHIFT, ">>", ">>") \ _(TK_CAST, "cast", "") \ _(TK_PLUS_EQ, "+=", "+=") \ _(TK_MINUS_EQ, "-=", "-=") \ diff --git a/torch/csrc/jit/frontend/tree_views.h b/torch/csrc/jit/frontend/tree_views.h index 067a77d6a353..db67b7533882 100644 --- a/torch/csrc/jit/frontend/tree_views.h +++ b/torch/csrc/jit/frontend/tree_views.h @@ -298,6 +298,8 @@ struct Expr : public TreeView { case TK_DICT_LITERAL: case '@': case TK_POW: + case TK_LSHIFT: + case TK_RSHIFT: case TK_FLOOR_DIV: case '&': case '^': @@ -739,6 +741,8 @@ struct BinOp : public Expr { case '-': case '@': case TK_POW: + case TK_LSHIFT: + case TK_RSHIFT: case '%': case '&': case '^': diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index 89b67b3f5316..50a213c46435 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -2679,6 +2679,8 @@ RegisterOperators reg2({ DEFINE_INT_OP(aten::__and__, a& b), DEFINE_INT_OP(aten::__or__, a | b), DEFINE_INT_OP(aten::__xor__, a ^ b), + DEFINE_INT_OP(aten::__lshift__, a << b), + DEFINE_INT_OP(aten::__rshift__, a >> b), DEFINE_UNARY_OP(aten::floor, floor(a), int, int), DEFINE_UNARY_OP(aten::ceil, ceil(a), int, int), diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 5d07c3e8deb4..184af77f03e7 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -398,6 +398,8 @@ class ExprBuilder(Builder): ast.BitAnd: '&', ast.BitXor: '^', ast.BitOr: '|', + ast.LShift: '<<', + ast.RShift: '>>', } if not PY2: