Support left and right shift operators in JIT (#34563)

Summary:
With this PR, we can now support left and right shift operators in the JIT engine for <int, int> and <Tensor, int>.

Updated tests pass as expected:
```
> python test/test_jit.py
...
Ran 2427 tests in 84.861s

OK (skipped=139, expected failures=1)
```

Running the following code with Python results in the output below:
```
> cat ~/expressions.py
import torch

torch.jit.script
def fn(a, b):
    # type: (int, int)
    return (
        a << b,  # supported
        b >> a,  # supported
        a & b,
        a | b,
        a ^ b
    )
print(fn.graph)
```

```
> python ~/expressions.py
graph(%a.1 : int,
      %b.1 : int):
  %4 : int = aten::leftshift(%a.1, %b.1) # /home/ince/expressions.py:7:8
  %7 : int = aten::rightshift(%b.1, %a.1) # /home/ince/expressions.py:8:8
  %10 : int = aten::__and__(%a.1, %b.1) # /home/ince/expressions.py:9:8
  %13 : int = aten::__or__(%a.1, %b.1) # /home/ince/expressions.py:10:8
  %16 : int = aten::__xor__(%a.1, %b.1) # /home/ince/expressions.py:11:8
  %17 : (int, int, int, int, int) = prim::TupleConstruct(%4, %7, %10, %13, %16)
  return (%17)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34563

Differential Revision: D20434209

Pulled By: tugrulince

fbshipit-source-id: 886386c59755106e17b84778b8e495b80a6269cd
This commit is contained in:
Tugrul Ince
2020-03-13 12:49:41 -07:00
committed by Facebook GitHub Bot
parent c34ee4fb6e
commit c9023e3b12
7 changed files with 52 additions and 18 deletions

View File

@ -8374,17 +8374,24 @@ a")
.check("Traceback") \
.check("in foo").check("in baz").run(str(cm.exception))
def test_binop_unsupported_error(self):
with self.assertRaisesRegex(NotSupportedError, "unsupported binary operator:"):
@torch.jit.script
def binop(x, y):
# Replace this with another unsupported op when/if it gets supported
return x << y
def test_operator_precedence(self):
def double(x):
# type: (int) -> int
return 2 * x
def complicated_arithmetic_operation():
# TODO we need to test exponent operator '**' and bitwise not
# operator '~' once they are properly supported.
list = [0, 1, 2, 3]
result = list[1:3][0] + double(4) + (-3 + 8) * 6 // 2 % 4 << 2 + 1 >> 1 | 23 & 16 + 3 ^ 4
return result
self.checkScript(complicated_arithmetic_operation, ())
def test_bitwise_ops(self):
def int_test():
return 2 & 3, 2 ^ 3, 2 | 3
return 2 & 3, 2 ^ 3, 2 | 3, 2 << 3, 2 >> 3
self.checkScript(int_test, ())
@ -8398,10 +8405,15 @@ a")
def tensor_test(x, y):
return x & y, x ^ y, x | y
def tensor_with_int_test(x, y):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
return x << y, x >> y
x = torch.tensor(2)
y = torch.tensor(3)
self.checkScript(tensor_test, (x, y))
self.checkScript(tensor_with_int_test, (x, 2))
def not_test(x):
return ~x

View File

@ -2278,6 +2278,10 @@ struct to_ir {
return aten::__not__;
case TK_FLOOR_DIV:
return aten::floordiv;
case TK_LSHIFT:
return aten::__lshift__;
case TK_RSHIFT:
return aten::__rshift__;
case '&':
return aten::__and__;
case '|':
@ -2329,6 +2333,10 @@ struct to_ir {
return "__xor__";
case TK_IN:
return "__contains__";
case TK_LSHIFT:
return "__lshift__";
case TK_RSHIFT:
return "__rshift__";
default:
throw std::runtime_error("unknown kind " + c10::to_string(kind));
}
@ -2868,7 +2876,9 @@ struct to_ir {
case '%':
case '&':
case '|':
case '^': {
case '^':
case TK_LSHIFT:
case TK_RSHIFT: {
const auto& inputs = tree->trees();
auto kind = getNodeKind(tree->kind(), inputs.size());
auto overload = getOperatorOverload(tree->kind(), inputs.size());

View File

@ -28,21 +28,23 @@ static const std::unordered_map<int, int> binary_prec = {
{'|', 5},
{'^', 6},
{'&', 7},
{'+', 8},
{'-', 8},
{'*', 9},
{'/', 9},
{TK_FLOOR_DIV, 9},
{'%', 9},
{'@', 9},
{TK_POW, 10},
{TK_LSHIFT, 8},
{TK_RSHIFT, 8},
{'+', 9},
{'-', 9},
{'*', 10},
{'/', 10},
{TK_FLOOR_DIV, 10},
{'%', 10},
{'@', 10},
{TK_POW, 11},
};
static const std::unordered_map<int, int> unary_prec = {
{TK_NOT, 3},
{'~', 3},
{'-', 9},
{'*', 9},
{'-', 10},
{'*', 10},
};
bool SharedParserData::isUnary(int kind, int* prec) {

View File

@ -72,6 +72,8 @@ namespace jit {
_(TK_AND, "and", "and") \
_(TK_OR, "or", "or") \
_(TK_NOT, "not", "not") \
_(TK_LSHIFT, "<<", "<<") \
_(TK_RSHIFT, ">>", ">>") \
_(TK_CAST, "cast", "") \
_(TK_PLUS_EQ, "+=", "+=") \
_(TK_MINUS_EQ, "-=", "-=") \

View File

@ -298,6 +298,8 @@ struct Expr : public TreeView {
case TK_DICT_LITERAL:
case '@':
case TK_POW:
case TK_LSHIFT:
case TK_RSHIFT:
case TK_FLOOR_DIV:
case '&':
case '^':
@ -739,6 +741,8 @@ struct BinOp : public Expr {
case '-':
case '@':
case TK_POW:
case TK_LSHIFT:
case TK_RSHIFT:
case '%':
case '&':
case '^':

View File

@ -2679,6 +2679,8 @@ RegisterOperators reg2({
DEFINE_INT_OP(aten::__and__, a& b),
DEFINE_INT_OP(aten::__or__, a | b),
DEFINE_INT_OP(aten::__xor__, a ^ b),
DEFINE_INT_OP(aten::__lshift__, a << b),
DEFINE_INT_OP(aten::__rshift__, a >> b),
DEFINE_UNARY_OP(aten::floor, floor(a), int, int),
DEFINE_UNARY_OP(aten::ceil, ceil(a), int, int),

View File

@ -398,6 +398,8 @@ class ExprBuilder(Builder):
ast.BitAnd: '&',
ast.BitXor: '^',
ast.BitOr: '|',
ast.LShift: '<<',
ast.RShift: '>>',
}
if not PY2: