mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support left and right shift operators in JIT (#34563)
Summary: With this PR, we can now support left and right shift operators in the JIT engine for <int, int> and <Tensor, int>. Updated tests pass as expected: ``` > python test/test_jit.py ... Ran 2427 tests in 84.861s OK (skipped=139, expected failures=1) ``` Running the following code with Python results in the output below: ``` > cat ~/expressions.py import torch torch.jit.script def fn(a, b): # type: (int, int) return ( a << b, # supported b >> a, # supported a & b, a | b, a ^ b ) print(fn.graph) ``` ``` > python ~/expressions.py graph(%a.1 : int, %b.1 : int): %4 : int = aten::leftshift(%a.1, %b.1) # /home/ince/expressions.py:7:8 %7 : int = aten::rightshift(%b.1, %a.1) # /home/ince/expressions.py:8:8 %10 : int = aten::__and__(%a.1, %b.1) # /home/ince/expressions.py:9:8 %13 : int = aten::__or__(%a.1, %b.1) # /home/ince/expressions.py:10:8 %16 : int = aten::__xor__(%a.1, %b.1) # /home/ince/expressions.py:11:8 %17 : (int, int, int, int, int) = prim::TupleConstruct(%4, %7, %10, %13, %16) return (%17) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/34563 Differential Revision: D20434209 Pulled By: tugrulince fbshipit-source-id: 886386c59755106e17b84778b8e495b80a6269cd
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c34ee4fb6e
commit
c9023e3b12
@ -8374,17 +8374,24 @@ a")
|
||||
.check("Traceback") \
|
||||
.check("in foo").check("in baz").run(str(cm.exception))
|
||||
|
||||
def test_binop_unsupported_error(self):
|
||||
with self.assertRaisesRegex(NotSupportedError, "unsupported binary operator:"):
|
||||
@torch.jit.script
|
||||
def binop(x, y):
|
||||
# Replace this with another unsupported op when/if it gets supported
|
||||
return x << y
|
||||
def test_operator_precedence(self):
|
||||
def double(x):
|
||||
# type: (int) -> int
|
||||
return 2 * x
|
||||
|
||||
def complicated_arithmetic_operation():
|
||||
# TODO we need to test exponent operator '**' and bitwise not
|
||||
# operator '~' once they are properly supported.
|
||||
list = [0, 1, 2, 3]
|
||||
result = list[1:3][0] + double(4) + (-3 + 8) * 6 // 2 % 4 << 2 + 1 >> 1 | 23 & 16 + 3 ^ 4
|
||||
return result
|
||||
|
||||
self.checkScript(complicated_arithmetic_operation, ())
|
||||
|
||||
def test_bitwise_ops(self):
|
||||
|
||||
def int_test():
|
||||
return 2 & 3, 2 ^ 3, 2 | 3
|
||||
return 2 & 3, 2 ^ 3, 2 | 3, 2 << 3, 2 >> 3
|
||||
|
||||
self.checkScript(int_test, ())
|
||||
|
||||
@ -8398,10 +8405,15 @@ a")
|
||||
def tensor_test(x, y):
|
||||
return x & y, x ^ y, x | y
|
||||
|
||||
def tensor_with_int_test(x, y):
|
||||
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
|
||||
return x << y, x >> y
|
||||
|
||||
x = torch.tensor(2)
|
||||
y = torch.tensor(3)
|
||||
|
||||
self.checkScript(tensor_test, (x, y))
|
||||
self.checkScript(tensor_with_int_test, (x, 2))
|
||||
|
||||
def not_test(x):
|
||||
return ~x
|
||||
|
@ -2278,6 +2278,10 @@ struct to_ir {
|
||||
return aten::__not__;
|
||||
case TK_FLOOR_DIV:
|
||||
return aten::floordiv;
|
||||
case TK_LSHIFT:
|
||||
return aten::__lshift__;
|
||||
case TK_RSHIFT:
|
||||
return aten::__rshift__;
|
||||
case '&':
|
||||
return aten::__and__;
|
||||
case '|':
|
||||
@ -2329,6 +2333,10 @@ struct to_ir {
|
||||
return "__xor__";
|
||||
case TK_IN:
|
||||
return "__contains__";
|
||||
case TK_LSHIFT:
|
||||
return "__lshift__";
|
||||
case TK_RSHIFT:
|
||||
return "__rshift__";
|
||||
default:
|
||||
throw std::runtime_error("unknown kind " + c10::to_string(kind));
|
||||
}
|
||||
@ -2868,7 +2876,9 @@ struct to_ir {
|
||||
case '%':
|
||||
case '&':
|
||||
case '|':
|
||||
case '^': {
|
||||
case '^':
|
||||
case TK_LSHIFT:
|
||||
case TK_RSHIFT: {
|
||||
const auto& inputs = tree->trees();
|
||||
auto kind = getNodeKind(tree->kind(), inputs.size());
|
||||
auto overload = getOperatorOverload(tree->kind(), inputs.size());
|
||||
|
@ -28,21 +28,23 @@ static const std::unordered_map<int, int> binary_prec = {
|
||||
{'|', 5},
|
||||
{'^', 6},
|
||||
{'&', 7},
|
||||
{'+', 8},
|
||||
{'-', 8},
|
||||
{'*', 9},
|
||||
{'/', 9},
|
||||
{TK_FLOOR_DIV, 9},
|
||||
{'%', 9},
|
||||
{'@', 9},
|
||||
{TK_POW, 10},
|
||||
{TK_LSHIFT, 8},
|
||||
{TK_RSHIFT, 8},
|
||||
{'+', 9},
|
||||
{'-', 9},
|
||||
{'*', 10},
|
||||
{'/', 10},
|
||||
{TK_FLOOR_DIV, 10},
|
||||
{'%', 10},
|
||||
{'@', 10},
|
||||
{TK_POW, 11},
|
||||
};
|
||||
|
||||
static const std::unordered_map<int, int> unary_prec = {
|
||||
{TK_NOT, 3},
|
||||
{'~', 3},
|
||||
{'-', 9},
|
||||
{'*', 9},
|
||||
{'-', 10},
|
||||
{'*', 10},
|
||||
};
|
||||
|
||||
bool SharedParserData::isUnary(int kind, int* prec) {
|
||||
|
@ -72,6 +72,8 @@ namespace jit {
|
||||
_(TK_AND, "and", "and") \
|
||||
_(TK_OR, "or", "or") \
|
||||
_(TK_NOT, "not", "not") \
|
||||
_(TK_LSHIFT, "<<", "<<") \
|
||||
_(TK_RSHIFT, ">>", ">>") \
|
||||
_(TK_CAST, "cast", "") \
|
||||
_(TK_PLUS_EQ, "+=", "+=") \
|
||||
_(TK_MINUS_EQ, "-=", "-=") \
|
||||
|
@ -298,6 +298,8 @@ struct Expr : public TreeView {
|
||||
case TK_DICT_LITERAL:
|
||||
case '@':
|
||||
case TK_POW:
|
||||
case TK_LSHIFT:
|
||||
case TK_RSHIFT:
|
||||
case TK_FLOOR_DIV:
|
||||
case '&':
|
||||
case '^':
|
||||
@ -739,6 +741,8 @@ struct BinOp : public Expr {
|
||||
case '-':
|
||||
case '@':
|
||||
case TK_POW:
|
||||
case TK_LSHIFT:
|
||||
case TK_RSHIFT:
|
||||
case '%':
|
||||
case '&':
|
||||
case '^':
|
||||
|
@ -2679,6 +2679,8 @@ RegisterOperators reg2({
|
||||
DEFINE_INT_OP(aten::__and__, a& b),
|
||||
DEFINE_INT_OP(aten::__or__, a | b),
|
||||
DEFINE_INT_OP(aten::__xor__, a ^ b),
|
||||
DEFINE_INT_OP(aten::__lshift__, a << b),
|
||||
DEFINE_INT_OP(aten::__rshift__, a >> b),
|
||||
|
||||
DEFINE_UNARY_OP(aten::floor, floor(a), int, int),
|
||||
DEFINE_UNARY_OP(aten::ceil, ceil(a), int, int),
|
||||
|
@ -398,6 +398,8 @@ class ExprBuilder(Builder):
|
||||
ast.BitAnd: '&',
|
||||
ast.BitXor: '^',
|
||||
ast.BitOr: '|',
|
||||
ast.LShift: '<<',
|
||||
ast.RShift: '>>',
|
||||
}
|
||||
|
||||
if not PY2:
|
||||
|
Reference in New Issue
Block a user