mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Add type annotations to _inductor/utils.py (#144108)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144108 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
4ab967c44d
commit
44ee9ca593
@ -404,6 +404,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
if V.graph.aot_mode:
|
||||
if V.graph.const_module:
|
||||
self.header.splice(V.graph.const_module.wrapper_code.header)
|
||||
assert V.graph.const_code is not None
|
||||
self.prefix.splice(V.graph.const_code)
|
||||
|
||||
if V.graph.is_const_graph:
|
||||
|
||||
@ -177,6 +177,7 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
||||
|
||||
if V.graph.const_module:
|
||||
self.header.splice(V.graph.const_module.wrapper_code.header)
|
||||
assert V.graph.const_code is not None
|
||||
self.prefix.splice(V.graph.const_code)
|
||||
|
||||
if V.graph.is_const_graph:
|
||||
|
||||
@ -1486,7 +1486,7 @@ class HalideKernel(SIMDKernel):
|
||||
argtypes,
|
||||
target="-".join(target),
|
||||
scheduler=schduler,
|
||||
scheduler_flags=scheduler_flags,
|
||||
scheduler_flags=scheduler_flags, # type: ignore[arg-type]
|
||||
cuda_device=cuda_device,
|
||||
)
|
||||
|
||||
|
||||
@ -1364,7 +1364,10 @@ class SIMDScheduling(BaseScheduling):
|
||||
src_code = kernel.codegen_kernel()
|
||||
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
|
||||
if config.trace.enabled:
|
||||
set_kernel_post_grad_provenance_tracing(node_schedule, kernel_name)
|
||||
set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule, # type: ignore[arg-type]
|
||||
kernel_name,
|
||||
)
|
||||
log.debug("Generating kernel code with kernel_name: %s", kernel_name)
|
||||
kernel.kernel_name = kernel_name
|
||||
kernel.code_hash = code_hash(src_code)
|
||||
|
||||
@ -1792,7 +1792,7 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]:
|
||||
|
||||
out_devices: OrderedSet[torch.device] = OrderedSet(
|
||||
arg.meta["val"].device
|
||||
for arg in output_node(gm).args[0]
|
||||
for arg in output_node(gm).args[0] # type: ignore[union-attr]
|
||||
if isinstance(arg, fx.Node) and isinstance(arg.meta.get("val"), torch.Tensor)
|
||||
)
|
||||
cuda_devices: OrderedSet[torch.device] = OrderedSet(
|
||||
|
||||
@ -1381,16 +1381,15 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
buffer_watermark = len(self.buffers)
|
||||
operation_watermark = len(self.operations)
|
||||
|
||||
origins = OrderedSet([n])
|
||||
# origins: OrderedSet[Union[Node, ir.IRNode]] = OrderedSet([n])
|
||||
origins: OrderedSet[Any] = OrderedSet([n])
|
||||
is_call_function = n.op == "call_function"
|
||||
if is_call_function:
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
||||
origins |= gather_origins(args, kwargs)
|
||||
with ir.IRNode.current_origins(origins), self.set_current_node( # type: ignore[arg-type]
|
||||
with ir.IRNode.current_origins(origins), self.set_current_node(
|
||||
n
|
||||
), V.set_current_node(
|
||||
n
|
||||
):
|
||||
), V.set_current_node(n):
|
||||
if (
|
||||
n.op == "call_function"
|
||||
and n.target is not operator.getitem
|
||||
|
||||
@ -12,7 +12,7 @@ import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -3777,7 +3777,7 @@ def scatter_fallback(
|
||||
op_overload,
|
||||
reduce,
|
||||
self.get_dtype(),
|
||||
src.get_dtype() if src_is_tensor else type(src),
|
||||
cast(torch.dtype, src.get_dtype() if src_is_tensor else type(src)),
|
||||
src.get_device().type if src_is_tensor else "not impl",
|
||||
src_is_tensor,
|
||||
):
|
||||
|
||||
@ -7,7 +7,7 @@ import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._inductor import config
|
||||
from torch._inductor.utils import get_benchmark_name
|
||||
@ -145,11 +145,11 @@ class MetricTable:
|
||||
row_dict.keys()
|
||||
), f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
|
||||
|
||||
row = [
|
||||
get_benchmark_name(),
|
||||
]
|
||||
row += [row_dict[column_name] for column_name in self.column_names]
|
||||
self._write_row(row)
|
||||
bn = get_benchmark_name()
|
||||
# assert bn is not None
|
||||
row = [bn] + [row_dict[column_name] for column_name in self.column_names]
|
||||
assert all(isinstance(i, str) for i in row)
|
||||
self._write_row(cast(list[str], row))
|
||||
|
||||
def output_filename(self) -> str:
|
||||
return f"metric_table_{self.table_name}.csv"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import sympy
|
||||
|
||||
@ -31,12 +31,12 @@ def _prepare_convolution_fusion_create(
|
||||
x: "TensorBox",
|
||||
weight: "TensorBox",
|
||||
bias: "TensorBox",
|
||||
padding: list[int],
|
||||
stride: list[int],
|
||||
dilation: list[int],
|
||||
padding: Sequence[int],
|
||||
stride: Sequence[int],
|
||||
dilation: Sequence[int],
|
||||
groups: int,
|
||||
transposed: bool = False,
|
||||
output_padding: Optional[list[int]] = None,
|
||||
output_padding: Optional[Sequence[int]] = None,
|
||||
quantize_args: Optional[list["TensorBox"]] = None,
|
||||
other: Optional["TensorBox"] = None,
|
||||
):
|
||||
|
||||
@ -437,7 +437,7 @@ class CompiledFxGraph(OutputCode):
|
||||
assert len(output.args) == 1
|
||||
stack_traces = [
|
||||
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
|
||||
for arg in output.args[0]
|
||||
for arg in output.args[0] # type: ignore[union-attr]
|
||||
]
|
||||
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
|
||||
placeholders = tuple(get_placeholder_info(gm.graph))
|
||||
|
||||
@ -823,6 +823,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
self.codegen_body()
|
||||
self.cse.invalidate(OrderedSet())
|
||||
if input_node.get_name() not in self.prologue_fused_inputs:
|
||||
assert load_code is not None
|
||||
self.body.writeline(load_code)
|
||||
|
||||
return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()
|
||||
@ -1768,11 +1769,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
# different than the original values. we explicitly restore the state
|
||||
# here to avoid this issue.
|
||||
|
||||
initial_stdout = sys.stdout
|
||||
initial_stderr = sys.stderr
|
||||
|
||||
def precompile_with_captured_stdout(choice):
|
||||
with restore_stdout_stderr(initial_stdout, initial_stderr):
|
||||
with restore_stdout_stderr():
|
||||
choice.precompile()
|
||||
|
||||
def on_complete(future):
|
||||
@ -1805,7 +1803,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
futures[future] = c
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@restore_stdout_stderr(initial_stdout, initial_stderr)
|
||||
@restore_stdout_stderr()
|
||||
def wait_on_futures():
|
||||
counters["inductor"]["select_algorithm_precompile"] += 1
|
||||
for future in as_completed(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user