mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Clean typing in codegen/common.py and codecache.py (#150767)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150767 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
27f7b65a69
commit
8568dbce1d
@ -12,7 +12,6 @@ import operator
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto, Enum
|
||||
from itertools import chain
|
||||
@ -27,7 +26,7 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import TypeVar
|
||||
from typing_extensions import Self, TypeVar
|
||||
|
||||
import sympy
|
||||
|
||||
@ -408,7 +407,7 @@ def get_backend_features(
|
||||
if isinstance(device, torch.device):
|
||||
device_type = device.type
|
||||
else:
|
||||
assert isinstance(device, str)
|
||||
assert isinstance(device, str), type(device)
|
||||
device_type = device
|
||||
device = torch.device(device_type)
|
||||
scheduling_ctor = get_scheduling_for_device(device_type)
|
||||
@ -538,7 +537,7 @@ def register_device_op_overrides(
|
||||
|
||||
|
||||
def get_device_op_overrides(device: str) -> DeviceOpOverrides:
|
||||
assert isinstance(device, str)
|
||||
assert isinstance(device, str), type(device)
|
||||
|
||||
if not device_op_overrides_dict:
|
||||
from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
|
||||
@ -621,7 +620,7 @@ def check_dtype(
|
||||
elif config.test_configs.static_cpp_dtype_assert and backend == "cpp":
|
||||
from .cpp_utils import CppCSEVariable, DTYPE_TO_CPP
|
||||
|
||||
assert isinstance(var, CppCSEVariable)
|
||||
assert isinstance(var, CppCSEVariable), type(var)
|
||||
if dtype == torch.bool:
|
||||
if var.is_vec:
|
||||
is_same_dt = f"IsVecMaskType<decltype({var})>::value"
|
||||
@ -682,9 +681,11 @@ class DataTypePropagation:
|
||||
return None
|
||||
|
||||
if node.target == operator.getitem:
|
||||
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
|
||||
node_arg = node.args[0]
|
||||
assert isinstance(node_arg, torch.fx.Node), type(node_arg)
|
||||
return self.deduce_node_dtype(node_arg)
|
||||
|
||||
assert isinstance(node.target, str)
|
||||
assert isinstance(node.target, str), type(node.target)
|
||||
|
||||
if node.target.startswith("masked_subblock"):
|
||||
return self.deduce_node_dtype_by_subgraph(node)
|
||||
@ -730,8 +731,8 @@ class DataTypePropagation:
|
||||
from ..loop_body import LoopBody
|
||||
from ..scheduler import SchedulerNode
|
||||
|
||||
assert isinstance(node, SchedulerNode)
|
||||
assert isinstance(node._body, LoopBody)
|
||||
assert isinstance(node, SchedulerNode), type(node)
|
||||
assert isinstance(node._body, LoopBody), type(node._body)
|
||||
return DataTypePropagation.propagate_loopbody(node._body)
|
||||
|
||||
|
||||
@ -1428,7 +1429,7 @@ class KernelArgs:
|
||||
def make_inplace(self, input_name: str, output_name: str) -> None:
|
||||
if input_name in V.graph.unaligned_buffers:
|
||||
V.graph.unaligned_buffers.add(output_name)
|
||||
assert output_name not in self.inplace_buffers
|
||||
assert output_name not in self.inplace_buffers, output_name
|
||||
if input_name in self.inplace_buffers:
|
||||
buf = self.inplace_buffers[input_name]
|
||||
assert not isinstance(buf, RemovedArg)
|
||||
@ -1490,7 +1491,7 @@ class KernelArgs:
|
||||
assert (
|
||||
existing_arg.inner_name != arg.inner_name
|
||||
and existing_arg.outer_name != arg.outer_name
|
||||
)
|
||||
), existing_arg
|
||||
self.workspace_args.append(arg)
|
||||
return arg.inner_name, 0
|
||||
|
||||
@ -1518,7 +1519,7 @@ class KernelArgs:
|
||||
)
|
||||
for existing_arg in self.workspace_args:
|
||||
if existing_arg.inner_name == arg.inner_name:
|
||||
assert arg == existing_arg
|
||||
assert arg == existing_arg, (arg, existing_arg)
|
||||
self.workspace_args.append(arg)
|
||||
return arg.inner_name
|
||||
|
||||
@ -1618,7 +1619,7 @@ class KernelArgs:
|
||||
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]:
|
||||
arg_defs: list[ArgName] = []
|
||||
call_args: list[str] = []
|
||||
arg_types: list[torch.dtype] = []
|
||||
arg_types: list[Any] = []
|
||||
precompile_args: list[KernelArgType] = []
|
||||
for inplaced in unique(self.inplace_buffers.values()):
|
||||
if isinstance(inplaced, RemovedArg):
|
||||
@ -1651,7 +1652,7 @@ class KernelArgs:
|
||||
for outer, inner in self.sizevars.items():
|
||||
arg_defs.append(ArgName(inner))
|
||||
call_args.append(outer)
|
||||
arg_types.append(type(outer)) # type: ignore[arg-type]
|
||||
arg_types.append(type(outer))
|
||||
precompile_args.append(SizeArg(inner, outer))
|
||||
if V.graph.wrapper_code:
|
||||
V.graph.wrapper_code.ensure_size_computed(outer)
|
||||
@ -1686,7 +1687,7 @@ class KernelArgs:
|
||||
# after you do a call into this kernel, which buffers actually contain
|
||||
# updated data? Modeled off of python_argdefs.
|
||||
def live_output_buffers(self) -> OrderedSet[str]:
|
||||
live_outs = OrderedSet() # type: ignore[var-annotated]
|
||||
live_outs = OrderedSet[str]()
|
||||
for inplaced in unique(self.inplace_buffers.values()):
|
||||
if isinstance(inplaced, RemovedArg):
|
||||
continue
|
||||
@ -1712,7 +1713,7 @@ class CSEVariable:
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(bounds, ValueRanges)
|
||||
assert isinstance(bounds, ValueRanges), type(bounds)
|
||||
self.name = name
|
||||
self.bounds = bounds
|
||||
self.use_count = 1 # track how many times this expression is used
|
||||
@ -1782,7 +1783,7 @@ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
|
||||
else:
|
||||
self._cache = {}
|
||||
|
||||
def clone(self) -> typing.Self:
|
||||
def clone(self) -> Self:
|
||||
return type(self)(
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
@ -1793,7 +1794,7 @@ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
|
||||
reduction_cache=self.reduction_cache,
|
||||
)
|
||||
|
||||
def scoped_copy(self) -> typing.Self:
|
||||
def scoped_copy(self) -> Self:
|
||||
"""Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
|
||||
new_cse = self.clone()
|
||||
new_cse._cache = ScopedDict(self._cache)
|
||||
@ -1918,7 +1919,7 @@ class CodeGen:
|
||||
super().__init__()
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
|
||||
def __enter__(self) -> typing.Self:
|
||||
def __enter__(self) -> Self:
|
||||
self.exit_stack.__enter__()
|
||||
return self
|
||||
|
||||
@ -2084,7 +2085,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
) -> str:
|
||||
if isinstance(var, CSEVariable):
|
||||
var = str(var)
|
||||
assert isinstance(var, str)
|
||||
assert isinstance(var, str), type(var)
|
||||
assert lower is None or isinstance(lower, str)
|
||||
assert upper is None or isinstance(upper, str)
|
||||
if lower and upper:
|
||||
@ -2113,7 +2114,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
def index_to_str(self, index: sympy.Expr) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self) -> typing.Self:
|
||||
def __enter__(self) -> Self:
|
||||
super().__enter__()
|
||||
assert self.overrides
|
||||
self.exit_stack.enter_context(
|
||||
@ -2184,7 +2185,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
# adds the necessary kernel args for index expressions
|
||||
# and renames variables in index expressions to kernel arg names
|
||||
if isinstance(index, (list, tuple)):
|
||||
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
|
||||
return [self.rename_indexing(x) for x in index]
|
||||
index = V.graph.sizevars.simplify(index)
|
||||
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
|
||||
replacements = {
|
||||
@ -2362,7 +2363,7 @@ class CSEProxy(DefaultHandler):
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
bounds = self._bound_variable(name, *args, **kwargs)
|
||||
|
||||
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
||||
value = getattr(self.parent_handler, name)(*args, **kwargs)
|
||||
dtype_handler = DtypePropagationOpsHandler()
|
||||
|
||||
backend = get_current_backend()
|
||||
@ -2387,8 +2388,8 @@ class CSEProxy(DefaultHandler):
|
||||
def do_cse(v: str) -> CSEVariable:
|
||||
# we tree_map over the output, so we need to fetch corresponding dtype
|
||||
nonlocal output_idx
|
||||
var_dtype: torch.dtype = (
|
||||
output_dtype[output_idx] # type: ignore[assignment]
|
||||
var_dtype: Optional[torch.dtype] = (
|
||||
output_dtype[output_idx]
|
||||
if isinstance(output_dtype, (list, tuple))
|
||||
else output_dtype
|
||||
)
|
||||
@ -2411,6 +2412,7 @@ class CSEProxy(DefaultHandler):
|
||||
config.test_configs.runtime_triton_dtype_assert
|
||||
or config.test_configs.static_cpp_dtype_assert
|
||||
):
|
||||
assert var_dtype is not None
|
||||
check_dtype(V.kernel.compute, csevar, var_dtype)
|
||||
return csevar
|
||||
|
||||
@ -2433,7 +2435,9 @@ class CSEProxy(DefaultHandler):
|
||||
|
||||
fx_node = V.interpreter.current_node
|
||||
if fx_node.target == name and self.kernel.node_to_bounds is not None:
|
||||
assert isinstance(self.kernel.node_to_bounds, dict)
|
||||
assert isinstance(self.kernel.node_to_bounds, dict), type(
|
||||
self.kernel.node_to_bounds
|
||||
)
|
||||
return self.kernel.node_to_bounds.get(fx_node, ValueRanges.unknown())
|
||||
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
|
||||
# These create lots of inner strings. We would need to compute the bounds at the ops
|
||||
@ -2468,14 +2472,14 @@ class CSEProxy(DefaultHandler):
|
||||
) -> sympy.Symbol:
|
||||
if isinstance(size, int):
|
||||
size = sympy.Integer(size)
|
||||
assert isinstance(size, sympy.Expr), size
|
||||
assert isinstance(size, sympy.Expr), (type(size), size)
|
||||
# Skip CSE since this doesn't return an expression
|
||||
|
||||
if var.bounds.lower < 0: # type: ignore[operator]
|
||||
if var.bounds.lower < 0:
|
||||
if wrap_neg:
|
||||
stm = ops.add(var, ops.index_expr(size, torch.long))
|
||||
# Mixed negative and non-negative
|
||||
if var.bounds.upper >= 0: # type: ignore[operator]
|
||||
if var.bounds.upper >= 0:
|
||||
lt = ops.lt(var, 0)
|
||||
stm = ops.where(lt, stm, var)
|
||||
else:
|
||||
@ -2492,7 +2496,7 @@ class CSEProxy(DefaultHandler):
|
||||
neg_bounds.lower + size, neg_bounds.upper + size
|
||||
)
|
||||
# We don't have a good way of representing the empty range
|
||||
if var.bounds.upper >= 0: # type: ignore[operator]
|
||||
if var.bounds.upper >= 0:
|
||||
pos = var.bounds & ValueRanges(0, int_oo)
|
||||
new_bounds = new_bounds | pos
|
||||
|
||||
@ -2544,8 +2548,7 @@ class CSEProxy(DefaultHandler):
|
||||
if mode is None:
|
||||
self._update_store_cache(name, value)
|
||||
if name not in V.graph.removed_buffers:
|
||||
return self.kernel.store(name, index, value, mode=mode)
|
||||
return None # type: ignore[return-value]
|
||||
self.kernel.store(name, index, value, mode=mode)
|
||||
|
||||
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
|
||||
self.kernel.store_buffer_names.add(name)
|
||||
|
Reference in New Issue
Block a user