[inductor] Add type annotations to _inductor/utils.py (#144108)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144108
Approved by: https://github.com/eellison
This commit is contained in:
Tom Ritchford
2025-02-15 19:05:03 +00:00
committed by PyTorch MergeBot
parent 4ab967c44d
commit 44ee9ca593
12 changed files with 357 additions and 239 deletions

View File

@ -404,6 +404,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
if V.graph.aot_mode:
if V.graph.const_module:
self.header.splice(V.graph.const_module.wrapper_code.header)
assert V.graph.const_code is not None
self.prefix.splice(V.graph.const_code)
if V.graph.is_const_graph:

View File

@ -177,6 +177,7 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
if V.graph.const_module:
self.header.splice(V.graph.const_module.wrapper_code.header)
assert V.graph.const_code is not None
self.prefix.splice(V.graph.const_code)
if V.graph.is_const_graph:

View File

@ -1486,7 +1486,7 @@ class HalideKernel(SIMDKernel):
argtypes,
target="-".join(target),
scheduler=schduler,
scheduler_flags=scheduler_flags,
scheduler_flags=scheduler_flags, # type: ignore[arg-type]
cuda_device=cuda_device,
)

View File

@ -1364,7 +1364,10 @@ class SIMDScheduling(BaseScheduling):
src_code = kernel.codegen_kernel()
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
if config.trace.enabled:
set_kernel_post_grad_provenance_tracing(node_schedule, kernel_name)
set_kernel_post_grad_provenance_tracing(
node_schedule, # type: ignore[arg-type]
kernel_name,
)
log.debug("Generating kernel code with kernel_name: %s", kernel_name)
kernel.kernel_name = kernel_name
kernel.code_hash = code_hash(src_code)

View File

@ -1792,7 +1792,7 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]:
out_devices: OrderedSet[torch.device] = OrderedSet(
arg.meta["val"].device
for arg in output_node(gm).args[0]
for arg in output_node(gm).args[0] # type: ignore[union-attr]
if isinstance(arg, fx.Node) and isinstance(arg.meta.get("val"), torch.Tensor)
)
cuda_devices: OrderedSet[torch.device] = OrderedSet(

View File

@ -1381,16 +1381,15 @@ class GraphLowering(torch.fx.Interpreter):
buffer_watermark = len(self.buffers)
operation_watermark = len(self.operations)
origins = OrderedSet([n])
# origins: OrderedSet[Union[Node, ir.IRNode]] = OrderedSet([n])
origins: OrderedSet[Any] = OrderedSet([n])
is_call_function = n.op == "call_function"
if is_call_function:
args, kwargs = self.fetch_args_kwargs_from_env(n)
origins |= gather_origins(args, kwargs)
with ir.IRNode.current_origins(origins), self.set_current_node( # type: ignore[arg-type]
with ir.IRNode.current_origins(origins), self.set_current_node(
n
), V.set_current_node(
n
):
), V.set_current_node(n):
if (
n.op == "call_function"
and n.target is not operator.getitem

View File

@ -12,7 +12,7 @@ import os
import warnings
from collections import defaultdict
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec
from unittest.mock import patch
@ -3777,7 +3777,7 @@ def scatter_fallback(
op_overload,
reduce,
self.get_dtype(),
src.get_dtype() if src_is_tensor else type(src),
cast(torch.dtype, src.get_dtype() if src_is_tensor else type(src)),
src.get_device().type if src_is_tensor else "not impl",
src_is_tensor,
):

View File

@ -7,7 +7,7 @@ import os
import re
from dataclasses import dataclass
from functools import lru_cache
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, Union
from torch._inductor import config
from torch._inductor.utils import get_benchmark_name
@ -145,11 +145,11 @@ class MetricTable:
row_dict.keys()
), f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
row = [
get_benchmark_name(),
]
row += [row_dict[column_name] for column_name in self.column_names]
self._write_row(row)
bn = get_benchmark_name()
# assert bn is not None
row = [bn] + [row_dict[column_name] for column_name in self.column_names]
assert all(isinstance(i, str) for i in row)
self._write_row(cast(list[str], row))
def output_filename(self) -> str:
return f"metric_table_{self.table_name}.csv"

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Optional
from typing import Any, Optional, Sequence
import sympy
@ -31,12 +31,12 @@ def _prepare_convolution_fusion_create(
x: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding: list[int],
stride: list[int],
dilation: list[int],
padding: Sequence[int],
stride: Sequence[int],
dilation: Sequence[int],
groups: int,
transposed: bool = False,
output_padding: Optional[list[int]] = None,
output_padding: Optional[Sequence[int]] = None,
quantize_args: Optional[list["TensorBox"]] = None,
other: Optional["TensorBox"] = None,
):

View File

@ -437,7 +437,7 @@ class CompiledFxGraph(OutputCode):
assert len(output.args) == 1
stack_traces = [
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
for arg in output.args[0]
for arg in output.args[0] # type: ignore[union-attr]
]
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
placeholders = tuple(get_placeholder_info(gm.graph))

View File

@ -823,6 +823,7 @@ class TritonTemplateKernel(TritonKernel):
self.codegen_body()
self.cse.invalidate(OrderedSet())
if input_node.get_name() not in self.prologue_fused_inputs:
assert load_code is not None
self.body.writeline(load_code)
return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()
@ -1768,11 +1769,8 @@ class AlgorithmSelectorCache(PersistentCache):
# different than the original values. we explicitly restore the state
# here to avoid this issue.
initial_stdout = sys.stdout
initial_stderr = sys.stderr
def precompile_with_captured_stdout(choice):
with restore_stdout_stderr(initial_stdout, initial_stderr):
with restore_stdout_stderr():
choice.precompile()
def on_complete(future):
@ -1805,7 +1803,7 @@ class AlgorithmSelectorCache(PersistentCache):
futures[future] = c
@functools.lru_cache(None)
@restore_stdout_stderr(initial_stdout, initial_stderr)
@restore_stdout_stderr()
def wait_on_futures():
counters["inductor"]["select_algorithm_precompile"] += 1
for future in as_completed(

File diff suppressed because it is too large Load Diff