[inductor] Clean typing in codegen/common.py and codecache.py (#150767)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150767
Approved by: https://github.com/aorenste
This commit is contained in:
Tom Ritchford
2025-05-16 11:08:00 +00:00
committed by PyTorch MergeBot
parent 27f7b65a69
commit 8568dbce1d
2 changed files with 46 additions and 39 deletions

View File

@ -112,16 +112,16 @@ if config.is_fbcode():
)
else:
def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
def log_global_cache_errors(*args: Any, **kwargs: Any) -> None:
pass
def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
def log_global_cache_stats(*args: Any, **kwargs: Any) -> None:
pass
def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
def log_global_cache_vals(*args: Any, **kwargs: Any) -> None:
pass
def use_global_cache() -> bool: # type: ignore[misc]
def use_global_cache() -> bool:
return False
@ -2451,7 +2451,8 @@ class CppPythonBindingsCodeCache(CppCodeCache):
assert spec is not None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module) # type: ignore[union-attr]
assert spec.loader is not None
spec.loader.exec_module(module)
return module
@classmethod
@ -2945,6 +2946,7 @@ def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
job()
except subprocess.SubprocessError as e:
if os.environ.get("HALIDE_REPRO") == "1":
cmd: list[Any]
python, script, *cmd = getattr(e, "cmd", ("", "", ""))
if os.path.basename(python).startswith("python"):
code = open(script).read()
@ -2955,7 +2957,9 @@ def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
def __repr__(self) -> str:
return "out"
cmd[cmd.index("-o") + 1] = Out() # type: ignore[call-overload]
ci = cmd.index("-o")
assert isinstance(ci, int)
cmd[ci + 1] = Out()
repl = textwrap.indent(
textwrap.dedent(
f"""\
@ -3565,7 +3569,7 @@ class LambdaFuture(CodeCacheFuture):
self.result_fn = result_fn
self.future = future
def result(self) -> Callable[..., Any]: # type: ignore[override]
def result(self) -> Callable[..., Any]:
return self.result_fn()

View File

@ -12,7 +12,6 @@ import operator
import os
import re
import tempfile
import typing
from abc import ABC, abstractmethod
from enum import auto, Enum
from itertools import chain
@ -27,7 +26,7 @@ from typing import (
TYPE_CHECKING,
Union,
)
from typing_extensions import TypeVar
from typing_extensions import Self, TypeVar
import sympy
@ -408,7 +407,7 @@ def get_backend_features(
if isinstance(device, torch.device):
device_type = device.type
else:
assert isinstance(device, str)
assert isinstance(device, str), type(device)
device_type = device
device = torch.device(device_type)
scheduling_ctor = get_scheduling_for_device(device_type)
@ -538,7 +537,7 @@ def register_device_op_overrides(
def get_device_op_overrides(device: str) -> DeviceOpOverrides:
assert isinstance(device, str)
assert isinstance(device, str), type(device)
if not device_op_overrides_dict:
from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401
@ -621,7 +620,7 @@ def check_dtype(
elif config.test_configs.static_cpp_dtype_assert and backend == "cpp":
from .cpp_utils import CppCSEVariable, DTYPE_TO_CPP
assert isinstance(var, CppCSEVariable)
assert isinstance(var, CppCSEVariable), type(var)
if dtype == torch.bool:
if var.is_vec:
is_same_dt = f"IsVecMaskType<decltype({var})>::value"
@ -682,9 +681,11 @@ class DataTypePropagation:
return None
if node.target == operator.getitem:
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
node_arg = node.args[0]
assert isinstance(node_arg, torch.fx.Node), type(node_arg)
return self.deduce_node_dtype(node_arg)
assert isinstance(node.target, str)
assert isinstance(node.target, str), type(node.target)
if node.target.startswith("masked_subblock"):
return self.deduce_node_dtype_by_subgraph(node)
@ -730,8 +731,8 @@ class DataTypePropagation:
from ..loop_body import LoopBody
from ..scheduler import SchedulerNode
assert isinstance(node, SchedulerNode)
assert isinstance(node._body, LoopBody)
assert isinstance(node, SchedulerNode), type(node)
assert isinstance(node._body, LoopBody), type(node._body)
return DataTypePropagation.propagate_loopbody(node._body)
@ -1428,7 +1429,7 @@ class KernelArgs:
def make_inplace(self, input_name: str, output_name: str) -> None:
if input_name in V.graph.unaligned_buffers:
V.graph.unaligned_buffers.add(output_name)
assert output_name not in self.inplace_buffers
assert output_name not in self.inplace_buffers, output_name
if input_name in self.inplace_buffers:
buf = self.inplace_buffers[input_name]
assert not isinstance(buf, RemovedArg)
@ -1490,7 +1491,7 @@ class KernelArgs:
assert (
existing_arg.inner_name != arg.inner_name
and existing_arg.outer_name != arg.outer_name
)
), existing_arg
self.workspace_args.append(arg)
return arg.inner_name, 0
@ -1518,7 +1519,7 @@ class KernelArgs:
)
for existing_arg in self.workspace_args:
if existing_arg.inner_name == arg.inner_name:
assert arg == existing_arg
assert arg == existing_arg, (arg, existing_arg)
self.workspace_args.append(arg)
return arg.inner_name
@ -1618,7 +1619,7 @@ class KernelArgs:
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]:
arg_defs: list[ArgName] = []
call_args: list[str] = []
arg_types: list[torch.dtype] = []
arg_types: list[Any] = []
precompile_args: list[KernelArgType] = []
for inplaced in unique(self.inplace_buffers.values()):
if isinstance(inplaced, RemovedArg):
@ -1651,7 +1652,7 @@ class KernelArgs:
for outer, inner in self.sizevars.items():
arg_defs.append(ArgName(inner))
call_args.append(outer)
arg_types.append(type(outer)) # type: ignore[arg-type]
arg_types.append(type(outer))
precompile_args.append(SizeArg(inner, outer))
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)
@ -1686,7 +1687,7 @@ class KernelArgs:
# after you do a call into this kernel, which buffers actually contain
# updated data? Modeled off of python_argdefs.
def live_output_buffers(self) -> OrderedSet[str]:
live_outs = OrderedSet() # type: ignore[var-annotated]
live_outs = OrderedSet[str]()
for inplaced in unique(self.inplace_buffers.values()):
if isinstance(inplaced, RemovedArg):
continue
@ -1712,7 +1713,7 @@ class CSEVariable:
dtype: Optional[torch.dtype] = None,
):
super().__init__()
assert isinstance(bounds, ValueRanges)
assert isinstance(bounds, ValueRanges), type(bounds)
self.name = name
self.bounds = bounds
self.use_count = 1 # track how many times this expression is used
@ -1782,7 +1783,7 @@ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
else:
self._cache = {}
def clone(self) -> typing.Self:
def clone(self) -> Self:
return type(self)(
prefix=self.prefix,
suffix=self.suffix,
@ -1793,7 +1794,7 @@ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
reduction_cache=self.reduction_cache,
)
def scoped_copy(self) -> typing.Self:
def scoped_copy(self) -> Self:
"""Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
new_cse = self.clone()
new_cse._cache = ScopedDict(self._cache)
@ -1918,7 +1919,7 @@ class CodeGen:
super().__init__()
self.exit_stack = contextlib.ExitStack()
def __enter__(self) -> typing.Self:
def __enter__(self) -> Self:
self.exit_stack.__enter__()
return self
@ -2084,7 +2085,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
) -> str:
if isinstance(var, CSEVariable):
var = str(var)
assert isinstance(var, str)
assert isinstance(var, str), type(var)
assert lower is None or isinstance(lower, str)
assert upper is None or isinstance(upper, str)
if lower and upper:
@ -2113,7 +2114,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
def index_to_str(self, index: sympy.Expr) -> str:
raise NotImplementedError
def __enter__(self) -> typing.Self:
def __enter__(self) -> Self:
super().__enter__()
assert self.overrides
self.exit_stack.enter_context(
@ -2184,7 +2185,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
# adds the necessary kernel args for index expressions
# and renames variables in index expressions to kernel arg names
if isinstance(index, (list, tuple)):
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
return [self.rename_indexing(x) for x in index]
index = V.graph.sizevars.simplify(index)
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
replacements = {
@ -2362,7 +2363,7 @@ class CSEProxy(DefaultHandler):
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
bounds = self._bound_variable(name, *args, **kwargs)
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
value = getattr(self.parent_handler, name)(*args, **kwargs)
dtype_handler = DtypePropagationOpsHandler()
backend = get_current_backend()
@ -2387,8 +2388,8 @@ class CSEProxy(DefaultHandler):
def do_cse(v: str) -> CSEVariable:
# we tree_map over the output, so we need to fetch corresponding dtype
nonlocal output_idx
var_dtype: torch.dtype = (
output_dtype[output_idx] # type: ignore[assignment]
var_dtype: Optional[torch.dtype] = (
output_dtype[output_idx]
if isinstance(output_dtype, (list, tuple))
else output_dtype
)
@ -2411,6 +2412,7 @@ class CSEProxy(DefaultHandler):
config.test_configs.runtime_triton_dtype_assert
or config.test_configs.static_cpp_dtype_assert
):
assert var_dtype is not None
check_dtype(V.kernel.compute, csevar, var_dtype)
return csevar
@ -2433,7 +2435,9 @@ class CSEProxy(DefaultHandler):
fx_node = V.interpreter.current_node
if fx_node.target == name and self.kernel.node_to_bounds is not None:
assert isinstance(self.kernel.node_to_bounds, dict)
assert isinstance(self.kernel.node_to_bounds, dict), type(
self.kernel.node_to_bounds
)
return self.kernel.node_to_bounds.get(fx_node, ValueRanges.unknown())
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
# These create lots of inner strings. We would need to compute the bounds at the ops
@ -2468,14 +2472,14 @@ class CSEProxy(DefaultHandler):
) -> sympy.Symbol:
if isinstance(size, int):
size = sympy.Integer(size)
assert isinstance(size, sympy.Expr), size
assert isinstance(size, sympy.Expr), (type(size), size)
# Skip CSE since this doesn't return an expression
if var.bounds.lower < 0: # type: ignore[operator]
if var.bounds.lower < 0:
if wrap_neg:
stm = ops.add(var, ops.index_expr(size, torch.long))
# Mixed negative and non-negative
if var.bounds.upper >= 0: # type: ignore[operator]
if var.bounds.upper >= 0:
lt = ops.lt(var, 0)
stm = ops.where(lt, stm, var)
else:
@ -2492,7 +2496,7 @@ class CSEProxy(DefaultHandler):
neg_bounds.lower + size, neg_bounds.upper + size
)
# We don't have a good way of representing the empty range
if var.bounds.upper >= 0: # type: ignore[operator]
if var.bounds.upper >= 0:
pos = var.bounds & ValueRanges(0, int_oo)
new_bounds = new_bounds | pos
@ -2544,8 +2548,7 @@ class CSEProxy(DefaultHandler):
if mode is None:
self._update_store_cache(name, value)
if name not in V.graph.removed_buffers:
return self.kernel.store(name, index, value, mode=mode)
return None # type: ignore[return-value]
self.kernel.store(name, index, value, mode=mode)
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
self.kernel.store_buffer_names.add(name)