From 9642c7568967ab424c5d0e04ef2cd1e82a54b5f8 Mon Sep 17 00:00:00 2001 From: morrison-turnansky Date: Wed, 25 Jun 2025 06:15:13 +0000 Subject: [PATCH] added stubs for jit tree views (#156504) Fixes #156488 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156504 Approved by: https://github.com/ezyang --- mypy.ini | 3 - torch/_C/_jit_tree_views.pyi | 202 +++++++++++++++++++++++++++++++++++ torch/jit/frontend.py | 10 +- 3 files changed, 209 insertions(+), 6 deletions(-) create mode 100644 torch/_C/_jit_tree_views.pyi diff --git a/mypy.ini b/mypy.ini index 1a32b414820f..e6a8af4c88c2 100644 --- a/mypy.ini +++ b/mypy.ini @@ -55,9 +55,6 @@ python_version = 3.11 # Extension modules without stubs. # -[mypy-torch._C._jit_tree_views] -ignore_missing_imports = True - [mypy-torch.for_onnx.onnx] ignore_missing_imports = True diff --git a/torch/_C/_jit_tree_views.pyi b/torch/_C/_jit_tree_views.pyi new file mode 100644 index 000000000000..cf4cffc05a9c --- /dev/null +++ b/torch/_C/_jit_tree_views.pyi @@ -0,0 +1,202 @@ +from typing import Any, Optional + +# Defined in torch/csrc/jit/python/python_tree_views.cpp + +class SourceRange: + def highlight(self) -> str: ... + @property + def start(self) -> int: ... + @property + def end(self) -> int: ... + +class SourceRangeFactory: + def __init__( + self, + text: str, + filename: Any, + file_lineno: int, + leading_whitespace_chars: int, + ) -> None: ... + def make_range(self, line: int, start_col: int, end_col: int) -> SourceRange: ... + def make_raw_range(self, start: int, end: int) -> SourceRange: ... + @property + def source(self) -> str: ... + +class TreeView: + def range(self) -> SourceRange: ... + def dump(self) -> None: ... + +class Ident(TreeView): + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + @property + def name(self) -> str: ... + +class Param(TreeView): + def __init__(self, type: Optional[Any], name: Ident, kwarg_only: bool) -> None: ... + +class Attribute(TreeView): + def __init__(self, name: Ident, value: Any) -> None: ... + +# Literals +def TrueLiteral(range: SourceRange) -> Any: ... +def FalseLiteral(range: SourceRange) -> Any: ... +def NoneLiteral(range: SourceRange) -> Any: ... + +# Tree nodes +class Stmt(TreeView): + def __init__(self, thing: TreeView) -> None: ... + +class Expr(TreeView): ... + +class Def(TreeView): + def __init__(self, name: Ident, decl: Any, body: list[Stmt]) -> None: ... + def decl(self) -> Any: ... + def name(self) -> Ident: ... + +class Property(TreeView): + def __init__( + self, r: SourceRange, name: Ident, getter: Def, setter: Optional[Def] + ) -> None: ... + def name(self) -> Ident: ... + def getter_name(self) -> str: ... + def setter_name(self) -> Optional[Ident]: ... + +class ClassDef(TreeView): + def __init__( + self, name: Ident, body: list[Stmt], props: list[Property], assigns: list[Any] + ) -> None: ... + +class Decl(TreeView): + def __init__( + self, r: SourceRange, params: list[Param], return_type: Optional[Expr] + ) -> None: ... + +class Delete(Stmt): + def __init__(self, range: SourceRange, targets: list[Expr]) -> None: ... + +class WithItem(Expr): + def __init__( + self, range: SourceRange, target: Expr, var: Optional[Any] + ) -> None: ... + +class Assign(Stmt): + def __init__( + self, lhs: list[Expr], rhs: Expr, type: Optional[Expr] = None + ) -> None: ... + +class AugAssign(Stmt): + def __init__(self, lhs: Expr, kind_str: str, rhs: Expr) -> None: ... + +class Return(Stmt): + def __init__(self, range: SourceRange, value: Optional[Expr]) -> None: ... + +class Raise(Stmt): + def __init__(self, range: SourceRange, expr: Expr) -> None: ... + +class Assert(Stmt): + def __init__(self, range: SourceRange, test: Expr, msg: Optional[Expr]) -> None: ... + +class Pass(Stmt): + def __init__(self, range: SourceRange) -> None: ... + +class Break(Stmt): ... +class Continue(Stmt): ... + +class Dots(Expr, TreeView): + def __init__(self, range: SourceRange) -> None: ... + +class If(Stmt): + def __init__( + self, + range: SourceRange, + cond: Expr, + true_branch: list[Stmt], + false_branch: list[Stmt], + ) -> None: ... + +class While(Stmt): + def __init__(self, range: SourceRange, cond: Expr, body: list[Stmt]) -> None: ... + +class With(Stmt): + def __init__( + self, range: SourceRange, targets: list[WithItem], body: list[Stmt] + ) -> None: ... + +class For(Stmt): + def __init__( + self, + range: SourceRange, + targets: list[Expr], + itrs: list[Expr], + body: list[Stmt], + ) -> None: ... + +class ExprStmt(Stmt): + def __init__(self, expr: Expr) -> None: ... + +class Var(Expr): + def __init__(self, name: Ident) -> None: ... + @property + def name(self) -> str: ... + +class BinOp(Expr): + def __init__(self, kind: str, lhs: Expr, rhs: Expr) -> None: ... + +class UnaryOp(Expr): + def __init__(self, range: SourceRange, kind: str, expr: Expr) -> None: ... + +class Const(Expr): + def __init__(self, range: SourceRange, value: str) -> None: ... + +class StringLiteral(Expr): + def __init__(self, range: SourceRange, value: str) -> None: ... + +class Apply(Expr): + def __init__( + self, expr: Expr, args: list[Expr], kwargs: list[Attribute] + ) -> None: ... + +class Select(Expr): + def __init__(self, expr: Expr, field: Ident) -> None: ... + +class TernaryIf(Expr): + def __init__(self, cond: Expr, true_expr: Expr, false_expr: Expr) -> None: ... + +class ListComp(Expr): + def __init__( + self, range: SourceRange, elt: Expr, target: Expr, iter: Expr + ) -> None: ... + +class DictComp(Expr): + def __init__( + self, range: SourceRange, key: Expr, value: Expr, target: Expr, iter: Expr + ) -> None: ... + +class ListLiteral(Expr): + def __init__(self, range: SourceRange, args: list[Expr]) -> None: ... + +class TupleLiteral(Expr): + def __init__(self, range: SourceRange, args: list[Expr]) -> None: ... + +class DictLiteral(Expr): + def __init__( + self, range: SourceRange, keys: list[Expr], values: list[Expr] + ) -> None: ... + +class Subscript(Expr): + def __init__(self, base: Expr, subscript_exprs: list[Expr]) -> None: ... + +class SliceExpr(Expr): + def __init__( + self, + range: SourceRange, + lower: Optional[Expr], + upper: Optional[Expr], + step: Optional[Expr], + ) -> None: ... + +class Starred(Expr): + def __init__(self, range: SourceRange, expr: Expr) -> None: ... + +class EmptyTypeAnnotation(TreeView): + def __init__(self, range: SourceRange) -> None: ... diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 76682e752299..ce6ab33db2a1 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -438,7 +438,11 @@ def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=No is_method = self_name is not None if type_line is not None: type_comment_decl = torch._C.parse_type_comment(type_line) - decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method) + decl = torch._C.merge_type_from_type_comment( + decl, # type: ignore[arg-type] + type_comment_decl, + is_method, # type: ignore[assignment] + ) return Def(Ident(r, def_name), decl, build_stmts(ctx, body)) @@ -1055,12 +1059,12 @@ class ExprBuilder(Builder): in_expr = BinOp("in", lhs, rhs) cmp_expr = UnaryOp(r, "not", in_expr) else: - cmp_expr = BinOp(op_token, lhs, rhs) + cmp_expr = BinOp(op_token, lhs, rhs) # type: ignore[assignment] if result is None: result = cmp_expr else: - result = BinOp("and", result, cmp_expr) + result = BinOp("and", result, cmp_expr) # type: ignore[assignment] return result @staticmethod