Enable typechecking in _inductor/ir.py (#110112)

I used a bunch of ignore-type comments, mostly due to
https://github.com/pytorch/pytorch/issues/109963.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110112
Approved by: https://github.com/peterbell10
This commit is contained in:
Jez Ng
2023-10-06 17:02:54 -07:00
committed by PyTorch MergeBot
parent e8ef8bfdce
commit c77dd684c9
10 changed files with 99 additions and 50 deletions

View File

@ -197,7 +197,6 @@ include_patterns = [
exclude_patterns = [
'**/fb/**',
'torch/_inductor/fx_passes/serialized_patterns/**',
'torch/_inductor/ir.py',
'torch/_inductor/scheduler.py',
]
command = [

View File

@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor
from ctypes import byref, c_size_t, c_void_p
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, TYPE_CHECKING, Union
import torch
from torch import multiprocessing
@ -337,9 +337,9 @@ class TensorMeta:
@classmethod
def from_irnodes(
cls, irnodes: Union[LayoutOrBuffer, Tuple[LayoutOrBuffer], List[LayoutOrBuffer]]
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
) -> Union[TensorMeta, List[TensorMeta]]:
if isinstance(irnodes, (tuple, list)):
if isinstance(irnodes, Sequence):
result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
assert all(isinstance(x, TensorMeta) for x in result)
return result

View File

@ -1,7 +1,7 @@
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional
from ...autotune_process import CUDABenchmarkRequest
from ...ir import Callable, CUDATemplateBuffer, IRNode, Layout, TensorBox
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout, TensorBox
from ...select_algorithm import ChoiceCaller
from ...utils import sympy_product
from ...virtualized import V
@ -270,7 +270,7 @@ class CUDATemplateCaller(ChoiceCaller):
self,
name: str,
category: str,
input_nodes: List[IRNode],
input_nodes: List[Buffer],
layout: Layout,
make_kernel_render: Callable[[str], str],
bmreq: CUDABenchmarkRequest,

View File

@ -26,13 +26,13 @@ class CUDATemplate(KernelTemplate):
def __init__(
self,
name: str,
input_nodes: List[IRNode],
input_nodes: List[Buffer],
layout: Layout,
input_reorder: Optional[List[int]] = None,
):
super().__init__(name)
self.input_nodes = input_nodes
self.output_node = Buffer("buf_out", layout)
self.output_node: Buffer = Buffer("buf_out", layout)
self.input_reorder = input_reorder
def generate(self, **kwargs) -> CUDATemplateCaller:

View File

@ -156,7 +156,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
def __init__(
self,
input_nodes: List[IRNode],
input_nodes: List[Buffer],
layout: Layout,
alpha: float,
beta: float,
@ -478,7 +478,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
self,
kernel: CUDATemplateKernel,
op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined]
output_node: IRNode = None,
output_node: Optional[Buffer] = None,
) -> str:
assert cutlass_utils.try_import_cutlass()
import cutlass_library as cutlass_lib # type: ignore[import]

View File

@ -333,7 +333,7 @@ def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"):
def extract_read_writes(
fn: Callable[[sympy.Expr], Any],
fn: Callable[..., Any],
*argsizes: Tuple[sympy.Expr, ...],
normalize: bool = False,
prefix: str = "d",

View File

@ -488,7 +488,10 @@ class GraphLowering(torch.fx.Interpreter):
if (
not hasattr(value, "data")
or not isinstance(value.data, ir.IRNode)
or not isinstance(value.data.data, ir.IRNode)
or not (
hasattr(value.data, "data")
and isinstance(value.data.data, ir.IRNode)
)
):
return
@ -675,7 +678,7 @@ class GraphLowering(torch.fx.Interpreter):
for name, value in self.graph_inputs.items():
assert isinstance(
value, (TensorBox, sympy.Expr)
), "Unsupported inductor graph input type: " + type(value)
), f"Unsupported inductor graph input type: {type(value)}"
if not isinstance(value, TensorBox):
continue
value.realize()

View File

@ -16,6 +16,7 @@ from typing import (
Callable,
ClassVar,
Dict,
Iterable,
List,
Optional,
Sequence,
@ -178,12 +179,12 @@ def stride_order2fill_order(order):
return fill_order
def get_stride_order(seq: Sequence[int]):
def get_stride_order(seq: Sequence[int]) -> List[int]:
"""
Convert strides to stride order
"""
sorted_idx: List[int] = argsort(seq)
out = [None for _ in range(len(seq))]
out = [0 for _ in range(len(seq))]
for i, elem in enumerate(sorted_idx):
out[elem] = i
return out
@ -236,7 +237,6 @@ def is_cpu(x):
return get_device_type(x) == "cpu"
@dataclasses.dataclass
class IRNode:
_current_origins: ClassVar[Set[Any]] = set()
@ -306,6 +306,21 @@ class IRNode:
"""
raise NotImplementedError(f"realize NYI on {type(self)}")
# The abstract method declarations below serve to convince mypy that all IRNode instances have these functions
# defined, while having no effect at runtime. We cannot create stub implementations here because other parts of
# the code dynamically check for defined attributes.
get_device: Callable[[], torch.device]
get_dtype: Callable[[], torch.dtype]
get_name: Callable[[], str]
get_reads: Callable[[], Any]
get_stride: Callable[[], Any]
get_storage_numel: Callable[[], Any]
has_exceeded_max_reads: Callable[[], bool]
make_loader: Callable[[], Callable[[Any], Any]]
make_indexer: Callable[[], Callable[[Any], Any]]
mark_reuse: Callable[[List[Any]], None]
realize_hint: Callable[[], None]
@dataclasses.dataclass
class Loops(IRNode):
@ -386,6 +401,21 @@ class Loops(IRNode):
self.get_size(),
).reads
def get_reduction_size(self):
raise NotImplementedError(
f"get_reduction_size() is not implemented by {type(self)}!"
)
def get_reduction_type(self):
raise NotImplementedError(
f"get_reduction_type() is not implemented by {type(self)}!"
)
def constant_to_device(self, device):
raise NotImplementedError(
f"constant_to_device() is not implemented by {type(self)}!"
)
def nop_loader_fn(idx, *, dtype):
if dtype.is_floating_point:
@ -1146,9 +1176,12 @@ class WelfordReduction(Reduction):
reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
def const(idx, val):
def const(val):
def inner_fn(idx):
return ops.constant(val, dst_dtype)
return ops.constant(
val,
dtype,
)
return Pointwise.create(
device=device,
@ -1172,7 +1205,7 @@ class WelfordReduction(Reduction):
return Pointwise.create(
device=device,
dtype=dst_dtype,
dtype=dtype,
inner_fn=inner_fn,
ranges=list(ranges),
)
@ -1251,7 +1284,7 @@ class WelfordReduction(Reduction):
return (0, 0, 0)
@classmethod
def create_multilayer(
def create_multilayer( # type: ignore[override]
cls,
device: torch.device,
dtype: torch.dtype,
@ -1377,7 +1410,7 @@ def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=No
return x, x.data.layout
if isinstance(x, ReinterpretView):
# making the base of x contiguous or stride_ordered will not necessarily make
# the ReinterpretedView either, so dont pass along those arguments
# the ReinterpretView either, so don't pass along those arguments
buffer, _ = as_storage_and_layout(
x.data,
freeze=freeze,
@ -1455,7 +1488,7 @@ class BaseView(IRNode):
return self.data.get_storage_numel()
def is_extern(self):
return self.data.is_extern()
return self.data.is_extern() # type: ignore[attr-defined]
def get_reads(self):
with patch.object(FlexibleLayout, "allow_indexing", True):
@ -1465,7 +1498,7 @@ class BaseView(IRNode):
).reads
def unwrap_view(self):
x = self
x: IRNode = self
while isinstance(x, BaseView):
x = x.data
return x
@ -1689,7 +1722,7 @@ class View(GenericView):
def fake_reindex(index):
return tuple([0] * len(old_size))
return cls(x, tuple(new_size), fake_reindex)
return cls(x, list(new_size), fake_reindex)
# TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
elif is_contiguous_storage_and_layout(x):
storage, old_layout = as_contiguous_storage_and_layout(x)
@ -1914,6 +1947,9 @@ class SliceView(View):
class BaseConstant(IRNode):
dtype: torch.dtype
device: torch.device
def get_size(self):
return ()
@ -1981,7 +2017,7 @@ class Layout(IRNode):
device: torch.device,
dtype: torch.dtype,
size: List[Expr],
stride: List[Expr],
stride: Optional[Sequence[Union[Expr, int]]],
offset: Expr = Integer(0),
):
assert stride is None or len(size) == len(
@ -2092,7 +2128,7 @@ class FixedLayout(Layout):
device: torch.device,
dtype: torch.dtype,
size: Union[List[Expr], List[int]],
stride: Optional[Union[List[Expr], List[int]]] = None,
stride: Optional[Sequence[Union[Expr, int]]] = None,
offset: Union[Expr, int] = Integer(0),
):
if stride is None:
@ -2213,7 +2249,7 @@ class FlexibleLayout(Layout):
class AliasedLayout(Layout):
"""Shares the same storage as another tensor"""
def __init__(self, view: "ReinterpretView"):
def __init__(self, view: IRNode):
layout = view.get_layout()
super().__init__(
layout.device,
@ -2260,13 +2296,13 @@ class MutationLayout(Layout):
target.get_device(),
target.get_dtype(),
target.get_size(),
None, # type: ignore[arg-type]
None,
)
self.target = target
name = self.get_buffer().get_name()
V.graph.mark_buffer_mutated(name)
@Layout.stride.getter
@Layout.stride.getter # type: ignore[attr-defined]
def stride(self):
return self.real_layout().stride
@ -2620,20 +2656,21 @@ class ComputedBuffer(Buffer):
def make_loader(self):
# Inline constants and index_expressions
can_inline = (
if (
hasattr(self.data, "make_loader")
and self.name not in V.graph.mutated_buffers
and self.num_reads() == 0
)
if can_inline:
):
# can be inlined
return self.data.make_loader()
return super().make_loader()
def get_store_function(self):
indexer = self.layout.as_fixed().make_indexer()
if self.data.get_reduction_type():
if isinstance(self.data, Reduction):
return partial(self.data.store_reduction, self.name, indexer)
else:
assert isinstance(self.data, Pointwise)
return partial(self.data.store_output, self.name, indexer)
def get_fill_order(self):
@ -3020,14 +3057,14 @@ class ConcatKernel(NopKernel):
)
kernel = StorageBox(concat_kernel)
for i in range(len(inputs)):
kernel.data.inputs.append(
concat_kernel.inputs.append(
cls.realize_into(
inputs[i],
SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]),
)
)
kernel.data.name = V.graph.register_buffer(kernel.data)
kernel.data.inputs = cls.unwrap_storage(kernel.data.inputs)
concat_kernel.name = V.graph.register_buffer(concat_kernel)
concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs)
return kernel
@ -3047,6 +3084,7 @@ class ConcatKernel(NopKernel):
if isinstance(src, StorageBox):
src.realize()
# ExternKernelAlloc has specific requirements for output layout, should create a copy
assert hasattr(src.data, "layout")
if isinstance(src.data.layout, FlexibleLayout) and not isinstance(
src.data, ExternKernelAlloc
):
@ -3073,6 +3111,9 @@ class ExternKernel(InputsKernel):
constant_args: Tuple[Any, ...] = ()
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
output_view: Optional[ReinterpretView] = None
ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field(
default_factory=list
)
def decide_layout(self):
if isinstance(self.layout, FlexibleLayout):
@ -3715,7 +3756,7 @@ class DynamicScalar(ExternKernel):
return False
def __init__(self, sym, data):
super().__init__(None, NoneLayout(), [data])
super().__init__(None, NoneLayout(), [data]) # type: ignore[arg-type]
self.sym = sym
def get_unbacked_symbol_defs(self):
@ -4117,7 +4158,7 @@ def _prepare_convolution_fusion_create(
dilation: List[int],
groups: int,
transposed: bool = False,
output_padding: List[int] = None,
output_padding: Optional[List[int]] = None,
):
"""
This function is a helper function to prepare inputs, layout and constant args
@ -4296,7 +4337,7 @@ def _prepare_linear_fusion_create(
convert_shape_to_inductor(output_size),
convert_shape_to_inductor(output_stride),
)
constant_args = []
constant_args: List[Any] = []
if bias is not None:
inputs.append(bias)
@ -4908,7 +4949,6 @@ class MkldnnRnnLayer(ExternKernelAlloc):
assert len(output_shape) == 3, "Expect output_shape to be 3D"
return make_contiguous_strides_for(output_shape)
indices = []
output_sizes = [output_shape, hy_shape, cy_shape]
output_strides = [
get_strides_of_lstm_output(output_shape, batch_first),
@ -4924,7 +4964,7 @@ class MkldnnRnnLayer(ExternKernelAlloc):
output_stride,
),
packed,
indices + [(tuple, i)],
[(tuple, i)],
)
for i, (output_size, output_stride) in enumerate(
zip(output_sizes, output_strides)
@ -5362,7 +5402,7 @@ class MutableBox(IRNode):
@property
def layout(self):
return self.data.layout
return self.data.layout # type: ignore[attr-defined]
def get_layout(self):
return self.layout
@ -5653,7 +5693,7 @@ class LoopBodyBlock:
{},
)
class CaptureIndexing(V.WrapperHandler):
class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined]
self.name = "CaptureIndexing"
def load(self, name: str, index: sympy.Expr):
@ -5699,6 +5739,8 @@ class LoopBodyBlock:
Recursively capture the masked out body in another LoopBodyBlock
"""
subblock: LoopBodyBlock
def shim(mask, other):
return V.ops.masked(mask, subblock, other)
@ -5716,12 +5758,13 @@ class LoopBodyBlock:
Introduce a call_module to update the indexing.
"""
var = self.body.add_indirect(size)
def set_indirect(new_var):
self.body.replace_indirect(
var, V.ops.indirect_indexing(new_var, size, check)
)
var = self.body.add_indirect(size)
tracer.create_proxy(
"call_module",
self.body.add_submodule(set_indirect, f"set_{var}"),
@ -5741,7 +5784,9 @@ class LoopBodyBlock:
from .index_propagation import IndexPropagation
from .sizevars import SimplifyIndexing
handler = SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges)
handler: Any = SimplifyIndexing(
CaptureIndexing(proxy_ops), self.body.var_ranges
)
if config.constant_and_index_propagation:
handler = IndexPropagation(handler)

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import functools
import logging
from typing import cast, List, Tuple, TypedDict
from typing import cast, List, Optional, Sequence, Tuple, TypedDict
import torch
from .. import config, ir
@ -228,8 +228,8 @@ class ConvLayoutParams(TypedDict):
def conv_layout(
x: TensorBox,
weight: TensorBox,
bias: TensorBox,
stride: tuple[int, ...],
bias: Optional[TensorBox],
stride: Sequence[int],
padding: tuple[int, ...],
dilation: tuple[int, ...],
transposed: bool,

View File

@ -261,6 +261,8 @@ class OpsWrapper:
ops = OpsWrapper()
_MockHandler = MockHandler
class _V:
MockHandler = MockHandler
@ -281,7 +283,7 @@ class _V:
get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler
@property
def ops(self) -> MockHandler: # type: ignore[valid-type]
def ops(self) -> _MockHandler:
"""The operator handler specific to the current codegen task"""
return _ops._get_handler()