mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable typechecking in _inductor/ir.py (#110112)
I used a bunch of ignore-type comments, mostly due to https://github.com/pytorch/pytorch/issues/109963. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110112 Approved by: https://github.com/peterbell10
This commit is contained in:
@ -197,7 +197,6 @@ include_patterns = [
|
||||
exclude_patterns = [
|
||||
'**/fb/**',
|
||||
'torch/_inductor/fx_passes/serialized_patterns/**',
|
||||
'torch/_inductor/ir.py',
|
||||
'torch/_inductor/scheduler.py',
|
||||
]
|
||||
command = [
|
||||
|
@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from ctypes import byref, c_size_t, c_void_p
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.queues import Queue
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import multiprocessing
|
||||
@ -337,9 +337,9 @@ class TensorMeta:
|
||||
|
||||
@classmethod
|
||||
def from_irnodes(
|
||||
cls, irnodes: Union[LayoutOrBuffer, Tuple[LayoutOrBuffer], List[LayoutOrBuffer]]
|
||||
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
|
||||
) -> Union[TensorMeta, List[TensorMeta]]:
|
||||
if isinstance(irnodes, (tuple, list)):
|
||||
if isinstance(irnodes, Sequence):
|
||||
result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
|
||||
assert all(isinstance(x, TensorMeta) for x in result)
|
||||
return result
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from ...autotune_process import CUDABenchmarkRequest
|
||||
from ...ir import Callable, CUDATemplateBuffer, IRNode, Layout, TensorBox
|
||||
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout, TensorBox
|
||||
from ...select_algorithm import ChoiceCaller
|
||||
from ...utils import sympy_product
|
||||
from ...virtualized import V
|
||||
@ -270,7 +270,7 @@ class CUDATemplateCaller(ChoiceCaller):
|
||||
self,
|
||||
name: str,
|
||||
category: str,
|
||||
input_nodes: List[IRNode],
|
||||
input_nodes: List[Buffer],
|
||||
layout: Layout,
|
||||
make_kernel_render: Callable[[str], str],
|
||||
bmreq: CUDABenchmarkRequest,
|
||||
|
@ -26,13 +26,13 @@ class CUDATemplate(KernelTemplate):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
input_nodes: List[IRNode],
|
||||
input_nodes: List[Buffer],
|
||||
layout: Layout,
|
||||
input_reorder: Optional[List[int]] = None,
|
||||
):
|
||||
super().__init__(name)
|
||||
self.input_nodes = input_nodes
|
||||
self.output_node = Buffer("buf_out", layout)
|
||||
self.output_node: Buffer = Buffer("buf_out", layout)
|
||||
self.input_reorder = input_reorder
|
||||
|
||||
def generate(self, **kwargs) -> CUDATemplateCaller:
|
||||
|
@ -156,7 +156,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_nodes: List[IRNode],
|
||||
input_nodes: List[Buffer],
|
||||
layout: Layout,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
@ -478,7 +478,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
self,
|
||||
kernel: CUDATemplateKernel,
|
||||
op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined]
|
||||
output_node: IRNode = None,
|
||||
output_node: Optional[Buffer] = None,
|
||||
) -> str:
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
||||
|
@ -333,7 +333,7 @@ def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"):
|
||||
|
||||
|
||||
def extract_read_writes(
|
||||
fn: Callable[[sympy.Expr], Any],
|
||||
fn: Callable[..., Any],
|
||||
*argsizes: Tuple[sympy.Expr, ...],
|
||||
normalize: bool = False,
|
||||
prefix: str = "d",
|
||||
|
@ -488,7 +488,10 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
if (
|
||||
not hasattr(value, "data")
|
||||
or not isinstance(value.data, ir.IRNode)
|
||||
or not isinstance(value.data.data, ir.IRNode)
|
||||
or not (
|
||||
hasattr(value.data, "data")
|
||||
and isinstance(value.data.data, ir.IRNode)
|
||||
)
|
||||
):
|
||||
return
|
||||
|
||||
@ -675,7 +678,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
for name, value in self.graph_inputs.items():
|
||||
assert isinstance(
|
||||
value, (TensorBox, sympy.Expr)
|
||||
), "Unsupported inductor graph input type: " + type(value)
|
||||
), f"Unsupported inductor graph input type: {type(value)}"
|
||||
if not isinstance(value, TensorBox):
|
||||
continue
|
||||
value.realize()
|
||||
|
@ -16,6 +16,7 @@ from typing import (
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -178,12 +179,12 @@ def stride_order2fill_order(order):
|
||||
return fill_order
|
||||
|
||||
|
||||
def get_stride_order(seq: Sequence[int]):
|
||||
def get_stride_order(seq: Sequence[int]) -> List[int]:
|
||||
"""
|
||||
Convert strides to stride order
|
||||
"""
|
||||
sorted_idx: List[int] = argsort(seq)
|
||||
out = [None for _ in range(len(seq))]
|
||||
out = [0 for _ in range(len(seq))]
|
||||
for i, elem in enumerate(sorted_idx):
|
||||
out[elem] = i
|
||||
return out
|
||||
@ -236,7 +237,6 @@ def is_cpu(x):
|
||||
return get_device_type(x) == "cpu"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class IRNode:
|
||||
_current_origins: ClassVar[Set[Any]] = set()
|
||||
|
||||
@ -306,6 +306,21 @@ class IRNode:
|
||||
"""
|
||||
raise NotImplementedError(f"realize NYI on {type(self)}")
|
||||
|
||||
# The abstract method declarations below serve to convince mypy that all IRNode instances have these functions
|
||||
# defined, while having no effect at runtime. We cannot create stub implementations here because other parts of
|
||||
# the code dynamically check for defined attributes.
|
||||
get_device: Callable[[], torch.device]
|
||||
get_dtype: Callable[[], torch.dtype]
|
||||
get_name: Callable[[], str]
|
||||
get_reads: Callable[[], Any]
|
||||
get_stride: Callable[[], Any]
|
||||
get_storage_numel: Callable[[], Any]
|
||||
has_exceeded_max_reads: Callable[[], bool]
|
||||
make_loader: Callable[[], Callable[[Any], Any]]
|
||||
make_indexer: Callable[[], Callable[[Any], Any]]
|
||||
mark_reuse: Callable[[List[Any]], None]
|
||||
realize_hint: Callable[[], None]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Loops(IRNode):
|
||||
@ -386,6 +401,21 @@ class Loops(IRNode):
|
||||
self.get_size(),
|
||||
).reads
|
||||
|
||||
def get_reduction_size(self):
|
||||
raise NotImplementedError(
|
||||
f"get_reduction_size() is not implemented by {type(self)}!"
|
||||
)
|
||||
|
||||
def get_reduction_type(self):
|
||||
raise NotImplementedError(
|
||||
f"get_reduction_type() is not implemented by {type(self)}!"
|
||||
)
|
||||
|
||||
def constant_to_device(self, device):
|
||||
raise NotImplementedError(
|
||||
f"constant_to_device() is not implemented by {type(self)}!"
|
||||
)
|
||||
|
||||
|
||||
def nop_loader_fn(idx, *, dtype):
|
||||
if dtype.is_floating_point:
|
||||
@ -1146,9 +1176,12 @@ class WelfordReduction(Reduction):
|
||||
|
||||
reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
|
||||
|
||||
def const(idx, val):
|
||||
def const(val):
|
||||
def inner_fn(idx):
|
||||
return ops.constant(val, dst_dtype)
|
||||
return ops.constant(
|
||||
val,
|
||||
dtype,
|
||||
)
|
||||
|
||||
return Pointwise.create(
|
||||
device=device,
|
||||
@ -1172,7 +1205,7 @@ class WelfordReduction(Reduction):
|
||||
|
||||
return Pointwise.create(
|
||||
device=device,
|
||||
dtype=dst_dtype,
|
||||
dtype=dtype,
|
||||
inner_fn=inner_fn,
|
||||
ranges=list(ranges),
|
||||
)
|
||||
@ -1251,7 +1284,7 @@ class WelfordReduction(Reduction):
|
||||
return (0, 0, 0)
|
||||
|
||||
@classmethod
|
||||
def create_multilayer(
|
||||
def create_multilayer( # type: ignore[override]
|
||||
cls,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
@ -1377,7 +1410,7 @@ def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=No
|
||||
return x, x.data.layout
|
||||
if isinstance(x, ReinterpretView):
|
||||
# making the base of x contiguous or stride_ordered will not necessarily make
|
||||
# the ReinterpretedView either, so dont pass along those arguments
|
||||
# the ReinterpretView either, so don't pass along those arguments
|
||||
buffer, _ = as_storage_and_layout(
|
||||
x.data,
|
||||
freeze=freeze,
|
||||
@ -1455,7 +1488,7 @@ class BaseView(IRNode):
|
||||
return self.data.get_storage_numel()
|
||||
|
||||
def is_extern(self):
|
||||
return self.data.is_extern()
|
||||
return self.data.is_extern() # type: ignore[attr-defined]
|
||||
|
||||
def get_reads(self):
|
||||
with patch.object(FlexibleLayout, "allow_indexing", True):
|
||||
@ -1465,7 +1498,7 @@ class BaseView(IRNode):
|
||||
).reads
|
||||
|
||||
def unwrap_view(self):
|
||||
x = self
|
||||
x: IRNode = self
|
||||
while isinstance(x, BaseView):
|
||||
x = x.data
|
||||
return x
|
||||
@ -1689,7 +1722,7 @@ class View(GenericView):
|
||||
def fake_reindex(index):
|
||||
return tuple([0] * len(old_size))
|
||||
|
||||
return cls(x, tuple(new_size), fake_reindex)
|
||||
return cls(x, list(new_size), fake_reindex)
|
||||
# TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
|
||||
elif is_contiguous_storage_and_layout(x):
|
||||
storage, old_layout = as_contiguous_storage_and_layout(x)
|
||||
@ -1914,6 +1947,9 @@ class SliceView(View):
|
||||
|
||||
|
||||
class BaseConstant(IRNode):
|
||||
dtype: torch.dtype
|
||||
device: torch.device
|
||||
|
||||
def get_size(self):
|
||||
return ()
|
||||
|
||||
@ -1981,7 +2017,7 @@ class Layout(IRNode):
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
size: List[Expr],
|
||||
stride: List[Expr],
|
||||
stride: Optional[Sequence[Union[Expr, int]]],
|
||||
offset: Expr = Integer(0),
|
||||
):
|
||||
assert stride is None or len(size) == len(
|
||||
@ -2092,7 +2128,7 @@ class FixedLayout(Layout):
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
size: Union[List[Expr], List[int]],
|
||||
stride: Optional[Union[List[Expr], List[int]]] = None,
|
||||
stride: Optional[Sequence[Union[Expr, int]]] = None,
|
||||
offset: Union[Expr, int] = Integer(0),
|
||||
):
|
||||
if stride is None:
|
||||
@ -2213,7 +2249,7 @@ class FlexibleLayout(Layout):
|
||||
class AliasedLayout(Layout):
|
||||
"""Shares the same storage as another tensor"""
|
||||
|
||||
def __init__(self, view: "ReinterpretView"):
|
||||
def __init__(self, view: IRNode):
|
||||
layout = view.get_layout()
|
||||
super().__init__(
|
||||
layout.device,
|
||||
@ -2260,13 +2296,13 @@ class MutationLayout(Layout):
|
||||
target.get_device(),
|
||||
target.get_dtype(),
|
||||
target.get_size(),
|
||||
None, # type: ignore[arg-type]
|
||||
None,
|
||||
)
|
||||
self.target = target
|
||||
name = self.get_buffer().get_name()
|
||||
V.graph.mark_buffer_mutated(name)
|
||||
|
||||
@Layout.stride.getter
|
||||
@Layout.stride.getter # type: ignore[attr-defined]
|
||||
def stride(self):
|
||||
return self.real_layout().stride
|
||||
|
||||
@ -2620,20 +2656,21 @@ class ComputedBuffer(Buffer):
|
||||
|
||||
def make_loader(self):
|
||||
# Inline constants and index_expressions
|
||||
can_inline = (
|
||||
if (
|
||||
hasattr(self.data, "make_loader")
|
||||
and self.name not in V.graph.mutated_buffers
|
||||
and self.num_reads() == 0
|
||||
)
|
||||
if can_inline:
|
||||
):
|
||||
# can be inlined
|
||||
return self.data.make_loader()
|
||||
return super().make_loader()
|
||||
|
||||
def get_store_function(self):
|
||||
indexer = self.layout.as_fixed().make_indexer()
|
||||
if self.data.get_reduction_type():
|
||||
if isinstance(self.data, Reduction):
|
||||
return partial(self.data.store_reduction, self.name, indexer)
|
||||
else:
|
||||
assert isinstance(self.data, Pointwise)
|
||||
return partial(self.data.store_output, self.name, indexer)
|
||||
|
||||
def get_fill_order(self):
|
||||
@ -3020,14 +3057,14 @@ class ConcatKernel(NopKernel):
|
||||
)
|
||||
kernel = StorageBox(concat_kernel)
|
||||
for i in range(len(inputs)):
|
||||
kernel.data.inputs.append(
|
||||
concat_kernel.inputs.append(
|
||||
cls.realize_into(
|
||||
inputs[i],
|
||||
SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]),
|
||||
)
|
||||
)
|
||||
kernel.data.name = V.graph.register_buffer(kernel.data)
|
||||
kernel.data.inputs = cls.unwrap_storage(kernel.data.inputs)
|
||||
concat_kernel.name = V.graph.register_buffer(concat_kernel)
|
||||
concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs)
|
||||
|
||||
return kernel
|
||||
|
||||
@ -3047,6 +3084,7 @@ class ConcatKernel(NopKernel):
|
||||
if isinstance(src, StorageBox):
|
||||
src.realize()
|
||||
# ExternKernelAlloc has specific requirements for output layout, should create a copy
|
||||
assert hasattr(src.data, "layout")
|
||||
if isinstance(src.data.layout, FlexibleLayout) and not isinstance(
|
||||
src.data, ExternKernelAlloc
|
||||
):
|
||||
@ -3073,6 +3111,9 @@ class ExternKernel(InputsKernel):
|
||||
constant_args: Tuple[Any, ...] = ()
|
||||
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
||||
output_view: Optional[ReinterpretView] = None
|
||||
ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
def decide_layout(self):
|
||||
if isinstance(self.layout, FlexibleLayout):
|
||||
@ -3715,7 +3756,7 @@ class DynamicScalar(ExternKernel):
|
||||
return False
|
||||
|
||||
def __init__(self, sym, data):
|
||||
super().__init__(None, NoneLayout(), [data])
|
||||
super().__init__(None, NoneLayout(), [data]) # type: ignore[arg-type]
|
||||
self.sym = sym
|
||||
|
||||
def get_unbacked_symbol_defs(self):
|
||||
@ -4117,7 +4158,7 @@ def _prepare_convolution_fusion_create(
|
||||
dilation: List[int],
|
||||
groups: int,
|
||||
transposed: bool = False,
|
||||
output_padding: List[int] = None,
|
||||
output_padding: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
This function is a helper function to prepare inputs, layout and constant args
|
||||
@ -4296,7 +4337,7 @@ def _prepare_linear_fusion_create(
|
||||
convert_shape_to_inductor(output_size),
|
||||
convert_shape_to_inductor(output_stride),
|
||||
)
|
||||
constant_args = []
|
||||
constant_args: List[Any] = []
|
||||
|
||||
if bias is not None:
|
||||
inputs.append(bias)
|
||||
@ -4908,7 +4949,6 @@ class MkldnnRnnLayer(ExternKernelAlloc):
|
||||
assert len(output_shape) == 3, "Expect output_shape to be 3D"
|
||||
return make_contiguous_strides_for(output_shape)
|
||||
|
||||
indices = []
|
||||
output_sizes = [output_shape, hy_shape, cy_shape]
|
||||
output_strides = [
|
||||
get_strides_of_lstm_output(output_shape, batch_first),
|
||||
@ -4924,7 +4964,7 @@ class MkldnnRnnLayer(ExternKernelAlloc):
|
||||
output_stride,
|
||||
),
|
||||
packed,
|
||||
indices + [(tuple, i)],
|
||||
[(tuple, i)],
|
||||
)
|
||||
for i, (output_size, output_stride) in enumerate(
|
||||
zip(output_sizes, output_strides)
|
||||
@ -5362,7 +5402,7 @@ class MutableBox(IRNode):
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
return self.data.layout
|
||||
return self.data.layout # type: ignore[attr-defined]
|
||||
|
||||
def get_layout(self):
|
||||
return self.layout
|
||||
@ -5653,7 +5693,7 @@ class LoopBodyBlock:
|
||||
{},
|
||||
)
|
||||
|
||||
class CaptureIndexing(V.WrapperHandler):
|
||||
class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined]
|
||||
self.name = "CaptureIndexing"
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
@ -5699,6 +5739,8 @@ class LoopBodyBlock:
|
||||
Recursively capture the masked out body in another LoopBodyBlock
|
||||
"""
|
||||
|
||||
subblock: LoopBodyBlock
|
||||
|
||||
def shim(mask, other):
|
||||
return V.ops.masked(mask, subblock, other)
|
||||
|
||||
@ -5716,12 +5758,13 @@ class LoopBodyBlock:
|
||||
Introduce a call_module to update the indexing.
|
||||
"""
|
||||
|
||||
var = self.body.add_indirect(size)
|
||||
|
||||
def set_indirect(new_var):
|
||||
self.body.replace_indirect(
|
||||
var, V.ops.indirect_indexing(new_var, size, check)
|
||||
)
|
||||
|
||||
var = self.body.add_indirect(size)
|
||||
tracer.create_proxy(
|
||||
"call_module",
|
||||
self.body.add_submodule(set_indirect, f"set_{var}"),
|
||||
@ -5741,7 +5784,9 @@ class LoopBodyBlock:
|
||||
from .index_propagation import IndexPropagation
|
||||
from .sizevars import SimplifyIndexing
|
||||
|
||||
handler = SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges)
|
||||
handler: Any = SimplifyIndexing(
|
||||
CaptureIndexing(proxy_ops), self.body.var_ranges
|
||||
)
|
||||
if config.constant_and_index_propagation:
|
||||
handler = IndexPropagation(handler)
|
||||
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import cast, List, Tuple, TypedDict
|
||||
from typing import cast, List, Optional, Sequence, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
from .. import config, ir
|
||||
@ -228,8 +228,8 @@ class ConvLayoutParams(TypedDict):
|
||||
def conv_layout(
|
||||
x: TensorBox,
|
||||
weight: TensorBox,
|
||||
bias: TensorBox,
|
||||
stride: tuple[int, ...],
|
||||
bias: Optional[TensorBox],
|
||||
stride: Sequence[int],
|
||||
padding: tuple[int, ...],
|
||||
dilation: tuple[int, ...],
|
||||
transposed: bool,
|
||||
|
@ -261,6 +261,8 @@ class OpsWrapper:
|
||||
|
||||
ops = OpsWrapper()
|
||||
|
||||
_MockHandler = MockHandler
|
||||
|
||||
|
||||
class _V:
|
||||
MockHandler = MockHandler
|
||||
@ -281,7 +283,7 @@ class _V:
|
||||
get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler
|
||||
|
||||
@property
|
||||
def ops(self) -> MockHandler: # type: ignore[valid-type]
|
||||
def ops(self) -> _MockHandler:
|
||||
"""The operator handler specific to the current codegen task"""
|
||||
return _ops._get_handler()
|
||||
|
||||
|
Reference in New Issue
Block a user