added stubs for jit tree views (#156504)

Fixes #156488

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156504
Approved by: https://github.com/ezyang
This commit is contained in:
morrison-turnansky
2025-06-25 06:15:13 +00:00
committed by PyTorch MergeBot
parent c60327ba74
commit 9642c75689
3 changed files with 209 additions and 6 deletions

View File

@ -55,9 +55,6 @@ python_version = 3.11
# Extension modules without stubs.
#
[mypy-torch._C._jit_tree_views]
ignore_missing_imports = True
[mypy-torch.for_onnx.onnx]
ignore_missing_imports = True

View File

@ -0,0 +1,202 @@
from typing import Any, Optional
# Defined in torch/csrc/jit/python/python_tree_views.cpp
class SourceRange:
def highlight(self) -> str: ...
@property
def start(self) -> int: ...
@property
def end(self) -> int: ...
class SourceRangeFactory:
def __init__(
self,
text: str,
filename: Any,
file_lineno: int,
leading_whitespace_chars: int,
) -> None: ...
def make_range(self, line: int, start_col: int, end_col: int) -> SourceRange: ...
def make_raw_range(self, start: int, end: int) -> SourceRange: ...
@property
def source(self) -> str: ...
class TreeView:
def range(self) -> SourceRange: ...
def dump(self) -> None: ...
class Ident(TreeView):
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
@property
def name(self) -> str: ...
class Param(TreeView):
def __init__(self, type: Optional[Any], name: Ident, kwarg_only: bool) -> None: ...
class Attribute(TreeView):
def __init__(self, name: Ident, value: Any) -> None: ...
# Literals
def TrueLiteral(range: SourceRange) -> Any: ...
def FalseLiteral(range: SourceRange) -> Any: ...
def NoneLiteral(range: SourceRange) -> Any: ...
# Tree nodes
class Stmt(TreeView):
def __init__(self, thing: TreeView) -> None: ...
class Expr(TreeView): ...
class Def(TreeView):
def __init__(self, name: Ident, decl: Any, body: list[Stmt]) -> None: ...
def decl(self) -> Any: ...
def name(self) -> Ident: ...
class Property(TreeView):
def __init__(
self, r: SourceRange, name: Ident, getter: Def, setter: Optional[Def]
) -> None: ...
def name(self) -> Ident: ...
def getter_name(self) -> str: ...
def setter_name(self) -> Optional[Ident]: ...
class ClassDef(TreeView):
def __init__(
self, name: Ident, body: list[Stmt], props: list[Property], assigns: list[Any]
) -> None: ...
class Decl(TreeView):
def __init__(
self, r: SourceRange, params: list[Param], return_type: Optional[Expr]
) -> None: ...
class Delete(Stmt):
def __init__(self, range: SourceRange, targets: list[Expr]) -> None: ...
class WithItem(Expr):
def __init__(
self, range: SourceRange, target: Expr, var: Optional[Any]
) -> None: ...
class Assign(Stmt):
def __init__(
self, lhs: list[Expr], rhs: Expr, type: Optional[Expr] = None
) -> None: ...
class AugAssign(Stmt):
def __init__(self, lhs: Expr, kind_str: str, rhs: Expr) -> None: ...
class Return(Stmt):
def __init__(self, range: SourceRange, value: Optional[Expr]) -> None: ...
class Raise(Stmt):
def __init__(self, range: SourceRange, expr: Expr) -> None: ...
class Assert(Stmt):
def __init__(self, range: SourceRange, test: Expr, msg: Optional[Expr]) -> None: ...
class Pass(Stmt):
def __init__(self, range: SourceRange) -> None: ...
class Break(Stmt): ...
class Continue(Stmt): ...
class Dots(Expr, TreeView):
def __init__(self, range: SourceRange) -> None: ...
class If(Stmt):
def __init__(
self,
range: SourceRange,
cond: Expr,
true_branch: list[Stmt],
false_branch: list[Stmt],
) -> None: ...
class While(Stmt):
def __init__(self, range: SourceRange, cond: Expr, body: list[Stmt]) -> None: ...
class With(Stmt):
def __init__(
self, range: SourceRange, targets: list[WithItem], body: list[Stmt]
) -> None: ...
class For(Stmt):
def __init__(
self,
range: SourceRange,
targets: list[Expr],
itrs: list[Expr],
body: list[Stmt],
) -> None: ...
class ExprStmt(Stmt):
def __init__(self, expr: Expr) -> None: ...
class Var(Expr):
def __init__(self, name: Ident) -> None: ...
@property
def name(self) -> str: ...
class BinOp(Expr):
def __init__(self, kind: str, lhs: Expr, rhs: Expr) -> None: ...
class UnaryOp(Expr):
def __init__(self, range: SourceRange, kind: str, expr: Expr) -> None: ...
class Const(Expr):
def __init__(self, range: SourceRange, value: str) -> None: ...
class StringLiteral(Expr):
def __init__(self, range: SourceRange, value: str) -> None: ...
class Apply(Expr):
def __init__(
self, expr: Expr, args: list[Expr], kwargs: list[Attribute]
) -> None: ...
class Select(Expr):
def __init__(self, expr: Expr, field: Ident) -> None: ...
class TernaryIf(Expr):
def __init__(self, cond: Expr, true_expr: Expr, false_expr: Expr) -> None: ...
class ListComp(Expr):
def __init__(
self, range: SourceRange, elt: Expr, target: Expr, iter: Expr
) -> None: ...
class DictComp(Expr):
def __init__(
self, range: SourceRange, key: Expr, value: Expr, target: Expr, iter: Expr
) -> None: ...
class ListLiteral(Expr):
def __init__(self, range: SourceRange, args: list[Expr]) -> None: ...
class TupleLiteral(Expr):
def __init__(self, range: SourceRange, args: list[Expr]) -> None: ...
class DictLiteral(Expr):
def __init__(
self, range: SourceRange, keys: list[Expr], values: list[Expr]
) -> None: ...
class Subscript(Expr):
def __init__(self, base: Expr, subscript_exprs: list[Expr]) -> None: ...
class SliceExpr(Expr):
def __init__(
self,
range: SourceRange,
lower: Optional[Expr],
upper: Optional[Expr],
step: Optional[Expr],
) -> None: ...
class Starred(Expr):
def __init__(self, range: SourceRange, expr: Expr) -> None: ...
class EmptyTypeAnnotation(TreeView):
def __init__(self, range: SourceRange) -> None: ...

View File

@ -438,7 +438,11 @@ def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=No
is_method = self_name is not None
if type_line is not None:
type_comment_decl = torch._C.parse_type_comment(type_line)
decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
decl = torch._C.merge_type_from_type_comment(
decl, # type: ignore[arg-type]
type_comment_decl,
is_method, # type: ignore[assignment]
)
return Def(Ident(r, def_name), decl, build_stmts(ctx, body))
@ -1055,12 +1059,12 @@ class ExprBuilder(Builder):
in_expr = BinOp("in", lhs, rhs)
cmp_expr = UnaryOp(r, "not", in_expr)
else:
cmp_expr = BinOp(op_token, lhs, rhs)
cmp_expr = BinOp(op_token, lhs, rhs) # type: ignore[assignment]
if result is None:
result = cmp_expr
else:
result = BinOp("and", result, cmp_expr)
result = BinOp("and", result, cmp_expr) # type: ignore[assignment]
return result
@staticmethod