Revert "Some minor type stub improvements (#118529)"

This reverts commit c978f38bd4aedeff4ee9ae693349217daea01412.

Reverted https://github.com/pytorch/pytorch/pull/118529 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/118529#issuecomment-1922362331))
This commit is contained in:
PyTorch MergeBot
2024-02-01 22:18:36 +00:00
parent d4a94ad041
commit dbba1d4bf5
9 changed files with 63 additions and 90 deletions

View File

@ -279,12 +279,6 @@ class SymInt:
def __ge__(self, other) -> builtins.bool:
raise AssertionError("type stub not overridden")
def __add__(self, other) -> "SymInt":
raise AssertionError("type stub not overridden")
def __mul__(self, other) -> "SymInt":
raise AssertionError("type stub not overridden")
def __sym_max__(self, other):
raise AssertionError("type stub not overridden")

View File

@ -2066,18 +2066,18 @@ def defake(x):
size: "torch._prims_common.ShapeType"
stride: "torch._prims_common.StrideType"
if x._has_symbolic_sizes_strides:
size = []
for s in x.size():
if isinstance(s, torch.SymInt):
size.append(s.node.shape_env.size_hint(s.node.expr))
else:
size.append(s)
stride = []
for s in x.stride():
if isinstance(s, torch.SymInt):
stride.append(s.node.shape_env.size_hint(s.node.expr))
else:
stride.append(s)
size = [
s.node.shape_env.size_hint(s.node.expr)
if isinstance(s, torch.SymInt)
else s
for s in x.size()
]
stride = [
s.node.shape_env.size_hint(s.node.expr)
if isinstance(s, torch.SymInt)
else s
for s in x.stride()
]
else:
size = x.size()
stride = x.stride()

View File

@ -24,7 +24,7 @@ class BoundVars:
def __init__(self, loop_body: LoopBody) -> None:
self.loop_body = loop_body
self.replacement_vals = {
k: ValueRanges[Expr](0, v - 1)
k: ValueRanges(0, v - 1)
if (isinstance(v, int) or v.is_number)
else bound_sympy(v)
for k, v in loop_body.var_ranges.items()
@ -37,10 +37,10 @@ class BoundVars:
or "masked_subblock" in node.target
)
# To access this variable call `get_bounds()`
self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {}
self._bounds: Dict[torch.fx.Node, ValueRanges] = {}
@cache_on_self
def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]:
def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges]:
submodules = self.swap_submodules(self.loop_body.submodules)
# Initialize the environment with the unbounded variables
@ -50,7 +50,7 @@ class BoundVars:
"masked_subblock" not in node.target
and "set_indirect" not in node.target
):
self._bounds[node] = ValueRanges[Expr].unknown()
self._bounds[node] = ValueRanges.unknown()
with V.set_ops_handler(ValueRangeAnalysis()):
interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules)
@ -59,8 +59,8 @@ class BoundVars:
def swap_submodules(
self, submodules: Dict[str, Callable[..., Any]]
) -> Dict[str, Callable[..., ValueRanges[Expr]]]:
result: Dict[str, Callable[..., ValueRanges[Expr]]] = {}
) -> Dict[str, Callable[..., ValueRanges]]:
result: Dict[str, Callable[..., ValueRanges]] = {}
for key in submodules.keys():
if key == "get_index":
result[key] = self.get_index
@ -94,11 +94,11 @@ class BoundVars:
def masked_subblock(
self,
subblock: LoopBodyBlock,
env: Dict[torch.fx.Node, ValueRanges[Expr]],
env: Dict[torch.fx.Node, ValueRanges],
mask: Any,
value: Any,
submodules: Dict[str, Callable[..., Any]],
) -> ValueRanges[Expr]:
) -> ValueRanges:
interp = InterpreterShim(subblock.graph, submodules)
interp.run(V.get_ops_handler(), initial_env=env)
output = [node for node in subblock.graph.nodes if node.target == "output"]
@ -107,12 +107,12 @@ class BoundVars:
# pessimistically assumed to be inf anyway
return interp.env[output[0]]
def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]:
def set_indirect(self, old: Expr, new: ValueRanges) -> ValueRanges:
assert isinstance(new, ValueRanges)
self.replacement_vals[old] = new
return new
def get_index(self, name: Expr) -> ValueRanges[Expr]:
def get_index(self, name: Expr) -> ValueRanges:
expr = self.loop_body.indexing_exprs[name]
bound = self.replacement_vals.get(expr)
if bound is None:

View File

@ -797,7 +797,7 @@ class CSEVariable:
See example of TritonCSEVariable in triton.py
"""
def __init__(self, name, bounds: ValueRanges[Any]):
def __init__(self, name, bounds: ValueRanges):
assert isinstance(bounds, ValueRanges)
self.name = name
self.bounds = bounds
@ -876,7 +876,7 @@ class CSE:
buffer: IndentedBuffer,
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
*,
bounds: ValueRanges[Any] = ValueRanges.unknown(),
bounds: ValueRanges = ValueRanges.unknown(),
write=True,
assignment=True,
) -> CSEVariable:
@ -917,7 +917,7 @@ class CSE:
return var
def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
def newvar(self, bounds: ValueRanges = ValueRanges.unknown()) -> CSEVariable:
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
var = V.kernel.create_cse_var(var_name, bounds)
self.varname_map[var_name] = var
@ -1002,7 +1002,7 @@ class Kernel(CodeGen):
self._load_mask = None
# set in set_current_node
self.current_node = None
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges]] = None
# Upper bounds for indirect_indexing and their str representation
# NB: None, None is never stored in map, but it is the assumed
# "not set" value for the dict

View File

@ -7,7 +7,7 @@ import math
import re
import sys
from copy import copy, deepcopy
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, Union
import sympy
@ -568,7 +568,7 @@ def get_current_node_opt_ctx() -> OptimizationContext:
class CppCSEVariable(CSEVariable):
def __init__(self, name, bounds: ValueRanges[Any]):
def __init__(self, name, bounds: ValueRanges):
super().__init__(name, bounds)
self.is_vec = False
self.dtype: Optional[torch.dtype] = None

View File

@ -425,7 +425,7 @@ def triton_constant(value):
class TritonCSEVariable(CSEVariable):
def __init__(self, name, bounds: ValueRanges[Any]):
def __init__(self, name, bounds: ValueRanges):
super().__init__(name, bounds)
# We'll use this to track which masks the variable needs when used for indirect indexing
self.mask_vars: Set[str] = set()

View File

@ -71,7 +71,7 @@ def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_va
# TODO - not sure if we should be doing int/float casts while tracing,
# might interfere with sympy.
index_val_int = ValueRanges[sympy.Expr](
index_val_int = ValueRanges(
int(index_val.lower), int(index_val.upper)
)
if not range_expressable_in_32_bits(index_val_int):

View File

@ -93,7 +93,7 @@ CURRENT_NODE_KEY = "current_node"
# These are modules that contain generic code for interacting with ShapeEnv
# which are unlikely to identify a particular interesting guard statement
@lru_cache(None)
def uninteresting_files() -> Set[str]:
def uninteresting_files():
import torch._inductor.sizevars
import torch._library.abstract_impl
import torch._subclasses.meta_utils
@ -120,18 +120,16 @@ def uninteresting_files() -> Set[str]:
class ConstraintViolationError(RuntimeError):
pass
def has_symbolic_sizes_strides(elem) -> bool:
def has_symbolic_sizes_strides(elem):
return elem._has_symbolic_sizes_strides
Int = Union[torch.SymInt, int]
def create_contiguous(shape: Sequence[Int]) -> List[Int]:
strides: List[Int] = [1]
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int:
def hint_int(a, fallback=None):
"""
Retrieve the hint for an int (based on the underlying real values as observed
at runtime). If no hint is available (e.g., because data dependent shapes),
@ -142,14 +140,12 @@ def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int
assert type(a) is int, a
return a
Scalar = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool]
def has_hint(a: Scalar) -> bool:
def has_hint(a):
if isinstance(a, SymTypes):
return a.node.has_hint()
return True
def is_concrete_int(a: Union[int, SymInt]) -> bool:
def is_concrete_int(a: Union[int, SymInt]):
r""" Utility to check if underlying object
in SymInt is concrete value. Also returns
true if integer is passed in.
@ -167,9 +163,7 @@ def is_concrete_int(a: Union[int, SymInt]) -> bool:
return False
SympyBoolean = sympy.logic.boolalg.Boolean
def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
def canonicalize_bool_expr(expr: sympy.Expr):
r""" Canonicalize a boolean expression by transforming it into a lt / le
inequality and moving all the non-constant terms to the rhs.
We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr
@ -192,7 +186,7 @@ def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
expr = sympy.logic.boolalg.to_cnf(expr)
return _canonicalize_bool_expr_impl(expr)
def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
def _canonicalize_bool_expr_impl(expr: sympy.Expr):
if isinstance(expr, (sympy.And, sympy.Or)):
return type(expr)(*map(canonicalize_bool_expr, expr.args))
@ -217,7 +211,7 @@ def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
rhs = -sympy.Add(*cts)
return t(lhs, rhs)
def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
def is_concrete_bool(a: Union[bool, SymBool]):
r""" Utility to check if underlying object
in SymBool is concrete value. Also returns
true if integer is passed in.
@ -234,7 +228,7 @@ def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
return False
def is_singleton(s: Int) -> bool:
def is_singleton(s):
# check for SingletonSymNode
if not isinstance(s, torch.SymInt):
return False

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import dataclasses
import itertools
import sympy
@ -8,8 +6,7 @@ import operator
import math
import logging
import torch
from typing import Dict, Optional, SupportsFloat, TypeVar, Generic, cast, Union
from typing_extensions import TypeGuard
from typing import Union, Dict, Optional, SupportsFloat
from torch._prims_common import dtype_to_type
from .interp import sympy_interp
@ -19,8 +16,6 @@ log = logging.getLogger(__name__)
__all__ = ["ValueRanges", "ValueRangeAnalysis", "bound_sympy"]
_T = TypeVar('_T', sympy.Expr, SympyBoolean)
class ValueRangeError(RuntimeError):
pass
@ -62,24 +57,16 @@ def sympy_generic_le(lower, upper):
return not (lower and not upper)
def vr_is_bool(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[SympyBoolean]]:
return vr.is_bool
def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]:
return not vr.is_bool
@dataclasses.dataclass(frozen=True)
class ValueRanges(Generic[_T]):
class ValueRanges:
# Although the type signature here suggests you can pass any
# sympy expression, in practice the analysis here only works
# with constant sympy expressions
lower: _T
upper: _T
lower: Union[sympy.Expr, SympyBoolean]
upper: Union[sympy.Expr, SympyBoolean]
is_bool: bool
def __init__(self, lower: Union[_T, bool, int, float], upper: Union[_T, bool, int, float]) -> None:
def __init__(self, lower, upper):
lower = simple_sympify(lower)
upper = simple_sympify(upper)
# TODO: when the bounds have free variables, this may be
@ -104,47 +91,45 @@ class ValueRanges(Generic[_T]):
x = simple_sympify(x)
return sympy_generic_le(self.lower, x) and sympy_generic_le(x, self.upper)
def tighten(self, other) -> ValueRanges:
def tighten(self, other) -> "ValueRanges":
"""Given two ValueRanges, returns their intersection"""
return self & other
# Intersection
def __and__(self: ValueRanges[_T], other: ValueRanges[_T]) -> ValueRanges[_T]:
def __and__(self, other) -> "ValueRanges":
if other == ValueRanges.unknown():
return self
if self == ValueRanges.unknown():
return other
assert self.is_bool == other.is_bool, (self, other)
if vr_is_bool(self):
return cast(ValueRanges[_T], ValueRanges(sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper)))
elif vr_is_expr(self):
return cast(ValueRanges[_T], ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper)))
if self.is_bool:
range = ValueRanges(sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper))
else:
raise AssertionError("impossible")
range = ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper))
return range
# Union
def __or__(self, other) -> ValueRanges:
def __or__(self, other) -> "ValueRanges":
if ValueRanges.unknown() in (self, other):
return ValueRanges.unknown()
assert self.is_bool == other.is_bool, (self, other)
if vr_is_bool(self):
return cast(ValueRanges[_T], ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper)))
elif vr_is_expr(self):
return cast(ValueRanges[_T], ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper)))
if self.is_bool:
range = ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper))
else:
raise AssertionError("impossible")
range = ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper))
return range
def is_singleton(self) -> bool:
return self.lower == self.upper
# TODO: this doesn't work with bools but arguably it should
@staticmethod
def unknown() -> ValueRanges[sympy.Expr]:
return ValueRanges(-sympy.oo, sympy.oo)
@classmethod
def unknown(cls):
return cls(-sympy.oo, sympy.oo)
@staticmethod
def unknown_bool() -> ValueRanges[SympyBoolean]:
return ValueRanges(sympy.false, sympy.true)
@classmethod
def unknown_bool(cls):
return cls(sympy.false, sympy.true)
@classmethod
def wrap(cls, arg):