mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
401 lines
14 KiB
Python
401 lines
14 KiB
Python
import torch
|
|
import sys
|
|
import ast
|
|
import inspect
|
|
import string
|
|
from textwrap import dedent
|
|
from functools import partial
|
|
from collections import namedtuple
|
|
from torch._C._jit_tree_views import *
|
|
|
|
PY2 = sys.version_info[0] == 2
|
|
_reserved_prefix = '__jit'
|
|
_identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits)
|
|
|
|
pretty_node_names = {
|
|
ast.For: "for loops",
|
|
ast.Delete: "del statements",
|
|
ast.ClassDef: "class definitions",
|
|
ast.With: "with statements",
|
|
ast.Raise: "raise statements",
|
|
ast.Assert: "assertions",
|
|
ast.Import: "import statements",
|
|
ast.ImportFrom: "import statements",
|
|
ast.Global: "global variables",
|
|
ast.Break: "break statements",
|
|
ast.Continue: "continue statements",
|
|
}
|
|
|
|
node_start_tokens = {
|
|
ast.For: "for",
|
|
ast.Delete: "del",
|
|
ast.ClassDef: "class",
|
|
ast.With: "with",
|
|
ast.Raise: "raise",
|
|
ast.Assert: "assert",
|
|
ast.Import: "import",
|
|
ast.ImportFrom: "from",
|
|
ast.Global: "global",
|
|
ast.Break: "break",
|
|
ast.Continue: "continue",
|
|
}
|
|
|
|
if PY2:
|
|
pretty_node_names.update({
|
|
ast.Print: "print statements",
|
|
ast.TryExcept: "try blocks",
|
|
ast.TryFinally: "try blocks",
|
|
ast.Exec: "exec statements",
|
|
})
|
|
|
|
node_start_tokens.update({
|
|
ast.Print: "print",
|
|
ast.TryExcept: "try",
|
|
ast.TryFinally: "try",
|
|
ast.Exec: "exec",
|
|
})
|
|
else:
|
|
pretty_node_names.update({
|
|
ast.AsyncFor: "async for loops",
|
|
ast.AsyncWith: "async with statements",
|
|
ast.Try: "try blocks",
|
|
ast.Nonlocal: "nonlocal variables",
|
|
})
|
|
|
|
node_start_tokens.update({
|
|
ast.AsyncFor: "async for",
|
|
ast.AsyncWith: "async with",
|
|
ast.Try: "try",
|
|
ast.Nonlocal: "nonlocal",
|
|
})
|
|
|
|
if sys.version_info >= (3, 6):
|
|
pretty_node_names.update({
|
|
ast.AnnAssign: "annotated assignments",
|
|
})
|
|
# NB: no specific token for AnnAssign
|
|
|
|
|
|
class FrontendError(Exception):
|
|
def __init__(self, source_range, msg):
|
|
self.source_range = source_range
|
|
self.msg = msg
|
|
|
|
def __str__(self):
|
|
result = self.msg
|
|
if self.source_range is not None:
|
|
result += '\n' + self.source_range.highlight()
|
|
return result
|
|
|
|
|
|
class NotSupportedError(FrontendError):
|
|
pass
|
|
|
|
|
|
class UnsupportedNodeError(NotSupportedError):
|
|
def __init__(self, ctx, offending_node):
|
|
# If we don't have a specific token, we default to length of 1
|
|
range_len = len(node_start_tokens.get(type(offending_node), ' '))
|
|
source_range = ctx.make_range(offending_node.lineno,
|
|
offending_node.col_offset,
|
|
offending_node.col_offset + range_len)
|
|
feature_name = pretty_node_names.get(node_type, node_type.__name__)
|
|
msg = "{} aren't supported".format(feature_name)
|
|
super(NotSupportedError, self).__init__(source_range, msg)
|
|
|
|
|
|
class FrontendTypeError(FrontendError):
|
|
pass
|
|
|
|
|
|
def get_jit_ast(fn):
|
|
source = dedent(inspect.getsource(fn))
|
|
py_ast = ast.parse(source)
|
|
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
|
|
raise RuntimeError("expected a single top-level function")
|
|
return build_def(SourceRangeFactory(source), py_ast.body[0])
|
|
|
|
|
|
class Builder(object):
|
|
def __call__(self, ctx, node):
|
|
method = getattr(self, 'build_' + node.__class__.__name__, None)
|
|
if method is None:
|
|
raise UnsupportedNodeError(ctx, node)
|
|
return method(ctx, node)
|
|
|
|
|
|
class CountReturns(ast.NodeVisitor):
|
|
def __init__(self):
|
|
self.num_returns = 0
|
|
|
|
def visit_Return(self, ret):
|
|
self.num_returns += 1
|
|
|
|
@staticmethod
|
|
def get_count(py_def):
|
|
counter = CountReturns()
|
|
counter.visit(py_def)
|
|
return counter.num_returns
|
|
|
|
|
|
_ret_err_msg = ("JIT-ed functions can only have a single return, "
|
|
"and it has to be the last statement in the body")
|
|
|
|
|
|
def build_def(ctx, py_def):
|
|
returns = []
|
|
ret_body = []
|
|
body = py_def.body
|
|
num_returns = CountReturns.get_count(py_def)
|
|
# TODO: change TorchScript AST to have a Return statement
|
|
if num_returns == 1:
|
|
ret_stmt, body = body[-1], body[:-1]
|
|
if not isinstance(ret_stmt, ast.Return):
|
|
raise ValueError(_ret_err_msg)
|
|
ret_expr = ret_stmt.value
|
|
ret_vals = ret_expr.elts if isinstance(ret_expr, ast.Tuple) else [ret_expr]
|
|
for i, val in enumerate(ret_vals):
|
|
val_expr = build_expr(ctx, val)
|
|
val_name = _reserved_prefix + '_' + str(i)
|
|
r = val_expr.range()
|
|
returns.append(Param(TensorType(r), Ident(r, val_name)))
|
|
ret_body.append(Assign([Ident(r, val_name)], '=', val_expr))
|
|
elif num_returns > 1:
|
|
raise ValueError(_ret_err_msg)
|
|
r = ctx.make_range(py_def.lineno, py_def.col_offset,
|
|
py_def.col_offset + len("def"))
|
|
return Def(Ident(r, py_def.name),
|
|
build_param_list(ctx, py_def.args),
|
|
returns,
|
|
[build_stmt(ctx, stmt) for stmt in body] + ret_body)
|
|
|
|
|
|
_vararg_kwarg_err = ("Compiled functions can't take variable number of arguments, "
|
|
"have default values for arguments, nor keyword-only arguments")
|
|
|
|
|
|
def build_param_list(ctx, py_args):
|
|
if py_args.vararg is not None or py_args.kwarg is not None or py_args.defaults:
|
|
raise ValueError(_vararg_kwarg_err)
|
|
if not PY2 and (py_args.kw_defaults or py_args.kwonlyargs):
|
|
raise ValueError(_vararg_kwarg_err)
|
|
return [build_param(ctx, arg) for arg in py_args.args]
|
|
|
|
|
|
def build_param(ctx, py_arg):
|
|
# NB: In Python3 py_arg is a pair of (str arg, expr? annotation)
|
|
# In Python2 py_arg is a Name (Expr subclass)
|
|
if getattr(py_arg, 'annotation', None) is not None:
|
|
raise ValueError("Compiled functions don't support annotations")
|
|
name = py_arg.id if PY2 else py_arg.arg
|
|
r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name))
|
|
return Param(TensorType(r), Ident(r, name))
|
|
|
|
|
|
class StmtBuilder(Builder):
|
|
augassign_map = {
|
|
ast.Add: '+',
|
|
ast.Sub: '-',
|
|
ast.Mult: '*',
|
|
ast.Div: '/',
|
|
}
|
|
|
|
@staticmethod
|
|
def build_Expr(ctx, stmt):
|
|
return ExprStmt(build_expr(ctx, stmt.value))
|
|
|
|
@staticmethod
|
|
def get_assign_ident(ctx, expr):
|
|
var = build_expr(ctx, expr)
|
|
if not isinstance(var, Var):
|
|
raise NotSupportedError("the only expressions allowed on the left hand side of "
|
|
"assignments are variable names", var.range())
|
|
return var.name()
|
|
|
|
@staticmethod
|
|
def build_Assign(ctx, stmt):
|
|
return Assign([StmtBuilder.get_assign_ident(ctx, e) for e in stmt.targets],
|
|
'=',
|
|
build_expr(ctx, stmt.value))
|
|
|
|
@staticmethod
|
|
def build_AugAssign(ctx, stmt):
|
|
lhs = [StmtBuilder.get_assign_ident(ctx, stmt.target)]
|
|
rhs = build_expr(ctx, stmt.value)
|
|
op = type(stmt.op)
|
|
if op in StmtBuilder.augassign_map:
|
|
op_token = StmtBuilder.augassign_map[op]
|
|
else:
|
|
raise NotSupportedError(
|
|
find_before(ctx, rhs.range().start, '=', offsets=(-1, 0)),
|
|
"unsupported kind of augumented assignment: " + op.__name__)
|
|
return Assign(lhs, op_token, rhs)
|
|
|
|
@staticmethod
|
|
def build_While(ctx, stmt):
|
|
if stmt.orelse:
|
|
# TODO: try to recover the location of else:? Python doesn't give us useful
|
|
# annotations in this case
|
|
raise NotSupportedError(None, "else branches of while loops aren't supported")
|
|
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("while"))
|
|
return While(r, build_expr(ctx, stmt.test), [build_stmt(ctx, s) for s in stmt.body])
|
|
|
|
@staticmethod
|
|
def build_If(ctx, stmt):
|
|
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("if"))
|
|
return If(r, build_expr(ctx, stmt.test),
|
|
[build_stmt(ctx, s) for s in stmt.body],
|
|
[build_stmt(ctx, s) for s in stmt.orelse])
|
|
|
|
|
|
class ExprBuilder(Builder):
|
|
_MethodRef = namedtuple('MethodRef', ['self', 'name'])
|
|
binop_map = {
|
|
ast.Add: '+',
|
|
ast.Sub: '-',
|
|
ast.Mult: '*',
|
|
ast.Div: '/',
|
|
}
|
|
|
|
unop_map = {
|
|
ast.Not: 'not',
|
|
ast.USub: '-',
|
|
}
|
|
|
|
boolop_map = {
|
|
ast.And: 'and',
|
|
ast.Or: 'or',
|
|
}
|
|
|
|
cmpop_map = {
|
|
ast.Eq: '==',
|
|
ast.NotEq: '!=',
|
|
ast.LtE: '<=',
|
|
ast.Lt: '<',
|
|
ast.GtE: '>=',
|
|
ast.Gt: '>',
|
|
}
|
|
|
|
@staticmethod
|
|
def build_Attribute(ctx, expr):
|
|
# NB: the only attributes we support are for getting methods
|
|
value = build_expr(ctx, expr.value)
|
|
# <sigh> name is just a string, so it's not annotated in any way.
|
|
source = ctx.source
|
|
pos = find_after(ctx, value.range().end, '.').end # Start with the dot
|
|
while source[pos] in string.whitespace: # Skip whitespace
|
|
pos += 1
|
|
start_pos = pos
|
|
while source[pos] in _identifier_chars: # Find the identifier itself
|
|
pos += 1
|
|
name_range = ctx.make_raw_range(start_pos, pos)
|
|
return ExprBuilder._MethodRef(value, Ident(name_range, expr.attr))
|
|
|
|
@staticmethod
|
|
def build_Call(ctx, expr):
|
|
ref = build_expr(ctx, expr.func, allow_methods=True)
|
|
if type(ref) is not ExprBuilder._MethodRef:
|
|
ref_range = ref.range()
|
|
parenthesis_range = find_after(ctx, ref_range.end, '(')
|
|
raise FrontendTypeError(
|
|
ctx.make_raw_range(ref_range.start, parenthesis_range.end),
|
|
"trying to call a non-function object")
|
|
args = [build_expr(ctx, py_arg) for py_arg in expr.args]
|
|
kwargs = [Attribute(Ident(name), build_expr(ctx, value)) for name, value in expr.keywords]
|
|
return Apply(ref.name, [ref.self] + args, kwargs)
|
|
|
|
@staticmethod
|
|
def build_Name(ctx, expr):
|
|
r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
|
|
if expr.id.startswith(_reserved_prefix):
|
|
raise NotSupportedError(r, "names of variables used in JIT-ed functions "
|
|
"can't start with " + _reserved_prefix)
|
|
return Var(Ident(r, expr.id))
|
|
|
|
@staticmethod
|
|
def build_BinOp(ctx, expr):
|
|
lhs = build_expr(ctx, expr.left)
|
|
rhs = build_expr(ctx, expr.right)
|
|
op = type(expr.op)
|
|
op_token = ExprBuilder.binop_map.get(op)
|
|
if op_token is None:
|
|
err_range = ctx.make_range(lhs.range().end, rhs.range().start)
|
|
raise NotSupportedError(err_range, "unsupported binary operator: " + op.__name__)
|
|
return BinOp(op_token, lhs, rhs)
|
|
|
|
@staticmethod
|
|
def build_UnaryOp(ctx, expr):
|
|
sub_expr = build_expr(ctx, expr.operand)
|
|
op = type(expr.op)
|
|
op_token = ExprBuilder.unop_map.get(op)
|
|
r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(op_token))
|
|
if op_token is None:
|
|
err_range = ctx.make_raw_range(r.start, sub_expr.range().end)
|
|
raise NotSupportedError(err_range, "unsupported unary operator: " + op.__name__)
|
|
return UnaryOp(r, op_token, sub_expr)
|
|
|
|
@staticmethod
|
|
def build_BoolOp(ctx, expr):
|
|
if len(expr.values) < 2:
|
|
raise AssertionError("expected at least 2 values in BoolOp, but got " + str(len(expr.values)))
|
|
sub_exprs = [build_expr(ctx, sub_expr) for sub_expr in expr.values]
|
|
op = type(expr.op)
|
|
op_token = ExprBuilder.boolop_map.get(op)
|
|
if op_token is None:
|
|
err_range = ctx.make_raw_range(sub_exprs[0].range().end, sub_exprs[1].range().start)
|
|
raise NotSupportedError(err_range, "unsupported boolean operator: " + op.__name__)
|
|
lhs = sub_exprs[0]
|
|
for rhs in sub_exprs[1:]:
|
|
lhs = BinOp(op_token, lhs, rhs)
|
|
return lhs
|
|
|
|
@staticmethod
|
|
def build_IfExp(ctx, expr):
|
|
return TernaryIf(build_expr(ctx, expr.test),
|
|
build_expr(ctx, expr.body),
|
|
build_expr(ctx, expr.orelse))
|
|
|
|
@staticmethod
|
|
def build_Compare(ctx, expr):
|
|
operands = [build_expr(ctx, e) for e in [expr.left] + list(expr.comparators)]
|
|
result = None
|
|
for lhs, op_, rhs in zip(operands, expr.ops, operands[1:]):
|
|
op = type(op_)
|
|
op_token = ExprBuilder.cmpop_map.get(op)
|
|
if op_token is None:
|
|
err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
|
|
raise NotSupportedError(err_range, "unsupported comparison operator: " + op.__name__)
|
|
cmp_expr = BinOp(op_token, lhs, rhs)
|
|
if result is None:
|
|
result = cmp_expr
|
|
else:
|
|
result = BinOp('and', result, cmp_expr)
|
|
return result
|
|
|
|
@staticmethod
|
|
def build_Num(ctx, expr):
|
|
# TODO: fix this once we have a nice Number node in our AST
|
|
err_range = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
|
|
raise NotSupportedError(err_range, "scalar constants aren't supported")
|
|
|
|
def __call__(self, ctx, expr, allow_methods=False):
|
|
result = super(ExprBuilder, self).__call__(ctx, expr)
|
|
if type(result) is ExprBuilder._MethodRef and not allow_methods:
|
|
err_range = ctx.make_raw_range(result.self.range().start, result.name.range().end)
|
|
raise FrontendTypeError(err_range, "taking attributes/function values isn't supported")
|
|
return result
|
|
|
|
|
|
build_expr = ExprBuilder()
|
|
build_stmt = StmtBuilder()
|
|
|
|
|
|
def find_after(ctx, pos, substr, offsets=(0, 0)):
|
|
new_pos = pos + ctx.source[pos:].index(substr)
|
|
return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])
|
|
|
|
|
|
def find_before(ctx, pos, substr, offsets=(0, 0)):
|
|
new_pos = ctx.source[:pos].rindex(substr)
|
|
return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])
|