mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Some minor type stub improvements (#118529)"
This reverts commit c978f38bd4aedeff4ee9ae693349217daea01412. Reverted https://github.com/pytorch/pytorch/pull/118529 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/118529#issuecomment-1922362331))
This commit is contained in:
@ -279,12 +279,6 @@ class SymInt:
|
||||
def __ge__(self, other) -> builtins.bool:
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __add__(self, other) -> "SymInt":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __mul__(self, other) -> "SymInt":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __sym_max__(self, other):
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
|
@ -2066,18 +2066,18 @@ def defake(x):
|
||||
size: "torch._prims_common.ShapeType"
|
||||
stride: "torch._prims_common.StrideType"
|
||||
if x._has_symbolic_sizes_strides:
|
||||
size = []
|
||||
for s in x.size():
|
||||
if isinstance(s, torch.SymInt):
|
||||
size.append(s.node.shape_env.size_hint(s.node.expr))
|
||||
else:
|
||||
size.append(s)
|
||||
stride = []
|
||||
for s in x.stride():
|
||||
if isinstance(s, torch.SymInt):
|
||||
stride.append(s.node.shape_env.size_hint(s.node.expr))
|
||||
else:
|
||||
stride.append(s)
|
||||
size = [
|
||||
s.node.shape_env.size_hint(s.node.expr)
|
||||
if isinstance(s, torch.SymInt)
|
||||
else s
|
||||
for s in x.size()
|
||||
]
|
||||
stride = [
|
||||
s.node.shape_env.size_hint(s.node.expr)
|
||||
if isinstance(s, torch.SymInt)
|
||||
else s
|
||||
for s in x.stride()
|
||||
]
|
||||
else:
|
||||
size = x.size()
|
||||
stride = x.stride()
|
||||
|
@ -24,7 +24,7 @@ class BoundVars:
|
||||
def __init__(self, loop_body: LoopBody) -> None:
|
||||
self.loop_body = loop_body
|
||||
self.replacement_vals = {
|
||||
k: ValueRanges[Expr](0, v - 1)
|
||||
k: ValueRanges(0, v - 1)
|
||||
if (isinstance(v, int) or v.is_number)
|
||||
else bound_sympy(v)
|
||||
for k, v in loop_body.var_ranges.items()
|
||||
@ -37,10 +37,10 @@ class BoundVars:
|
||||
or "masked_subblock" in node.target
|
||||
)
|
||||
# To access this variable call `get_bounds()`
|
||||
self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {}
|
||||
self._bounds: Dict[torch.fx.Node, ValueRanges] = {}
|
||||
|
||||
@cache_on_self
|
||||
def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]:
|
||||
def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges]:
|
||||
submodules = self.swap_submodules(self.loop_body.submodules)
|
||||
|
||||
# Initialize the environment with the unbounded variables
|
||||
@ -50,7 +50,7 @@ class BoundVars:
|
||||
"masked_subblock" not in node.target
|
||||
and "set_indirect" not in node.target
|
||||
):
|
||||
self._bounds[node] = ValueRanges[Expr].unknown()
|
||||
self._bounds[node] = ValueRanges.unknown()
|
||||
|
||||
with V.set_ops_handler(ValueRangeAnalysis()):
|
||||
interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules)
|
||||
@ -59,8 +59,8 @@ class BoundVars:
|
||||
|
||||
def swap_submodules(
|
||||
self, submodules: Dict[str, Callable[..., Any]]
|
||||
) -> Dict[str, Callable[..., ValueRanges[Expr]]]:
|
||||
result: Dict[str, Callable[..., ValueRanges[Expr]]] = {}
|
||||
) -> Dict[str, Callable[..., ValueRanges]]:
|
||||
result: Dict[str, Callable[..., ValueRanges]] = {}
|
||||
for key in submodules.keys():
|
||||
if key == "get_index":
|
||||
result[key] = self.get_index
|
||||
@ -94,11 +94,11 @@ class BoundVars:
|
||||
def masked_subblock(
|
||||
self,
|
||||
subblock: LoopBodyBlock,
|
||||
env: Dict[torch.fx.Node, ValueRanges[Expr]],
|
||||
env: Dict[torch.fx.Node, ValueRanges],
|
||||
mask: Any,
|
||||
value: Any,
|
||||
submodules: Dict[str, Callable[..., Any]],
|
||||
) -> ValueRanges[Expr]:
|
||||
) -> ValueRanges:
|
||||
interp = InterpreterShim(subblock.graph, submodules)
|
||||
interp.run(V.get_ops_handler(), initial_env=env)
|
||||
output = [node for node in subblock.graph.nodes if node.target == "output"]
|
||||
@ -107,12 +107,12 @@ class BoundVars:
|
||||
# pessimistically assumed to be inf anyway
|
||||
return interp.env[output[0]]
|
||||
|
||||
def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]:
|
||||
def set_indirect(self, old: Expr, new: ValueRanges) -> ValueRanges:
|
||||
assert isinstance(new, ValueRanges)
|
||||
self.replacement_vals[old] = new
|
||||
return new
|
||||
|
||||
def get_index(self, name: Expr) -> ValueRanges[Expr]:
|
||||
def get_index(self, name: Expr) -> ValueRanges:
|
||||
expr = self.loop_body.indexing_exprs[name]
|
||||
bound = self.replacement_vals.get(expr)
|
||||
if bound is None:
|
||||
|
@ -797,7 +797,7 @@ class CSEVariable:
|
||||
See example of TritonCSEVariable in triton.py
|
||||
"""
|
||||
|
||||
def __init__(self, name, bounds: ValueRanges[Any]):
|
||||
def __init__(self, name, bounds: ValueRanges):
|
||||
assert isinstance(bounds, ValueRanges)
|
||||
self.name = name
|
||||
self.bounds = bounds
|
||||
@ -876,7 +876,7 @@ class CSE:
|
||||
buffer: IndentedBuffer,
|
||||
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
|
||||
*,
|
||||
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
||||
bounds: ValueRanges = ValueRanges.unknown(),
|
||||
write=True,
|
||||
assignment=True,
|
||||
) -> CSEVariable:
|
||||
@ -917,7 +917,7 @@ class CSE:
|
||||
|
||||
return var
|
||||
|
||||
def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
|
||||
def newvar(self, bounds: ValueRanges = ValueRanges.unknown()) -> CSEVariable:
|
||||
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
|
||||
var = V.kernel.create_cse_var(var_name, bounds)
|
||||
self.varname_map[var_name] = var
|
||||
@ -1002,7 +1002,7 @@ class Kernel(CodeGen):
|
||||
self._load_mask = None
|
||||
# set in set_current_node
|
||||
self.current_node = None
|
||||
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
|
||||
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges]] = None
|
||||
# Upper bounds for indirect_indexing and their str representation
|
||||
# NB: None, None is never stored in map, but it is the assumed
|
||||
# "not set" value for the dict
|
||||
|
@ -7,7 +7,7 @@ import math
|
||||
import re
|
||||
import sys
|
||||
from copy import copy, deepcopy
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -568,7 +568,7 @@ def get_current_node_opt_ctx() -> OptimizationContext:
|
||||
|
||||
|
||||
class CppCSEVariable(CSEVariable):
|
||||
def __init__(self, name, bounds: ValueRanges[Any]):
|
||||
def __init__(self, name, bounds: ValueRanges):
|
||||
super().__init__(name, bounds)
|
||||
self.is_vec = False
|
||||
self.dtype: Optional[torch.dtype] = None
|
||||
|
@ -425,7 +425,7 @@ def triton_constant(value):
|
||||
|
||||
|
||||
class TritonCSEVariable(CSEVariable):
|
||||
def __init__(self, name, bounds: ValueRanges[Any]):
|
||||
def __init__(self, name, bounds: ValueRanges):
|
||||
super().__init__(name, bounds)
|
||||
# We'll use this to track which masks the variable needs when used for indirect indexing
|
||||
self.mask_vars: Set[str] = set()
|
||||
|
@ -71,7 +71,7 @@ def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_va
|
||||
# TODO - not sure if we should be doing int/float casts while tracing,
|
||||
# might interfere with sympy.
|
||||
|
||||
index_val_int = ValueRanges[sympy.Expr](
|
||||
index_val_int = ValueRanges(
|
||||
int(index_val.lower), int(index_val.upper)
|
||||
)
|
||||
if not range_expressable_in_32_bits(index_val_int):
|
||||
|
@ -93,7 +93,7 @@ CURRENT_NODE_KEY = "current_node"
|
||||
# These are modules that contain generic code for interacting with ShapeEnv
|
||||
# which are unlikely to identify a particular interesting guard statement
|
||||
@lru_cache(None)
|
||||
def uninteresting_files() -> Set[str]:
|
||||
def uninteresting_files():
|
||||
import torch._inductor.sizevars
|
||||
import torch._library.abstract_impl
|
||||
import torch._subclasses.meta_utils
|
||||
@ -120,18 +120,16 @@ def uninteresting_files() -> Set[str]:
|
||||
class ConstraintViolationError(RuntimeError):
|
||||
pass
|
||||
|
||||
def has_symbolic_sizes_strides(elem) -> bool:
|
||||
def has_symbolic_sizes_strides(elem):
|
||||
return elem._has_symbolic_sizes_strides
|
||||
|
||||
Int = Union[torch.SymInt, int]
|
||||
|
||||
def create_contiguous(shape: Sequence[Int]) -> List[Int]:
|
||||
strides: List[Int] = [1]
|
||||
def create_contiguous(shape):
|
||||
strides = [1]
|
||||
for dim in reversed(shape[:-1]):
|
||||
strides.append(dim * strides[-1])
|
||||
return list(reversed(strides))
|
||||
|
||||
def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int:
|
||||
def hint_int(a, fallback=None):
|
||||
"""
|
||||
Retrieve the hint for an int (based on the underlying real values as observed
|
||||
at runtime). If no hint is available (e.g., because data dependent shapes),
|
||||
@ -142,14 +140,12 @@ def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int
|
||||
assert type(a) is int, a
|
||||
return a
|
||||
|
||||
Scalar = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool]
|
||||
|
||||
def has_hint(a: Scalar) -> bool:
|
||||
def has_hint(a):
|
||||
if isinstance(a, SymTypes):
|
||||
return a.node.has_hint()
|
||||
return True
|
||||
|
||||
def is_concrete_int(a: Union[int, SymInt]) -> bool:
|
||||
def is_concrete_int(a: Union[int, SymInt]):
|
||||
r""" Utility to check if underlying object
|
||||
in SymInt is concrete value. Also returns
|
||||
true if integer is passed in.
|
||||
@ -167,9 +163,7 @@ def is_concrete_int(a: Union[int, SymInt]) -> bool:
|
||||
|
||||
return False
|
||||
|
||||
SympyBoolean = sympy.logic.boolalg.Boolean
|
||||
|
||||
def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
|
||||
def canonicalize_bool_expr(expr: sympy.Expr):
|
||||
r""" Canonicalize a boolean expression by transforming it into a lt / le
|
||||
inequality and moving all the non-constant terms to the rhs.
|
||||
We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr
|
||||
@ -192,7 +186,7 @@ def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
|
||||
expr = sympy.logic.boolalg.to_cnf(expr)
|
||||
return _canonicalize_bool_expr_impl(expr)
|
||||
|
||||
def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
|
||||
def _canonicalize_bool_expr_impl(expr: sympy.Expr):
|
||||
if isinstance(expr, (sympy.And, sympy.Or)):
|
||||
return type(expr)(*map(canonicalize_bool_expr, expr.args))
|
||||
|
||||
@ -217,7 +211,7 @@ def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
|
||||
rhs = -sympy.Add(*cts)
|
||||
return t(lhs, rhs)
|
||||
|
||||
def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
|
||||
def is_concrete_bool(a: Union[bool, SymBool]):
|
||||
r""" Utility to check if underlying object
|
||||
in SymBool is concrete value. Also returns
|
||||
true if integer is passed in.
|
||||
@ -234,7 +228,7 @@ def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
|
||||
|
||||
return False
|
||||
|
||||
def is_singleton(s: Int) -> bool:
|
||||
def is_singleton(s):
|
||||
# check for SingletonSymNode
|
||||
if not isinstance(s, torch.SymInt):
|
||||
return False
|
||||
|
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
import sympy
|
||||
@ -8,8 +6,7 @@ import operator
|
||||
import math
|
||||
import logging
|
||||
import torch
|
||||
from typing import Dict, Optional, SupportsFloat, TypeVar, Generic, cast, Union
|
||||
from typing_extensions import TypeGuard
|
||||
from typing import Union, Dict, Optional, SupportsFloat
|
||||
|
||||
from torch._prims_common import dtype_to_type
|
||||
from .interp import sympy_interp
|
||||
@ -19,8 +16,6 @@ log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["ValueRanges", "ValueRangeAnalysis", "bound_sympy"]
|
||||
|
||||
_T = TypeVar('_T', sympy.Expr, SympyBoolean)
|
||||
|
||||
class ValueRangeError(RuntimeError):
|
||||
pass
|
||||
|
||||
@ -62,24 +57,16 @@ def sympy_generic_le(lower, upper):
|
||||
return not (lower and not upper)
|
||||
|
||||
|
||||
def vr_is_bool(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[SympyBoolean]]:
|
||||
return vr.is_bool
|
||||
|
||||
|
||||
def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]:
|
||||
return not vr.is_bool
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ValueRanges(Generic[_T]):
|
||||
class ValueRanges:
|
||||
# Although the type signature here suggests you can pass any
|
||||
# sympy expression, in practice the analysis here only works
|
||||
# with constant sympy expressions
|
||||
lower: _T
|
||||
upper: _T
|
||||
lower: Union[sympy.Expr, SympyBoolean]
|
||||
upper: Union[sympy.Expr, SympyBoolean]
|
||||
is_bool: bool
|
||||
|
||||
def __init__(self, lower: Union[_T, bool, int, float], upper: Union[_T, bool, int, float]) -> None:
|
||||
def __init__(self, lower, upper):
|
||||
lower = simple_sympify(lower)
|
||||
upper = simple_sympify(upper)
|
||||
# TODO: when the bounds have free variables, this may be
|
||||
@ -104,47 +91,45 @@ class ValueRanges(Generic[_T]):
|
||||
x = simple_sympify(x)
|
||||
return sympy_generic_le(self.lower, x) and sympy_generic_le(x, self.upper)
|
||||
|
||||
def tighten(self, other) -> ValueRanges:
|
||||
def tighten(self, other) -> "ValueRanges":
|
||||
"""Given two ValueRanges, returns their intersection"""
|
||||
return self & other
|
||||
|
||||
# Intersection
|
||||
def __and__(self: ValueRanges[_T], other: ValueRanges[_T]) -> ValueRanges[_T]:
|
||||
def __and__(self, other) -> "ValueRanges":
|
||||
if other == ValueRanges.unknown():
|
||||
return self
|
||||
if self == ValueRanges.unknown():
|
||||
return other
|
||||
assert self.is_bool == other.is_bool, (self, other)
|
||||
if vr_is_bool(self):
|
||||
return cast(ValueRanges[_T], ValueRanges(sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper)))
|
||||
elif vr_is_expr(self):
|
||||
return cast(ValueRanges[_T], ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper)))
|
||||
if self.is_bool:
|
||||
range = ValueRanges(sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper))
|
||||
else:
|
||||
raise AssertionError("impossible")
|
||||
range = ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper))
|
||||
return range
|
||||
|
||||
# Union
|
||||
def __or__(self, other) -> ValueRanges:
|
||||
def __or__(self, other) -> "ValueRanges":
|
||||
if ValueRanges.unknown() in (self, other):
|
||||
return ValueRanges.unknown()
|
||||
assert self.is_bool == other.is_bool, (self, other)
|
||||
if vr_is_bool(self):
|
||||
return cast(ValueRanges[_T], ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper)))
|
||||
elif vr_is_expr(self):
|
||||
return cast(ValueRanges[_T], ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper)))
|
||||
if self.is_bool:
|
||||
range = ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper))
|
||||
else:
|
||||
raise AssertionError("impossible")
|
||||
range = ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper))
|
||||
return range
|
||||
|
||||
def is_singleton(self) -> bool:
|
||||
return self.lower == self.upper
|
||||
|
||||
# TODO: this doesn't work with bools but arguably it should
|
||||
@staticmethod
|
||||
def unknown() -> ValueRanges[sympy.Expr]:
|
||||
return ValueRanges(-sympy.oo, sympy.oo)
|
||||
@classmethod
|
||||
def unknown(cls):
|
||||
return cls(-sympy.oo, sympy.oo)
|
||||
|
||||
@staticmethod
|
||||
def unknown_bool() -> ValueRanges[SympyBoolean]:
|
||||
return ValueRanges(sympy.false, sympy.true)
|
||||
@classmethod
|
||||
def unknown_bool(cls):
|
||||
return cls(sympy.false, sympy.true)
|
||||
|
||||
@classmethod
|
||||
def wrap(cls, arg):
|
||||
|
Reference in New Issue
Block a user